frame.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import struct
  2. from os import urandom
  3. from exceptions import SocketClosed
  4. OPCODE_CONTINUATION = 0x0
  5. OPCODE_TEXT = 0x1
  6. OPCODE_BINARY = 0x2
  7. OPCODE_CLOSE = 0x8
  8. OPCODE_PING = 0x9
  9. OPCODE_PONG = 0xA
  10. class Frame(object):
  11. def __init__(self, opcode, payload, masking_key='', final=True, rsv1=False,
  12. rsv2=False, rsv3=False):
  13. if len(masking_key) not in (0, 4):
  14. raise ValueError('invalid masking key "%s"' % masking_key)
  15. self.final = final
  16. self.rsv1 = rsv1
  17. self.rsv2 = rsv2
  18. self.rsv3 = rsv3
  19. self.opcode = opcode
  20. self.masking_key = masking_key
  21. self.payload = payload
  22. def pack(self):
  23. """
  24. Pack the frame into a string according to the following scheme:
  25. +-+-+-+-+-------+-+-------------+-------------------------------+
  26. |F|R|R|R| opcode|M| Payload len | Extended payload length |
  27. |I|S|S|S| (4) |A| (7) | (16/64) |
  28. |N|V|V|V| |S| | (if payload len==126/127) |
  29. | |1|2|3| |K| | |
  30. +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
  31. | Extended payload length continued, if payload len == 127 |
  32. + - - - - - - - - - - - - - - - +-------------------------------+
  33. | |Masking-key, if MASK set to 1 |
  34. +-------------------------------+-------------------------------+
  35. | Masking-key (continued) | Payload Data |
  36. +-------------------------------- - - - - - - - - - - - - - - - +
  37. : Payload Data continued ... :
  38. + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
  39. | Payload Data continued ... |
  40. +---------------------------------------------------------------+
  41. """
  42. header = struct.pack('!B', (self.fin << 7) | (self.rsv1 << 6) |
  43. (self.rsv2 << 5) | (self.rsv3 << 4) | self.opcode)
  44. mask = bool(self.masking_key) << 7
  45. payload_len = len(self.payload)
  46. if payload_len <= 125:
  47. header += struct.pack('!B', mask | payload_len)
  48. elif payload_len < (1 << 16):
  49. header += struct.pack('!BH', mask | 126, payload_len)
  50. elif payload_len < (1 << 63):
  51. header += struct.pack('!BQ', mask | 127, payload_len)
  52. else:
  53. raise Exception('the payload length is too damn high!')
  54. if self.masking_key:
  55. return header + self.masking_key + self.mask_payload()
  56. return header + self.payload
  57. def mask_payload(self):
  58. return mask(self.masking_key, self.payload)
  59. def fragment(self, fragment_size, mask=False):
  60. frames = []
  61. for start in range(0, len(self.payload), fragment_size):
  62. payload = self.payload[start:start + fragment_size]
  63. key = urandom(4) if mask else ''
  64. frames.append(Frame(OPCODE_CONTINUATION, payload, key, False))
  65. frames[0].opcode = self.opcode
  66. frames[-1].final = True
  67. return frames
  68. def __str__(self):
  69. return '<Frame opcode=%c len=%d>' % (self.opcode, len(self.payload))
  70. class FrameReceiver(object):
  71. def __init__(self, sock):
  72. self.sock = sock
  73. def assert_received(self, n, exact=False):
  74. if self.nreceived < n:
  75. recv = recv_exactly if exact else recv_at_least
  76. received = recv(self.sock, n - self.nreceived)
  77. if not len(received):
  78. raise SocketClosed()
  79. self.buf += received
  80. self.nreceived = len(self.buf)
  81. def receive_fragments(self):
  82. fragments = [self.receive_frame()]
  83. while not fragments[-1].final:
  84. fragments.append(self.receive_frame())
  85. return fragments
  86. def receive_frame(self):
  87. self.buf = ''
  88. self.nreceived = 0
  89. total_len = 2
  90. self.assert_received(2)
  91. b1, b2 = struct.unpack('!BB', self.buf[:2])
  92. final = bool(b1 & 0x80)
  93. rsv1 = bool(b1 & 0x40)
  94. rsv2 = bool(b1 & 0x20)
  95. rsv3 = bool(b1 & 0x10)
  96. opcode = b1 & 0x0F
  97. mask = bool(b2 & 0x80)
  98. payload_len = b2 & 0x7F
  99. if mask:
  100. total_len += 4
  101. if payload_len == 126:
  102. self.assert_received(4)
  103. total_len += 4 + struct.unpack('!H', self.buf[2:4])
  104. key_start = 4
  105. elif payload_len == 127:
  106. self.assert_received(8)
  107. total_len += 8 + struct.unpack('!Q', self.buf[2:10])
  108. key_start = 10
  109. else:
  110. total_len += payload_len
  111. key_start = 2
  112. self.assert_received(total_len, exact=True)
  113. if mask:
  114. payload_start = key_start + 4
  115. masking_key = self.buf[key_start:payload_start]
  116. payload = mask(masking_key, self.buf[payload_start:])
  117. else:
  118. masking_key = ''
  119. payload = self.buf[key_start:]
  120. return Frame(opcode, payload, masking_key=masking_key, final=final,
  121. rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
  122. def recv_exactly(sock, n):
  123. """
  124. Keep receiving data from `sock' until exactly `n' bytes have been read.
  125. """
  126. left = n
  127. data = ''
  128. while left > 0:
  129. received = sock.recv(left)
  130. data += received
  131. left -= len(received)
  132. return received
  133. def recv_at_least(sock, n, at_least):
  134. """
  135. Keep receiving data from `sock' until at least `n' bytes have been read.
  136. """
  137. left = at_least
  138. data = ''
  139. while left > 0:
  140. received = sock.recv(n)
  141. data += received
  142. left -= len(received)
  143. return data
  144. def mask(key, original):
  145. """
  146. Mask an octet string using the given masking key.
  147. The following masking algorithm is used, as defined in RFC 6455:
  148. for each octet:
  149. j = i MOD 4
  150. transformed-octet-i = original-octet-i XOR masking-key-octet-j
  151. """
  152. if len(key) != 4:
  153. raise ValueError('invalid masking key "%s"' % key)
  154. key = map(ord, key)
  155. masked = bytearray(original)
  156. for i in xrange(len(masked)):
  157. masked[i] ^= key[i % 4]
  158. return masked
  159. def concat_frames(frames):
  160. """
  161. Create a new Frame object with the concatenated payload of the given list
  162. of frames.
  163. """
  164. assert len(frames)
  165. first = frames[0]
  166. assert first.opcode != 0
  167. assert frames[-1].final
  168. return Frame(first.opcode, ''.join([f.payload for f in frames]),
  169. rsv1=first.rsv1, rsv2=first.rsv2, rsv3=first.rsv3)