frame.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import struct
  2. import socket
  3. from os import urandom
  4. from string import printable
  5. OPCODE_CONTINUATION = 0x0
  6. OPCODE_TEXT = 0x1
  7. OPCODE_BINARY = 0x2
  8. OPCODE_CLOSE = 0x8
  9. OPCODE_PING = 0x9
  10. OPCODE_PONG = 0xA
  11. CLOSE_NORMAL = 1000
  12. CLOSE_GOING_AWAY = 1001
  13. CLOSE_PROTOCOL_ERROR = 1002
  14. CLOSE_NOACCEPT_DTYPE = 1003
  15. CLOSE_INVALID_DATA = 1007
  16. CLOSE_POLICY = 1008
  17. CLOSE_MESSAGE_TOOBIG = 1009
  18. CLOSE_MISSING_EXTENSIONS = 1010
  19. CLOSE_UNABLE = 1011
  20. line_printable = [c for c in printable if c not in '\r\n\x0b\x0c']
  21. def printstr(s):
  22. return ''.join(c if c in line_printable else '.' for c in str(s))
  23. class Frame(object):
  24. """
  25. A Frame instance represents a web socket data frame as defined in RFC 6455.
  26. To encoding a frame for sending it over a socket, use Frame.pack(). To
  27. receive and decode a frame from a socket, use receive_frame().
  28. """
  29. def __init__(self, opcode, payload, masking_key='', mask=False, final=True,
  30. rsv1=False, rsv2=False, rsv3=False):
  31. """
  32. Create a new frame.
  33. `opcode` is one of the constants as defined above.
  34. `payload` is a string of bytes containing the data sendt in the frame.
  35. `masking_key` is an optional custom key to use for masking, or `mask`
  36. can be used instead to let this constructor generate a random masking
  37. key.
  38. `final` is a boolean indicating whether this frame is the last in a
  39. chain of fragments.
  40. `rsv1`, `rsv2` and `rsv3` are booleans indicating bit values for RSV1,
  41. RVS2 and RSV3, which are only non-zero if defined so by extensions.
  42. """
  43. if mask:
  44. masking_key = urandom(4)
  45. if len(masking_key) not in (0, 4):
  46. raise ValueError('invalid masking key "%s"' % masking_key)
  47. self.final = final
  48. self.rsv1 = rsv1
  49. self.rsv2 = rsv2
  50. self.rsv3 = rsv3
  51. self.opcode = opcode
  52. self.masking_key = masking_key
  53. self.payload = payload
  54. def pack(self):
  55. """
  56. Pack the frame into a string according to the following scheme:
  57. +-+-+-+-+-------+-+-------------+-------------------------------+
  58. |F|R|R|R| opcode|M| Payload len | Extended payload length |
  59. |I|S|S|S| (4) |A| (7) | (16/64) |
  60. |N|V|V|V| |S| | (if payload len==126/127) |
  61. | |1|2|3| |K| | |
  62. +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
  63. | Extended payload length continued, if payload len == 127 |
  64. + - - - - - - - - - - - - - - - +-------------------------------+
  65. | |Masking-key, if MASK set to 1 |
  66. +-------------------------------+-------------------------------+
  67. | Masking-key (continued) | Payload Data |
  68. +-------------------------------- - - - - - - - - - - - - - - - +
  69. : Payload Data continued ... :
  70. + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
  71. | Payload Data continued ... |
  72. +---------------------------------------------------------------+
  73. """
  74. header = struct.pack('!B', (self.final << 7) | (self.rsv1 << 6)
  75. | (self.rsv2 << 5) | (self.rsv3 << 4)
  76. | (self.opcode & 0xf))
  77. mask = bool(self.masking_key) << 7
  78. payload_len = len(self.payload)
  79. if payload_len <= 125:
  80. header += struct.pack('!B', mask | payload_len)
  81. elif payload_len < (1 << 16):
  82. header += struct.pack('!BH', mask | 126, payload_len)
  83. elif payload_len < (1 << 63):
  84. header += struct.pack('!BQ', mask | 127, payload_len)
  85. else:
  86. # FIXME: RFC 6455 defines an action for this...
  87. raise Exception('the payload length is too damn high!')
  88. if mask:
  89. return header + self.masking_key + self.mask_payload()
  90. return header + self.payload
  91. def mask_payload(self):
  92. return mask(self.masking_key, self.payload)
  93. def fragment(self, fragment_size, mask=False):
  94. """
  95. Fragment the frame into a chain of fragment frames:
  96. - An initial frame with non-zero opcode
  97. - Zero or more frames with opcode = 0 and final = False
  98. - A final frame with opcode = 0 and final = True
  99. The first and last frame may be the same frame, having a non-zero
  100. opcode and final = True. Thus, this function returns a list containing
  101. at least a single frame.
  102. `fragment_size` indicates the maximum payload size of each fragment.
  103. The payload of the original frame is split into one or more parts, and
  104. each part is converted to a Frame instance.
  105. `mask` is a boolean (default False) indicating whether the payloads
  106. should be masked. If True, each frame is assigned a randomly generated
  107. masking key.
  108. """
  109. frames = []
  110. for start in xrange(0, len(self.payload), fragment_size):
  111. payload = self.payload[start:start + fragment_size]
  112. frames.append(Frame(OPCODE_CONTINUATION, payload, mask=mask,
  113. final=False))
  114. frames[0].opcode = self.opcode
  115. frames[-1].final = True
  116. return frames
  117. def is_fragmented(self):
  118. return not self.final or self.opcode == OPCODE_CONTINUATION
  119. def __str__(self):
  120. s = '<%s opcode=0x%X len=%d' \
  121. % (self.__class__.__name__, self.opcode, len(self.payload))
  122. if self.masking_key:
  123. s += ' masking_key=%4s' % printstr(self.masking_key)
  124. max_pl_disp = 30
  125. pl = printstr(self.payload)[:max_pl_disp]
  126. if len(self.payload) > max_pl_disp:
  127. pl += '...'
  128. s += ' payload=%s' % pl
  129. if self.rsv1:
  130. s += ' rsv1'
  131. if self.rsv2:
  132. s += ' rsv2'
  133. if self.rsv3:
  134. s += ' rsv3'
  135. return s + '>'
  136. class ControlFrame(Frame):
  137. """
  138. A control frame is a frame with an opcode OPCODE_CLOSE, OPCODE_PING or
  139. OPCODE_PONG. These frames must be handled as defined by RFC 6455, and
  140. """
  141. def fragment(self, fragment_size, mask=False):
  142. """
  143. Control frames must not be fragmented.
  144. """
  145. raise TypeError('control frames must not be fragmented')
  146. def pack(self):
  147. """
  148. Same as Frame.pack(), but asserts that the payload size does not exceed
  149. 125 bytes.
  150. """
  151. if len(self.payload) > 125:
  152. raise ValueError('control frames must not be larger than 125 '
  153. 'bytes')
  154. return Frame.pack(self)
  155. def unpack_close(self):
  156. """
  157. Unpack a close message into a status code and a reason. If no payload
  158. is given, the code is None and the reason is an empty string.
  159. """
  160. if self.payload:
  161. code = struct.unpack('!H', str(self.payload[:2]))[0]
  162. reason = str(self.payload[2:])
  163. else:
  164. code = None
  165. reason = ''
  166. return code, reason
  167. def decode_frame(reader):
  168. b1, b2 = struct.unpack('!BB', reader.readn(2))
  169. final = bool(b1 & 0x80)
  170. rsv1 = bool(b1 & 0x40)
  171. rsv2 = bool(b1 & 0x20)
  172. rsv3 = bool(b1 & 0x10)
  173. opcode = b1 & 0x0F
  174. masked = bool(b2 & 0x80)
  175. payload_len = b2 & 0x7F
  176. if payload_len == 126:
  177. payload_len = struct.unpack('!H', reader.readn(2))
  178. elif payload_len == 127:
  179. payload_len = struct.unpack('!Q', reader.readn(8))
  180. if masked:
  181. masking_key = reader.readn(4)
  182. payload = mask(masking_key, reader.readn(payload_len))
  183. else:
  184. masking_key = ''
  185. payload = reader.readn(payload_len)
  186. # Control frames have most significant bit 1
  187. cls = ControlFrame if opcode & 0x8 else Frame
  188. return cls(opcode, payload, masking_key=masking_key, final=final,
  189. rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
  190. def receive_frame(sock):
  191. return decode_frame(SocketReader(sock))
  192. def read_frame(data):
  193. reader = BufferReader(data)
  194. frame = decode_frame(reader)
  195. return frame, reader.offset
  196. def pop_frame(data):
  197. frame, size = read_frame(data)
  198. return frame, data[size:]
  199. class BufferReader(object):
  200. def __init__(self, data):
  201. self.data = data
  202. self.offset = 0
  203. def readn(self, n):
  204. assert len(self.data) - self.offset >= n
  205. self.offset += n
  206. return self.data[self.offset - n:self.offset]
  207. class SocketReader(object):
  208. def __init__(self, sock):
  209. self.sock = sock
  210. def readn(self, n):
  211. """
  212. Keep receiving data until exactly `n` bytes have been read.
  213. """
  214. data = ''
  215. while len(data) < n:
  216. received = self.sock.recv(n - len(data))
  217. if not len(received):
  218. raise socket.error('no data read from socket')
  219. data += received
  220. return data
  221. def contains_frame(data):
  222. """
  223. Read the frame length from the start of `data` and check if the data is
  224. long enough to contain the entire frame.
  225. """
  226. if len(data) < 2:
  227. return False
  228. b2 = struct.unpack('!B', data[1])[0]
  229. payload_len = b2 & 0x7F
  230. payload_start = 2
  231. if payload_len == 126:
  232. if len(data) > 4:
  233. payload_len = struct.unpack('!H', data[2:4])
  234. payload_start = 4
  235. elif payload_len == 127:
  236. if len(data) > 12:
  237. payload_len = struct.unpack('!Q', data[4:12])
  238. payload_start = 12
  239. return len(data) >= payload_len + payload_start
  240. def mask(key, original):
  241. """
  242. Mask an octet string using the given masking key.
  243. The following masking algorithm is used, as defined in RFC 6455:
  244. for each octet:
  245. j = i MOD 4
  246. transformed-octet-i = original-octet-i XOR masking-key-octet-j
  247. """
  248. if len(key) != 4:
  249. raise ValueError('invalid masking key "%s"' % key)
  250. key = map(ord, key)
  251. masked = bytearray(original)
  252. for i in xrange(len(masked)):
  253. masked[i] ^= key[i % 4]
  254. return masked
  255. def create_close_frame(code, reason):
  256. payload = '' if code is None else struct.pack('!H', code) + reason
  257. return ControlFrame(OPCODE_CLOSE, payload)