handshake.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import os
  2. import re
  3. from hashlib import sha1
  4. from base64 import b64encode
  5. from urlparse import urlparse
  6. from python_digest import build_authorization_request
  7. from errors import HandshakeError
  8. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  9. WS_VERSION = '13'
  10. MAX_REDIRECTS = 10
  11. def split_stripped(value, delim=','):
  12. return map(str.strip, str(value).split(delim)) if value else []
  13. class Handshake(object):
  14. def __init__(self, wsock):
  15. self.wsock = wsock
  16. self.sock = wsock.sock
  17. def fail(self, msg):
  18. self.sock.close()
  19. raise HandshakeError(msg)
  20. def receive_request(self):
  21. raw, headers = self.receive_headers()
  22. # Request must be HTTP (at least 1.1) GET request, find the location
  23. match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw)
  24. if match is None:
  25. self.fail('not a valid HTTP 1.1 GET request')
  26. location = match.group(1)
  27. return location, headers
  28. def receive_response(self):
  29. raw, headers = self.receive_headers()
  30. # Response must be HTTP (at least 1.1) with status 101
  31. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  32. if match is None:
  33. self.fail('not a valid HTTP 1.1 response')
  34. status = int(match.group(1))
  35. return status, headers
  36. def receive_headers(self):
  37. # Receive entire HTTP header
  38. raw_headers = ''
  39. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  40. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  41. headers = {}
  42. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', raw_headers):
  43. if key in headers:
  44. headers[key] += ', ' + value
  45. else:
  46. headers[key] = value
  47. return raw_headers, headers
  48. def send_headers(self, headers):
  49. # Send request
  50. for hdr in list(headers):
  51. if isinstance(hdr, tuple):
  52. hdr = '%s: %s' % hdr
  53. self.sock.sendall(hdr + '\r\n')
  54. self.sock.sendall('\r\n')
  55. def perform(self):
  56. raise NotImplementedError
  57. class ServerHandshake(Handshake):
  58. """
  59. Executes a handshake as the server end point of the socket. If the HTTP
  60. request headers sent by the client are invalid, a HandshakeError is raised.
  61. """
  62. def perform(self):
  63. # Receive and validate client handshake
  64. self.wsock.location, headers = self.receive_request()
  65. # Send server handshake in response
  66. self.send_headers(self.response_headers(headers))
  67. def response_headers(self, headers):
  68. # Check if headers that MUST be present are actually present
  69. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  70. 'Sec-WebSocket-Version'):
  71. if name not in headers:
  72. self.fail('missing "%s" header' % name)
  73. # Check WebSocket version used by client
  74. version = headers['Sec-WebSocket-Version']
  75. if version != WS_VERSION:
  76. self.fail('WebSocket version %s requested (only %s is supported)'
  77. % (version, WS_VERSION))
  78. # Verify required header keywords
  79. if 'websocket' not in headers['Upgrade'].lower():
  80. self.fail('"Upgrade" header must contain "websocket"')
  81. if 'upgrade' not in headers['Connection'].lower():
  82. self.fail('"Connection" header must contain "Upgrade"')
  83. # Origin must be present if browser client, and must match the list of
  84. # trusted origins
  85. origin = 'null'
  86. if 'Origin' not in headers:
  87. if 'User-Agent' in headers:
  88. self.fail('browser client must specify "Origin" header')
  89. if self.wsock.trusted_origins:
  90. self.fail('no "Origin" header specified, assuming untrusted')
  91. elif self.wsock.trusted_origins:
  92. origin = headers['Origin']
  93. if origin not in self.wsock.trusted_origins:
  94. self.fail('untrusted origin "%s"' % origin)
  95. # Only a supported protocol can be returned
  96. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  97. if 'Sec-WebSocket-Protocol' in headers else []
  98. protocol = None
  99. for p in client_proto:
  100. if p in self.wsock.proto:
  101. protocol = p
  102. break
  103. # Only supported extensions are returned
  104. if 'Sec-WebSocket-Extensions' in headers:
  105. client_ext = split_stripped(headers['Sec-WebSocket-Extensions'])
  106. extensions = [e for e in client_ext if e in self.wsock.extensions]
  107. else:
  108. extensions = []
  109. # Encode acceptation key using the WebSocket GUID
  110. key = headers['Sec-WebSocket-Key'].strip()
  111. accept = b64encode(sha1(key + WS_GUID).digest())
  112. # Location scheme differs for SSL-enabled connections
  113. scheme = 'wss' if self.wsock.secure else 'ws'
  114. if 'Host' in headers:
  115. host = headers['Host']
  116. else:
  117. host, port = self.sock.getpeername()
  118. default_port = 443 if self.wsock.secure else 80
  119. if port != default_port:
  120. host += ':%d' % port
  121. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  122. # Construct HTTP response header
  123. yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
  124. yield 'Upgrade', 'websocket'
  125. yield 'Connection', 'Upgrade'
  126. yield 'WebSocket-Origin', origin
  127. yield 'WebSocket-Location', location
  128. yield 'Sec-WebSocket-Accept', accept
  129. if protocol:
  130. yield 'Sec-WebSocket-Protocol', protocol
  131. if extensions:
  132. yield 'Sec-WebSocket-Extensions', ', '.join(extensions)
  133. class ClientHandshake(Handshake):
  134. """
  135. Executes a handshake as the client end point of the socket. May raise a
  136. HandshakeError if the server response is invalid.
  137. """
  138. def __init__(self, wsock):
  139. Handshake.__init__(self, wsock)
  140. self.redirects = 0
  141. def perform(self):
  142. self.send_headers(self.request_headers())
  143. self.handle_response(*self.receive_response())
  144. def handle_response(self, status, headers):
  145. if status == 101:
  146. self.handle_handshake(headers)
  147. elif status == 401:
  148. self.handle_auth(headers)
  149. elif status in (301, 302, 303, 307, 308):
  150. self.handle_redirect(headers)
  151. else:
  152. self.fail('invalid HTTP response status %d' % status)
  153. def handle_handshake(self, headers):
  154. # Check if headers that MUST be present are actually present
  155. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  156. if name not in headers:
  157. self.fail('missing "%s" header' % name)
  158. if 'websocket' not in headers['Upgrade'].lower():
  159. self.fail('"Upgrade" header must contain "websocket"')
  160. if 'upgrade' not in headers['Connection'].lower():
  161. self.fail('"Connection" header must contain "Upgrade"')
  162. # Verify accept header
  163. accept = headers['Sec-WebSocket-Accept'].strip()
  164. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  165. if accept != required_accept:
  166. self.fail('invalid websocket accept header "%s"' % accept)
  167. # Compare extensions
  168. if 'Sec-WebSocket-Extensions' in headers:
  169. server_ext = split_stripped(headers['Sec-WebSocket-Extensions'])
  170. for e in set(server_ext) - set(self.wsock.extensions):
  171. self.fail('server extension "%s" unsupported by client' % e)
  172. for e in set(self.wsock.extensions) - set(server_ext):
  173. self.fail('client extension "%s" unsupported by server' % e)
  174. # Assert that returned protocol (if any) is supported
  175. if 'Sec-WebSocket-Protocol' in headers:
  176. protocol = headers['Sec-WebSocket-Protocol']
  177. if protocol != 'null' and protocol not in self.wsock.protocols:
  178. self.fail('unsupported protocol "%s"' % protocol)
  179. self.wsock.protocol = protocol
  180. def handle_auth(self, headers):
  181. # HTTP authentication is required in the request
  182. hdr = headers['WWW-Authenticate']
  183. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  184. mode = hdr.lstrip().split(' ', 1)[0]
  185. if not self.wsock.auth:
  186. self.fail('missing username and password for HTTP authentication')
  187. if mode == 'Basic':
  188. auth_hdr = self.http_auth_basic_headers(**authres)
  189. elif mode == 'Digest':
  190. auth_hdr = self.http_auth_digest_headers(**authres)
  191. else:
  192. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  193. # Send new, authenticated handshake
  194. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  195. self.handle_response(*self.receive_response())
  196. def handle_redirect(self, headers):
  197. self.redirects += 1
  198. if self.redirects > MAX_REDIRECTS:
  199. self.fail('reached maximum number of redirects (%d)'
  200. % MAX_REDIRECTS)
  201. # Handle HTTP redirect
  202. url = urlparse(headers['Location'].strip())
  203. # Reconnect socket to new host if net location changed
  204. if not url.port:
  205. url.port = 443 if self.secure else 80
  206. addr = (url.netloc, url.port)
  207. if addr != self.sock.getpeername():
  208. self.sock.close()
  209. self.sock.connect(addr)
  210. # Update websocket object and send new handshake
  211. self.wsock.location = url.path
  212. self.perform()
  213. def request_headers(self):
  214. if len(self.wsock.location) == 0:
  215. self.fail('request location is empty')
  216. # Generate a 16-byte random base64-encoded key for this connection
  217. self.key = b64encode(os.urandom(16))
  218. # Send client handshake
  219. yield 'GET %s HTTP/1.1' % self.wsock.location
  220. yield 'Host', '%s:%d' % self.sock.getpeername()
  221. yield 'Upgrade', 'websocket'
  222. yield 'Connection', 'keep-alive, Upgrade'
  223. yield 'Sec-WebSocket-Key', self.key
  224. yield 'Sec-WebSocket-Version', WS_VERSION
  225. if self.wsock.origin:
  226. yield 'Origin', self.wsock.origin
  227. # These are for eagerly caching webservers
  228. yield 'Pragma', 'no-cache'
  229. yield 'Cache-Control', 'no-cache'
  230. # Request protocols and extension, these are later checked with the
  231. # actual supported values from the server's response
  232. if self.wsock.protocols:
  233. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  234. if self.wsock.extensions:
  235. yield 'Sec-WebSocket-Extensions', ', '.join(self.wsock.extensions)
  236. def http_auth_basic_headers(self, **kwargs):
  237. u, p = self.wsock.auth
  238. u = u.encode('utf-8')
  239. p = p.encode('utf-8')
  240. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  241. def http_auth_digest_headers(self, **kwargs):
  242. username, password = self.wsock.auth
  243. yield 'Authorization', build_authorization_request(
  244. username=username.encode('utf-8'),
  245. method='GET',
  246. uri=self.wsock.location,
  247. nonce_count=0,
  248. realm=kwargs['realm'],
  249. nonce=kwargs['nonce'],
  250. opaque=kwargs['opaque'],
  251. password=password.encode('utf-8'))