websocket.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import re
  2. import socket
  3. from hashlib import sha1
  4. from base64 import b64encode
  5. from frame import receive_frame
  6. from errors import HandshakeError
  7. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  8. WS_VERSION = '13'
  9. def split_stripped(value, delim=','):
  10. return map(str.strip, str(value).split(delim))
  11. class websocket(object):
  12. """
  13. Implementation of web socket, upgrades a regular TCP socket to a websocket
  14. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  15. Server example:
  16. >>> sock = websocket()
  17. >>> sock.bind(('', 8000))
  18. >>> sock.listen()
  19. >>> client = sock.accept()
  20. >>> client.send(Frame(...))
  21. >>> frame = client.recv()
  22. Client example:
  23. >>> sock = websocket()
  24. >>> sock.connect(('', 8000))
  25. """
  26. def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET,
  27. sproto=0):
  28. """
  29. Create a regular TCP socket of family `family` and protocol
  30. `sock` is an optional regular TCP socket to be used for sending binary
  31. data. If not specified, a new socket is created.
  32. `protocols` is a list of supported protocol names.
  33. `extensions` is a list of supported extensions.
  34. `sfamily` and `sproto` are used for the regular socket constructor.
  35. """
  36. self.protocols = protocols
  37. self.extensions = extensions
  38. self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
  39. def bind(self, address):
  40. self.sock.bind(address)
  41. def listen(self, backlog):
  42. self.sock.listen(backlog)
  43. def accept(self):
  44. """
  45. Equivalent to socket.accept(), but transforms the socket into a
  46. websocket instance and sends a server handshake (after receiving a
  47. client handshake). Note that the handshake may raise a HandshakeError
  48. exception.
  49. """
  50. sock, address = self.sock.accept()
  51. wsock = websocket(sock)
  52. wsock.server_handshake()
  53. return wsock, address
  54. def connect(self, address):
  55. """
  56. Equivalent to socket.connect(), but sends an client handshake request
  57. after connecting.
  58. """
  59. self.sock.sonnect(address)
  60. self.client_handshake()
  61. def send(self, *args):
  62. """
  63. Send a number of frames.
  64. """
  65. for frame in args:
  66. #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
  67. self.sock.sendall(frame.pack())
  68. def recv(self):
  69. """
  70. Receive a single frames. This can be either a data frame or a control
  71. frame.
  72. """
  73. frame = receive_frame(self.sock)
  74. #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
  75. return frame
  76. def recvn(self, n):
  77. """
  78. Receive exactly `n` frames. These can be either data frames or control
  79. frames, or a combination of both.
  80. """
  81. return [self.recv() for i in xrange(n)]
  82. def getpeername(self):
  83. return self.sock.getpeername()
  84. def getsockname(self):
  85. return self.sock.getsockname()
  86. def setsockopt(self, level, optname, value):
  87. self.sock.setsockopt(level, optname, value)
  88. def getsockopt(self, level, optname):
  89. return self.sock.getsockopt(level, optname)
  90. def close(self):
  91. self.sock.close()
  92. def server_handshake(self):
  93. """
  94. Execute a handshake as the server end point of the socket. If the HTTP
  95. request headers sent by the client are invalid, a HandshakeError
  96. is raised.
  97. """
  98. raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
  99. # request must be HTTP (at least 1.1) GET request, find the location
  100. location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
  101. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  102. header_names = [name for name, value in headers]
  103. def header(name):
  104. return ', '.join([v for n, v in headers if n == name])
  105. # Check if headers that MUST be present are actually present
  106. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  107. 'Origin', 'Sec-WebSocket-Version'):
  108. if name not in header_names:
  109. raise HandshakeError('missing "%s" header' % name)
  110. # Check WebSocket version used by client
  111. version = header('Sec-WebSocket-Version')
  112. if version != WS_VERSION:
  113. raise HandshakeError('WebSocket version %s requested (only %s '
  114. 'is supported)' % (version, WS_VERSION))
  115. # Only supported protocols are returned
  116. proto = header('Sec-WebSocket-Extensions')
  117. protocols = split_stripped(proto) if proto else []
  118. protocols = [p for p in protocols if p in self.protocols]
  119. # Only supported extensions are returned
  120. ext = header('Sec-WebSocket-Extensions')
  121. extensions = split_stripped(ext) if ext else []
  122. extensions = [e for e in extensions if e in self.extensions]
  123. # Encode acceptation key using the WebSocket GUID
  124. key = header('Sec-WebSocket-Key').strip()
  125. accept = b64encode(sha1(key + WS_GUID).digest())
  126. # Construct HTTP response header
  127. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  128. shake += 'Upgrade: WebSocket\r\n'
  129. shake += 'Connection: Upgrade\r\n'
  130. shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
  131. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  132. % (header('Host'), location)
  133. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  134. if protocols:
  135. shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
  136. if extensions:
  137. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
  138. self.sock.sendall(shake + '\r\n')
  139. def client_handshake(self):
  140. """
  141. Execute a handshake as the client end point of the socket. May raise a
  142. HandshakeError if the server response is invalid.
  143. """
  144. # TODO: implement HTTP request headers for client handshake
  145. raise NotImplementedError()