websocket.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import re
  2. import struct
  3. from hashlib import sha1
  4. from threading import Thread
  5. from frame import ControlFrame, receive_fragments, receive_frame, \
  6. OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG
  7. from message import create_message
  8. from exceptions import InvalidRequest, SocketClosed, PingError
  9. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  10. WS_VERSION = '13'
  11. class WebSocket(object):
  12. """
  13. A WebSocket upgrades a regular TCP socket to a web socket. The class
  14. implements the handshake protocol as defined by RFC 6455, provides
  15. abstracted methods for sending (optionally fragmented) messages, and
  16. automatically handles control messages.
  17. """
  18. def __init__(self, sock):
  19. """
  20. `sock' is a regular TCP socket instance.
  21. """
  22. self.sock = sock
  23. self.received_close_params = None
  24. self.close_frame_sent = False
  25. self.ping_sent = False
  26. self.ping_payload = None
  27. def send_message(self, message, fragment_size=None):
  28. if fragment_size is None:
  29. self.send_frame(message.frame())
  30. else:
  31. map(self.send_frame, message.fragment(fragment_size))
  32. def send_frame(self, frame):
  33. self.sock.sendall(frame.pack())
  34. def handle_control_frame(self, frame):
  35. if frame.opcode == OPCODE_CLOSE:
  36. self.received_close_params = frame.unpack_close()
  37. elif frame.opcode == OPCODE_PING:
  38. # Respond with a pong message with identical payload
  39. self.send_frame(ControlFrame(OPCODE_PONG, frame.payload))
  40. elif frame.opcode == OPCODE_PONG:
  41. # Assert that the PONG payload is identical to that of the PING
  42. if not self.ping_sent:
  43. raise PingError('received PONG while no PING was sent')
  44. self.ping_sent = False
  45. if frame.payload != self.ping_payload:
  46. raise PingError('received PONG with invalid payload')
  47. self.ping_payload = None
  48. self.onpong(frame.payload)
  49. def receive_message(self):
  50. frames = receive_fragments(self.sock, self.handle_control_frame)
  51. payload = ''.join([f.payload for f in frames])
  52. return create_message(frames[0].opcode, payload)
  53. def handshake(self):
  54. """
  55. Execute a handshake with the other end point of the socket. If the HTTP
  56. request headers read from the socket are invalid, an InvalidRequest
  57. exception is raised.
  58. """
  59. raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
  60. # request must be HTTP (at least 1.1) GET request, find the location
  61. location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
  62. headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
  63. # Check if headers that MUST be present are actually present
  64. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  65. 'Origin', 'Sec-WebSocket-Version'):
  66. if name not in headers:
  67. raise InvalidRequest('missing "%s" header' % name)
  68. # Check WebSocket version used by client
  69. version = headers['Sec-WebSocket-Version']
  70. if version != WS_VERSION:
  71. raise InvalidRequest('WebSocket version %s requested (only %s '
  72. 'is supported)' % (version, WS_VERSION))
  73. # Make sure the requested protocols are supported by this server
  74. if 'Sec-WebSocket-Protocol' in headers:
  75. parts = headers['Sec-WebSocket-Protocol'].split(',')
  76. protocols = map(str.strip, parts)
  77. for p in protocols:
  78. if p not in self.protocols:
  79. raise InvalidRequest('unsupported protocol "%s"' % p)
  80. else:
  81. protocols = []
  82. # Encode acceptation key using the WebSocket GUID
  83. key = headers['Sec-WebSocket-Key']
  84. accept = sha1(key + WS_GUID).digest().encode('base64')
  85. # Construct HTTP response header
  86. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  87. shake += 'Upgrade: WebSocket\r\n'
  88. shake += 'Connection: Upgrade\r\n'
  89. shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
  90. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  91. % (headers['Host'], location)
  92. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  93. if self.protocols:
  94. shake += 'Sec-WebSocket-Protocol: %s\r\n' \
  95. % ', '.join(self.protocols)
  96. self.sock.send(shake + '\r\n')
  97. self.onopen()
  98. def receive_forever(self):
  99. """
  100. Receive and handle messages in an endless loop. A message may consist
  101. of multiple data frames, but this is not visible for onmessage().
  102. Control messages (or control frames) are handled automatically.
  103. """
  104. try:
  105. while True:
  106. self.onmessage(self, self.receive_message())
  107. if self.received_close_params is not None:
  108. self.handle_close(*self.received_close_params)
  109. break
  110. except SocketClosed:
  111. self.onclose(None, '')
  112. def run_threaded(self, daemon=True):
  113. """
  114. Spawn a new thread that receives messages in an endless loop.
  115. """
  116. thread = Thread(target=self.receive_forever)
  117. thread.daemon = daemon
  118. thread.start()
  119. return thread
  120. def send_close(self, code, reason):
  121. payload = '' if code is None else struct.pack('!H', code)
  122. self.send_frame(ControlFrame(OPCODE_CLOSE, payload))
  123. self.close_frame_sent = True
  124. def send_ping(self, payload=''):
  125. """
  126. Send a ping control frame with an optional payload.
  127. """
  128. self.send_frame(ControlFrame(OPCODE_PING, payload))
  129. self.ping_payload = payload
  130. self.ping_sent = True
  131. self.onping(payload)
  132. def handle_close(self, code=None, reason=''):
  133. """
  134. Handle a close message by sending a response close message if no close
  135. message was sent before, and closing the connection. The onclose()
  136. handler is called afterwards.
  137. """
  138. if not self.close_frame_sent:
  139. payload = '' if code is None else struct.pack('!H', code)
  140. self.send_frame(ControlFrame(OPCODE_CLOSE, payload))
  141. self.sock.close()
  142. self.onclose(code, reason)
  143. def close(self, code=None, reason=''):
  144. """
  145. Close the socket by sending a close message and waiting for a response
  146. close message. The onclose() handler is called after the close message
  147. has been sent, but before the response has been received.
  148. """
  149. self.send_close(code, reason)
  150. # FIXME: swap the two lines below?
  151. self.onclose(code, reason)
  152. frame = receive_frame(self.sock)
  153. self.sock.close()
  154. if frame.opcode != OPCODE_CLOSE:
  155. raise ValueError('Expected close frame, got %s instead' % frame)
  156. def onopen(self):
  157. """
  158. Called after the handshake has completed.
  159. """
  160. pass
  161. def onmessage(self, message):
  162. """
  163. Called when a message is received. `message' is a Message object, which
  164. can be constructed from a single frame or multiple fragmented frames.
  165. """
  166. return NotImplemented
  167. def onping(self, payload):
  168. """
  169. Called after a ping control frame has been sent. This handler could be
  170. used to start a timeout handler for a pong message that is not received
  171. in time.
  172. """
  173. pass
  174. def onpong(self, payload):
  175. """
  176. Called when a pong control frame is received.
  177. """
  178. pass
  179. def onclose(self, code, reason):
  180. """
  181. Called when the socket is closed by either end point.
  182. """
  183. pass