websocket.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import re
  2. import socket
  3. from hashlib import sha1
  4. from frame import receive_frame
  5. from exceptions import HandshakeError
  6. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  7. WS_VERSION = '13'
  8. def split_stripped(value, delim=','):
  9. return map(str.strip, value.split(delim))
  10. class websocket(object):
  11. """
  12. Implementation of web socket, upgrades a regular TCP socket to a websocket
  13. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  14. Server example:
  15. >>> sock = websocket()
  16. >>> sock.bind(('', 80))
  17. >>> sock.listen()
  18. >>> client = sock.accept()
  19. >>> client.send(Frame(...))
  20. >>> frame = client.recv()
  21. Client example:
  22. >>> sock = websocket()
  23. >>> sock.connect(('kompiler.org', 80))
  24. """
  25. def __init__(self, protocols=[], extensions=[], sfamily=socket.AF_INET,
  26. sproto=0):
  27. """
  28. Create a regular TCP socket of family `family` and protocol
  29. `protocols` is a list of supported protocol names.
  30. `extensions` is a list of supported extensions.
  31. `sfamily` and `sproto` are used for the regular socket constructor.
  32. """
  33. self.protocols = protocols
  34. self.extensions = extensions
  35. self.sock = socket.socket(sfamily, socket.SOCK_STREAM, sproto)
  36. def bind(self, address):
  37. self.sock.bind(address)
  38. def listen(self, backlog):
  39. self.sock.listen(backlog)
  40. def accept(self):
  41. """
  42. Equivalent to socket.accept(), but transforms the socket into a
  43. websocket instance and sends a server handshake (after receiving a
  44. client handshake). Note that the handshake may raise an HandshakeError
  45. exception.
  46. """
  47. client, address = self.sock.accept()
  48. client = websocket(client)
  49. client.server_handshake()
  50. return client, address
  51. def connect(self, address):
  52. """
  53. Equivalent to socket.connect(), but sends an client handshake request
  54. after connecting.
  55. """
  56. self.sock.sonnect(address)
  57. self.client_handshake()
  58. def send(self, *args):
  59. """
  60. Send a number of frames.
  61. """
  62. for frame in args:
  63. self.sock.sendall(frame.pack())
  64. def recv(self, n=1):
  65. """
  66. Receive exactly `n` frames. These can be either data frames or control
  67. frames, or a combination of both.
  68. """
  69. return [receive_frame(self.sock) for i in xrange(n)]
  70. def getpeername(self):
  71. return self.sock.getpeername()
  72. def getsockname(self):
  73. return self.sock.getsockname()
  74. def setsockopt(self, level, optname, value):
  75. self.sock.setsockopt(level, optname, value)
  76. def getsockopt(self, level, optname):
  77. return self.sock.getsockopt(level, optname)
  78. def server_handshake(self):
  79. """
  80. Execute a handshake as the server end point of the socket. If the HTTP
  81. request headers sent by the client are invalid, a HandshakeError
  82. is raised.
  83. """
  84. raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
  85. # request must be HTTP (at least 1.1) GET request, find the location
  86. location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
  87. headers = re.findall(r'(.*?): (.*?)\r\n', raw_headers)
  88. header_names = [name for name, value in headers]
  89. def header(name):
  90. return ', '.join([v for n, v in headers if n == name])
  91. # Check if headers that MUST be present are actually present
  92. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  93. 'Origin', 'Sec-WebSocket-Version'):
  94. if name not in header_names:
  95. raise HandshakeError('missing "%s" header' % name)
  96. # Check WebSocket version used by client
  97. version = header('Sec-WebSocket-Version')
  98. if version != WS_VERSION:
  99. raise HandshakeError('WebSocket version %s requested (only %s '
  100. 'is supported)' % (version, WS_VERSION))
  101. # Only supported protocols are returned
  102. proto = header('Sec-WebSocket-Extensions')
  103. protocols = split_stripped(proto) if proto else []
  104. protocols = [p for p in protocols if p in self.protocols]
  105. # Only supported extensions are returned
  106. ext = header('Sec-WebSocket-Extensions')
  107. extensions = split_stripped(ext) if ext else []
  108. extensions = [e for e in extensions if e in self.extensions]
  109. # Encode acceptation key using the WebSocket GUID
  110. key = header('Sec-WebSocket-Key')
  111. accept = sha1(key + WS_GUID).digest().encode('base64')
  112. # Construct HTTP response header
  113. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  114. shake += 'Upgrade: WebSocket\r\n'
  115. shake += 'Connection: Upgrade\r\n'
  116. shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
  117. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  118. % (header('Host'), location)
  119. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  120. if protocols:
  121. shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
  122. if extensions:
  123. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
  124. self.sock.send(shake + '\r\n')
  125. def client_handshake(self):
  126. """
  127. Execute a handshake as the client end point of the socket. May raise a
  128. HandshakeError if the server response is invalid.
  129. """
  130. # TODO: implement HTTP request headers for client handshake
  131. raise NotImplementedError()