websocket.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import re
  2. import socket
  3. from hashlib import sha1
  4. from frame import receive_frame
  5. from exceptions import InvalidRequest
  6. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  7. WS_VERSION = '13'
  8. def split_stripped(value, delim=','):
  9. return map(str.strip, value.split(delim))
  10. class websocket(object):
  11. """
  12. Implementation of web socket, upgrades a regular TCP socket to a websocket
  13. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  14. Server example:
  15. >>> sock = websocket()
  16. >>> sock.bind(('', 80))
  17. >>> sock.listen()
  18. >>> client = sock.accept()
  19. >>> client.send(Frame(...))
  20. >>> frame = client.recv()
  21. Client example:
  22. >>> sock = websocket()
  23. >>> sock.connect(('kompiler.org', 80))
  24. """
  25. def __init__(self, protocols=[], extensions=[], family=socket.AF_INET,
  26. proto=0):
  27. """
  28. Create a regular TCP socket of family `family` and protocol
  29. `protocols` is a list of supported protocol names.
  30. `extensions` is a list of supported extensions.
  31. `family` and `proto` are used for the regular socket constructor.
  32. """
  33. self.protocols = protocols
  34. self.extensions = extensions
  35. self.sock = socket.socket(family, socket.SOCK_STREAM, proto)
  36. def bind(self, address):
  37. self.sock.bind(address)
  38. def listen(self, backlog):
  39. self.sock.listen(backlog)
  40. def accept(self):
  41. """
  42. Equivalent to socket.accept(), but transforms the socket into a
  43. websocket instance and sends a server handshake (after receiving a
  44. client handshake). Note that the handshake may raise an InvalidRequest
  45. exception.
  46. """
  47. client, address = socket.socket.accept(self)
  48. client = websocket(client)
  49. client.server_handshake()
  50. return client, address
  51. def connect(self, address):
  52. """
  53. Equivalent to socket.connect(), but sends an client handshake request
  54. after connecting.
  55. """
  56. self.sock.sonnect(address)
  57. self.client_handshake()
  58. def send(self, *args):
  59. """
  60. Send a number of frames.
  61. """
  62. for frame in args:
  63. self.sock.sendall(frame.pack())
  64. def recv(self, n=1):
  65. """
  66. Receive exactly `n` frames. These can be either data frames or control
  67. frames, or a combination of both.
  68. """
  69. return [receive_frame(self.sock) for i in xrange(n)]
  70. def getpeername(self):
  71. return self.sock.getpeername()
  72. def getsockname(self):
  73. return self.sock.getpeername()
  74. def setsockopt(self, level, optname, value):
  75. self.sock.setsockopt(level, optname, value)
  76. def getsockopt(self, level, optname):
  77. return self.sock.getsockopt(level, optname)
  78. def server_handshake(self):
  79. """
  80. Execute a handshake as the server end point of the socket. If the HTTP
  81. request headers sent by the client are invalid, an InvalidRequest
  82. exception is raised.
  83. """
  84. raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
  85. # request must be HTTP (at least 1.1) GET request, find the location
  86. location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
  87. headers = re.findall(r'(.*?): (.*?)\r\n', raw_headers)
  88. header_names = [name for name, value in headers]
  89. def header(name):
  90. return ', '.join([v for n, v in headers if n == name])
  91. # Check if headers that MUST be present are actually present
  92. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  93. 'Origin', 'Sec-WebSocket-Version'):
  94. if name not in header_names:
  95. raise InvalidRequest('missing "%s" header' % name)
  96. # Check WebSocket version used by client
  97. version = header('Sec-WebSocket-Version')
  98. if version != WS_VERSION:
  99. raise InvalidRequest('WebSocket version %s requested (only %s '
  100. 'is supported)' % (version, WS_VERSION))
  101. # Only supported protocols are returned
  102. proto = header('Sec-WebSocket-Extensions')
  103. protocols = split_stripped(proto) if proto else []
  104. protocols = [p for p in protocols if p in self.protocols]
  105. # Only supported extensions are returned
  106. ext = header('Sec-WebSocket-Extensions')
  107. extensions = split_stripped(ext) if ext else []
  108. extensions = [e for e in extensions if e in self.extensions]
  109. # Encode acceptation key using the WebSocket GUID
  110. key = header('Sec-WebSocket-Key')
  111. accept = sha1(key + WS_GUID).digest().encode('base64')
  112. # Construct HTTP response header
  113. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  114. shake += 'Upgrade: WebSocket\r\n'
  115. shake += 'Connection: Upgrade\r\n'
  116. shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
  117. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  118. % (header('Host'), location)
  119. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  120. if protocols:
  121. shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
  122. if extensions:
  123. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
  124. self.sock.send(shake + '\r\n')
  125. def client_handshake(self):
  126. # TODO: implement HTTP request headers for client handshake
  127. raise NotImplementedError()