websocket.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import re
  2. import socket
  3. from hashlib import sha1
  4. from frame import receive_frame
  5. from exceptions import InvalidRequest
  6. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  7. WS_VERSION = '13'
  8. class websocket(object):
  9. """
  10. Implementation of web socket, upgrades a regular TCP socket to a websocket
  11. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  12. Server example:
  13. >>> sock = websocket()
  14. >>> sock.bind(('', 80))
  15. >>> sock.listen()
  16. >>> client = sock.accept()
  17. >>> client.send(Frame(...))
  18. >>> frame = client.recv()
  19. Client example:
  20. >>> sock = websocket()
  21. >>> sock.connect(('kompiler.org', 80))
  22. """
  23. def __init__(self, wsprotocols=[], family=socket.AF_INET, proto=0):
  24. """
  25. Create aregular TCP socket of family `family` and protocol
  26. `wsprotocols` is a list of supported protocol names.
  27. """
  28. self.sock = socket.socket(family, socket.SOCK_STREAM, proto)
  29. self.protocols = wsprotocols
  30. def bind(self, address):
  31. self.sock.bind(address)
  32. def listen(self, backlog):
  33. self.sock.listen(backlog)
  34. def accept(self):
  35. client, address = socket.socket.accept(self)
  36. client = websocket(client)
  37. client.server_handshake()
  38. return client, address
  39. def connect(self, address):
  40. """
  41. Equivalent to socket.connect(), but sends an HTTP handshake request
  42. after connecting.
  43. """
  44. self.sock.sonnect(address)
  45. self.client_handshake()
  46. def send(self, *args):
  47. """
  48. Send a number of frames.
  49. """
  50. for frame in args:
  51. self.sock.sendall(frame.pack())
  52. def recv(self, n=1):
  53. """
  54. Receive exactly `n` frames. These can be either data frames or control
  55. frames, or a combination of both.
  56. """
  57. return [receive_frame(self.sock) for i in xrange(n)]
  58. def getpeername(self):
  59. return self.sock.getpeername()
  60. def getsockname(self):
  61. return self.sock.getpeername()
  62. def setsockopt(self, level, optname, value):
  63. self.sock.setsockopt(level, optname, value)
  64. def getsockopt(self, level, optname):
  65. return self.sock.getsockopt(level, optname)
  66. def server_handshake(self):
  67. """
  68. Execute a handshake as the server end point of the socket. If the HTTP
  69. request headers sent by the client are invalid, an InvalidRequest
  70. exception is raised.
  71. """
  72. raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
  73. # request must be HTTP (at least 1.1) GET request, find the location
  74. location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
  75. headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
  76. # Check if headers that MUST be present are actually present
  77. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  78. 'Origin', 'Sec-WebSocket-Version'):
  79. if name not in headers:
  80. raise InvalidRequest('missing "%s" header' % name)
  81. # Check WebSocket version used by client
  82. version = headers['Sec-WebSocket-Version']
  83. if version != WS_VERSION:
  84. raise InvalidRequest('WebSocket version %s requested (only %s '
  85. 'is supported)' % (version, WS_VERSION))
  86. # Make sure the requested protocols are supported by this server
  87. if 'Sec-WebSocket-Protocol' in headers:
  88. parts = headers['Sec-WebSocket-Protocol'].split(',')
  89. protocols = map(str.strip, parts)
  90. for p in protocols:
  91. if p not in self.protocols:
  92. raise InvalidRequest('unsupported protocol "%s"' % p)
  93. else:
  94. protocols = []
  95. # Encode acceptation key using the WebSocket GUID
  96. key = headers['Sec-WebSocket-Key']
  97. accept = sha1(key + WS_GUID).digest().encode('base64')
  98. # Construct HTTP response header
  99. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  100. shake += 'Upgrade: WebSocket\r\n'
  101. shake += 'Connection: Upgrade\r\n'
  102. shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
  103. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  104. % (headers['Host'], location)
  105. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  106. if self.protocols:
  107. shake += 'Sec-WebSocket-Protocol: %s\r\n' \
  108. % ', '.join(self.protocols)
  109. self.sock.send(shake + '\r\n')
  110. def client_handshake(self):
  111. # TODO: implement HTTP request headers for client handshake
  112. raise NotImplementedError()