websocket.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import re
  2. from hashlib import sha1
  3. from threading import Thread
  4. from frame import receive_fragments
  5. from message import create_message
  6. from exceptions import SocketClosed
  7. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  8. WS_VERSION = '13'
  9. class WebSocket(object):
  10. def __init__(self, sock, address, encoding=None):
  11. super(WebSocket, self).__init__(sock)
  12. self.address = address
  13. self.encoding = encoding
  14. def send_message(self, message, fragment_size=None):
  15. if fragment_size is None:
  16. self.send_frame(message.frame())
  17. else:
  18. map(self.send_frame, message.fragment(fragment_size))
  19. def send_frame(self, frame):
  20. self.sock.sendall(frame.pack())
  21. def receive_message(self):
  22. frames = receive_fragments(self.sock)
  23. payload = ''.join([f.payload for f in frames])
  24. return create_message(frames[0].opcode, payload)
  25. def handshake(self):
  26. raw_headers = self.sock.recv(512)
  27. if self.encoding:
  28. raw_headers = raw_headers.decode(self.encoding, 'ignore')
  29. location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
  30. headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
  31. # Check if headers that MUST be present are actually present
  32. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  33. 'Origin', 'Sec-WebSocket-Version'):
  34. assert name in headers
  35. # Check WebSocket version used by client
  36. assert headers['Sec-WebSocket-Version'] == WS_VERSION
  37. # Make sure the requested protocols are supported by this server
  38. if 'Sec-WebSocket-Protocol' in headers:
  39. parts = headers['Sec-WebSocket-Protocol'].split(',')
  40. protocols = map(str.strip, parts)
  41. for protocol in protocols:
  42. assert protocol in self.protocols
  43. else:
  44. protocols = []
  45. key = headers['Sec-WebSocket-Key']
  46. accept = sha1(key + WS_GUID).digest().encode('base64')
  47. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  48. shake += 'Upgrade: WebSocket\r\n'
  49. shake += 'Connection: Upgrade\r\n'
  50. shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
  51. shake += 'WebSocket-Location: ws://%s%s\r\n' % (headers['Host'], location)
  52. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  53. if self.protocols:
  54. shake += 'Sec-WebSocket-Protocol: %s\r\n' \
  55. % ', '.join(self.protocols)
  56. self.sock.send(shake + '\r\n')
  57. self.onopen()
  58. def receive_forever(self):
  59. try:
  60. while True:
  61. self.onmessage(self, self.receive_message())
  62. except SocketClosed:
  63. self.onclose()
  64. def run_threaded(self, daemon=True):
  65. t = Thread(target=self.receive_forever)
  66. t.daemon = daemon
  67. t.start()
  68. def onopen(self):
  69. """
  70. Called after the handshake has completed.
  71. """
  72. pass
  73. def onmessage(self, message):
  74. """
  75. Called when a message is received. `message' is a Message object, which
  76. can be constructed from a single frame or multiple fragmented frames.
  77. """
  78. raise NotImplemented
  79. def onclose(self):
  80. """
  81. Called when the other end of the socket disconnects.
  82. """
  83. pass
  84. def close(self):
  85. raise SocketClosed()