websocket.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import re
  2. import socket
  3. from hashlib import sha1
  4. from frame import receive_frame
  5. from exceptions import InvalidRequest
  6. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  7. WS_VERSION = '13'
  8. class websocket(object):
  9. """
  10. Implementation of web socket, upgrades a regular TCP socket to a websocket
  11. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  12. Server example:
  13. >>> sock = websocket()
  14. >>> sock.bind(('', 80))
  15. >>> sock.listen()
  16. >>> client = sock.accept()
  17. >>> client.send(Frame(...))
  18. >>> frame = client.recv()
  19. Client example:
  20. >>> sock = websocket()
  21. >>> sock.connect(('kompiler.org', 80))
  22. """
  23. def __init__(self, wsprotocols=[], family=socket.AF_INET, proto=0):
  24. """
  25. Create aregular TCP socket of family `family` and protocol
  26. `wsprotocols` is a list of supported protocol names.
  27. """
  28. self.sock = socket.socket(family, socket.SOCK_STREAM, proto)
  29. self.protocols = wsprotocols
  30. def bind(self, address):
  31. self.sock.bind(address)
  32. def listen(self, backlog):
  33. self.sock.listen(backlog)
  34. def accept(self):
  35. """
  36. Equivalent to socket.accept(), but transforms the socket into a
  37. websocket instance and sends a server handshake (after receiving a
  38. client handshake). Note that the handshake may raise an InvalidRequest
  39. exception.
  40. """
  41. client, address = socket.socket.accept(self)
  42. client = websocket(client)
  43. client.server_handshake()
  44. return client, address
  45. def connect(self, address):
  46. """
  47. Equivalent to socket.connect(), but sends an client handshake request
  48. after connecting.
  49. """
  50. self.sock.sonnect(address)
  51. self.client_handshake()
  52. def send(self, *args):
  53. """
  54. Send a number of frames.
  55. """
  56. for frame in args:
  57. self.sock.sendall(frame.pack())
  58. def recv(self, n=1):
  59. """
  60. Receive exactly `n` frames. These can be either data frames or control
  61. frames, or a combination of both.
  62. """
  63. return [receive_frame(self.sock) for i in xrange(n)]
  64. def getpeername(self):
  65. return self.sock.getpeername()
  66. def getsockname(self):
  67. return self.sock.getpeername()
  68. def setsockopt(self, level, optname, value):
  69. self.sock.setsockopt(level, optname, value)
  70. def getsockopt(self, level, optname):
  71. return self.sock.getsockopt(level, optname)
  72. def server_handshake(self):
  73. """
  74. Execute a handshake as the server end point of the socket. If the HTTP
  75. request headers sent by the client are invalid, an InvalidRequest
  76. exception is raised.
  77. """
  78. raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
  79. # request must be HTTP (at least 1.1) GET request, find the location
  80. location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
  81. headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
  82. # Check if headers that MUST be present are actually present
  83. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  84. 'Origin', 'Sec-WebSocket-Version'):
  85. if name not in headers:
  86. raise InvalidRequest('missing "%s" header' % name)
  87. # Check WebSocket version used by client
  88. version = headers['Sec-WebSocket-Version']
  89. if version != WS_VERSION:
  90. raise InvalidRequest('WebSocket version %s requested (only %s '
  91. 'is supported)' % (version, WS_VERSION))
  92. # Make sure the requested protocols are supported by this server
  93. if 'Sec-WebSocket-Protocol' in headers:
  94. parts = headers['Sec-WebSocket-Protocol'].split(',')
  95. protocols = map(str.strip, parts)
  96. for p in protocols:
  97. if p not in self.protocols:
  98. raise InvalidRequest('unsupported protocol "%s"' % p)
  99. else:
  100. protocols = []
  101. # Encode acceptation key using the WebSocket GUID
  102. key = headers['Sec-WebSocket-Key']
  103. accept = sha1(key + WS_GUID).digest().encode('base64')
  104. # Construct HTTP response header
  105. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  106. shake += 'Upgrade: WebSocket\r\n'
  107. shake += 'Connection: Upgrade\r\n'
  108. shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
  109. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  110. % (headers['Host'], location)
  111. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  112. if self.protocols:
  113. shake += 'Sec-WebSocket-Protocol: %s\r\n' \
  114. % ', '.join(self.protocols)
  115. self.sock.send(shake + '\r\n')
  116. def client_handshake(self):
  117. # TODO: implement HTTP request headers for client handshake
  118. raise NotImplementedError()