websocket.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import os
  2. import re
  3. import socket
  4. import ssl
  5. from hashlib import sha1
  6. from base64 import b64encode
  7. from frame import receive_frame
  8. from errors import HandshakeError, SSLError
  9. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  10. WS_VERSION = '13'
  11. def split_stripped(value, delim=','):
  12. return map(str.strip, str(value).split(delim)) if value else []
  13. class websocket(object):
  14. """
  15. Implementation of web socket, upgrades a regular TCP socket to a websocket
  16. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  17. The API of a websocket is identical to that of a regular socket, as
  18. illustrated by the examples below.
  19. Server example:
  20. >>> import twspy, socket
  21. >>> sock = twspy.websocket()
  22. >>> sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  23. >>> sock.bind(('', 8000))
  24. >>> sock.listen()
  25. >>> client = sock.accept()
  26. >>> client.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Client!'))
  27. >>> frame = client.recv()
  28. Client example:
  29. >>> import twspy
  30. >>> sock = twspy.websocket()
  31. >>> sock.connect(('', 8000))
  32. >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
  33. """
  34. def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET,
  35. sproto=0):
  36. """
  37. Create a regular TCP socket of family `family` and protocol
  38. `sock` is an optional regular TCP socket to be used for sending binary
  39. data. If not specified, a new socket is created.
  40. `protocols` is a list of supported protocol names.
  41. `extensions` is a list of supported extensions.
  42. `sfamily` and `sproto` are used for the regular socket constructor.
  43. """
  44. self.protocols = protocols
  45. self.extensions = extensions
  46. self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
  47. self.secure = False
  48. self.handshake_started = False
  49. def bind(self, address):
  50. self.sock.bind(address)
  51. def listen(self, backlog):
  52. self.sock.listen(backlog)
  53. def accept(self):
  54. """
  55. Equivalent to socket.accept(), but transforms the socket into a
  56. websocket instance and sends a server handshake (after receiving a
  57. client handshake). Note that the handshake may raise a HandshakeError
  58. exception.
  59. """
  60. sock, address = self.sock.accept()
  61. wsock = websocket(sock)
  62. wsock.server_handshake()
  63. return wsock, address
  64. def connect(self, address, path='/'):
  65. """
  66. Equivalent to socket.connect(), but sends an client handshake request
  67. after connecting.
  68. `address` is a (host, port) tuple of the server to connect to.
  69. `path` is optional, used as the *location* part of the HTTP handshake.
  70. In a URL, this would show as ws://host[:port]/path.
  71. """
  72. self.sock.connect(address)
  73. self.client_handshake(address, path)
  74. def send(self, *args):
  75. """
  76. Send a number of frames.
  77. """
  78. for frame in args:
  79. #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
  80. self.sock.sendall(frame.pack())
  81. def recv(self):
  82. """
  83. Receive a single frames. This can be either a data frame or a control
  84. frame.
  85. """
  86. frame = receive_frame(self.sock)
  87. #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
  88. return frame
  89. def recvn(self, n):
  90. """
  91. Receive exactly `n` frames. These can be either data frames or control
  92. frames, or a combination of both.
  93. """
  94. return [self.recv() for i in xrange(n)]
  95. def getpeername(self):
  96. return self.sock.getpeername()
  97. def getsockname(self):
  98. return self.sock.getsockname()
  99. def setsockopt(self, level, optname, value):
  100. self.sock.setsockopt(level, optname, value)
  101. def getsockopt(self, level, optname):
  102. return self.sock.getsockopt(level, optname)
  103. def close(self):
  104. self.sock.close()
  105. def server_handshake(self):
  106. """
  107. Execute a handshake as the server end point of the socket. If the HTTP
  108. request headers sent by the client are invalid, a HandshakeError
  109. is raised.
  110. """
  111. def fail(msg):
  112. self.sock.close()
  113. raise HandshakeError(msg)
  114. # Receive HTTP header
  115. raw_headers = ''
  116. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  117. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  118. # Request must be HTTP (at least 1.1) GET request, find the location
  119. match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
  120. if match is None:
  121. fail('not a valid HTTP 1.1 GET request')
  122. location = match.group(1)
  123. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  124. header_names = [name for name, value in headers]
  125. def header(name):
  126. return ', '.join([v for n, v in headers if n == name])
  127. # Check if headers that MUST be present are actually present
  128. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  129. 'Sec-WebSocket-Version'):
  130. if name not in header_names:
  131. fail('missing "%s" header' % name)
  132. # Check WebSocket version used by client
  133. version = header('Sec-WebSocket-Version')
  134. if version != WS_VERSION:
  135. fail('WebSocket version %s requested (only %s '
  136. 'is supported)' % (version, WS_VERSION))
  137. # Verify required header keywords
  138. if 'websocket' not in header('Upgrade').lower():
  139. fail('"Upgrade" header must contain "websocket"')
  140. if 'upgrade' not in header('Connection').lower():
  141. fail('"Connection" header must contain "Upgrade"')
  142. # Origin must be present if browser client
  143. if 'User-Agent' in header_names and 'Origin' not in header_names:
  144. fail('browser client must specify "Origin" header')
  145. # Only supported protocols are returned
  146. client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
  147. protocol = 'null'
  148. for p in client_protocols:
  149. if p in self.protocols:
  150. protocol = p
  151. break
  152. # Only supported extensions are returned
  153. extensions = split_stripped(header('Sec-WebSocket-Extensions'))
  154. extensions = [e for e in extensions if e in self.extensions]
  155. # Encode acceptation key using the WebSocket GUID
  156. key = header('Sec-WebSocket-Key').strip()
  157. accept = b64encode(sha1(key + WS_GUID).digest())
  158. # Construct HTTP response header
  159. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  160. shake += 'Upgrade: websocket\r\n'
  161. shake += 'Connection: Upgrade\r\n'
  162. shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
  163. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  164. % (header('Host'), location)
  165. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  166. shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
  167. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
  168. self.sock.sendall(shake + '\r\n')
  169. self.handshake_started = True
  170. def client_handshake(self, address, location):
  171. """
  172. Executes a handshake as the client end point of the socket. May raise a
  173. HandshakeError if the server response is invalid.
  174. """
  175. def fail(msg):
  176. self.sock.close()
  177. raise HandshakeError(msg)
  178. if len(location) == 0:
  179. fail('request location is empty')
  180. # Generate a 16-byte random base64-encoded key for this connection
  181. key = b64encode(os.urandom(16))
  182. # Send client handshake
  183. shake = 'GET %s HTTP/1.1\r\n' % location
  184. shake += 'Host: %s:%d\r\n' % address
  185. shake += 'Upgrade: websocket\r\n'
  186. shake += 'Connection: keep-alive, Upgrade\r\n'
  187. shake += 'Sec-WebSocket-Key: %s\r\n' % key
  188. shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
  189. shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
  190. # These are for eagerly caching webservers
  191. shake += 'Pragma: no-cache\r\n'
  192. shake += 'Cache-Control: no-cache\r\n'
  193. # Request protocols and extension, these are later checked with the
  194. # actual supported values from the server's response
  195. if self.protocols:
  196. shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
  197. if self.extensions:
  198. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
  199. self.sock.sendall(shake + '\r\n')
  200. self.handshake_started = True
  201. # Receive and process server handshake
  202. raw_headers = ''
  203. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  204. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  205. # Response must be HTTP (at least 1.1) with status 101
  206. if not raw_headers.startswith('HTTP/1.1 101'):
  207. # TODO: implement HTTP authentication (401) and redirect (3xx)?
  208. fail('not a valid HTTP 1.1 status 101 response')
  209. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  210. header_names = [name for name, value in headers]
  211. def header(name):
  212. return ', '.join([v for n, v in headers if n == name])
  213. # Check if headers that MUST be present are actually present
  214. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  215. if name not in header_names:
  216. fail('missing "%s" header' % name)
  217. if 'websocket' not in header('Upgrade').lower():
  218. fail('"Upgrade" header must contain "websocket"')
  219. if 'upgrade' not in header('Connection').lower():
  220. fail('"Connection" header must contain "Upgrade"')
  221. # Verify accept header
  222. accept = header('Sec-WebSocket-Accept').strip()
  223. required_accept = b64encode(sha1(key + WS_GUID).digest())
  224. if accept != required_accept:
  225. fail('invalid websocket accept header "%s"' % accept)
  226. # Compare extensions
  227. server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
  228. for ext in server_extensions:
  229. if ext not in self.extensions:
  230. fail('server extension "%s" is unsupported by client' % ext)
  231. for ext in self.extensions:
  232. if ext not in server_extensions:
  233. fail('client extension "%s" is unsupported by server' % ext)
  234. # Assert that returned protocol is supported
  235. protocol = header('Sec-WebSocket-Protocol')
  236. if protocol:
  237. if protocol != 'null' and protocol not in self.protocols:
  238. fail('unsupported protocol "%s"' % protocol)
  239. self.protocol = protocol
  240. def enable_ssl(self, *args, **kwargs):
  241. """
  242. Transforms the regular socket.socket to an ssl.SSLSocket for secure
  243. connections. Any arguments are passed to ssl.wrap_socket:
  244. http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
  245. """
  246. if self.handshake_started:
  247. raise SSLError('can only enable SSL before handshake')
  248. self.secure = True
  249. self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)