handshake.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import os
  2. import re
  3. from hashlib import sha1
  4. from base64 import b64encode
  5. from urlparse import urlparse
  6. from python_digest import build_authorization_request
  7. from errors import HandshakeError
  8. from extension import filter_extensions
  9. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  10. WS_VERSION = '13'
  11. MAX_REDIRECTS = 10
  12. class Handshake(object):
  13. def __init__(self, wsock):
  14. self.wsock = wsock
  15. self.sock = wsock.sock
  16. def fail(self, msg):
  17. self.sock.close()
  18. raise HandshakeError(msg)
  19. def receive_request(self):
  20. raw, headers = self.receive_headers()
  21. # Request must be HTTP (at least 1.1) GET request, find the location
  22. # (without trailing slash)
  23. match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
  24. if match is None:
  25. self.fail('not a valid HTTP 1.1 GET request')
  26. location = match.group(1)
  27. return location, headers
  28. def receive_response(self):
  29. raw, headers = self.receive_headers()
  30. # Response must be HTTP (at least 1.1) with status 101
  31. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  32. if match is None:
  33. self.fail('not a valid HTTP 1.1 response')
  34. status = int(match.group(1))
  35. return status, headers
  36. def receive_headers(self):
  37. # Receive entire HTTP header
  38. raw_headers = ''
  39. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  40. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  41. headers = {}
  42. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', raw_headers):
  43. if key in headers:
  44. headers[key] += ', ' + value
  45. else:
  46. headers[key] = value
  47. return raw_headers, headers
  48. def send_headers(self, headers):
  49. # Send request
  50. for hdr in list(headers):
  51. if isinstance(hdr, tuple):
  52. hdr = '%s: %s' % hdr
  53. self.sock.sendall(hdr + '\r\n')
  54. self.sock.sendall('\r\n')
  55. def perform(self):
  56. raise NotImplementedError
  57. class ServerHandshake(Handshake):
  58. """
  59. Executes a handshake as the server end point of the socket. If the HTTP
  60. request headers sent by the client are invalid, a HandshakeError is raised.
  61. """
  62. def perform(self, ssock):
  63. # Receive and validate client handshake
  64. location, headers = self.receive_request()
  65. self.wsock.location = location
  66. self.wsock.request_headers = headers
  67. # Send server handshake in response
  68. self.send_headers(self.response_headers(ssock))
  69. def response_headers(self, ssock):
  70. headers = self.wsock.request_headers
  71. # Check if headers that MUST be present are actually present
  72. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  73. 'Sec-WebSocket-Version'):
  74. if name not in headers:
  75. self.fail('missing "%s" header' % name)
  76. # Check WebSocket version used by client
  77. version = headers['Sec-WebSocket-Version']
  78. if version != WS_VERSION:
  79. self.fail('WebSocket version %s requested (only %s is supported)'
  80. % (version, WS_VERSION))
  81. # Verify required header keywords
  82. if 'websocket' not in headers['Upgrade'].lower():
  83. self.fail('"Upgrade" header must contain "websocket"')
  84. if 'upgrade' not in headers['Connection'].lower():
  85. self.fail('"Connection" header must contain "Upgrade"')
  86. # Origin must be present if browser client, and must match the list of
  87. # trusted origins
  88. if 'Origin' not in headers and 'User-Agent' in headers:
  89. self.fail('browser client must specify "Origin" header')
  90. origin = headers.get('Origin', 'null')
  91. if origin == 'null':
  92. if ssock.trusted_origins:
  93. self.fail('no "Origin" header specified, assuming untrusted')
  94. elif ssock.trusted_origins and origin not in ssock.trusted_origins:
  95. self.fail('untrusted origin "%s"' % origin)
  96. # Only a supported protocol can be returned
  97. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  98. if 'Sec-WebSocket-Protocol' in headers else []
  99. self.wsock.protocol = None
  100. for p in client_proto:
  101. if p in ssock.protocols:
  102. self.wsock.protocol = p
  103. break
  104. # Only supported extensions are returned
  105. if 'Sec-WebSocket-Extensions' in headers:
  106. supported_ext = dict((e.name, e) for e in ssock.extensions)
  107. extensions = []
  108. all_params = []
  109. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  110. name, params = parse_param_hdr(ext)
  111. if name in supported_ext:
  112. extensions.append(supported_ext[name])
  113. all_params.append(params)
  114. self.wsock.extensions = filter_extensions(extensions)
  115. for ext, params in zip(self.wsock.extensions, all_params):
  116. hook = ext.create_hook(**params)
  117. self.wsock.add_hook(send=hook.send, recv=hook.recv)
  118. else:
  119. self.wsock.extensions = []
  120. # Check if requested resource location is served by this server
  121. if ssock.locations:
  122. if self.wsock.location not in ssock.locations:
  123. raise HandshakeError('location "%s" is not supported by this '
  124. 'server' % self.wsock.location)
  125. # Encode acceptation key using the WebSocket GUID
  126. key = headers['Sec-WebSocket-Key'].strip()
  127. accept = b64encode(sha1(key + WS_GUID).digest())
  128. # Location scheme differs for SSL-enabled connections
  129. scheme = 'wss' if self.wsock.secure else 'ws'
  130. if 'Host' in headers:
  131. host = headers['Host']
  132. else:
  133. host, port = self.sock.getpeername()
  134. default_port = 443 if self.wsock.secure else 80
  135. if port != default_port:
  136. host += ':%d' % port
  137. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  138. # Construct HTTP response header
  139. yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
  140. yield 'Upgrade', 'websocket'
  141. yield 'Connection', 'Upgrade'
  142. yield 'Sec-WebSocket-Origin', origin
  143. yield 'Sec-WebSocket-Location', location
  144. yield 'Sec-WebSocket-Accept', accept
  145. if self.wsock.protocol:
  146. yield 'Sec-WebSocket-Protocol', self.wsock.protocol
  147. if self.wsock.extensions:
  148. values = [format_param_hdr(e.name, e.request)
  149. for e in self.wsock.extensions]
  150. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  151. class ClientHandshake(Handshake):
  152. """
  153. Executes a handshake as the client end point of the socket. May raise a
  154. HandshakeError if the server response is invalid.
  155. """
  156. def __init__(self, wsock):
  157. Handshake.__init__(self, wsock)
  158. self.redirects = 0
  159. def perform(self):
  160. self.send_headers(self.request_headers())
  161. self.handle_response(*self.receive_response())
  162. def handle_response(self, status, headers):
  163. if status == 101:
  164. self.handle_handshake(headers)
  165. elif status == 401:
  166. self.handle_auth(headers)
  167. elif status in (301, 302, 303, 307, 308):
  168. self.handle_redirect(headers)
  169. else:
  170. self.fail('invalid HTTP response status %d' % status)
  171. def handle_handshake(self, headers):
  172. # Check if headers that MUST be present are actually present
  173. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  174. if name not in headers:
  175. self.fail('missing "%s" header' % name)
  176. if 'websocket' not in headers['Upgrade'].lower():
  177. self.fail('"Upgrade" header must contain "websocket"')
  178. if 'upgrade' not in headers['Connection'].lower():
  179. self.fail('"Connection" header must contain "Upgrade"')
  180. # Verify accept header
  181. accept = headers['Sec-WebSocket-Accept'].strip()
  182. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  183. if accept != required_accept:
  184. self.fail('invalid websocket accept header "%s"' % accept)
  185. # Compare extensions, add hooks only for those returned by server
  186. if 'Sec-WebSocket-Extensions' in headers:
  187. supported_ext = dict((e.name, e) for e in self.wsock.extensions)
  188. self.wsock.extensions = []
  189. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  190. name, params = parse_param_hdr(ext)
  191. if name not in supported_ext:
  192. raise HandshakeError('server handshake contains '
  193. 'unsupported extension "%s"' % name)
  194. hook = supported_ext[name].create_hook(**params)
  195. self.wsock.extensions.append(supported_ext[name])
  196. self.wsock.add_hook(send=hook.send, recv=hook.recv)
  197. # Assert that returned protocol (if any) is supported
  198. if 'Sec-WebSocket-Protocol' in headers:
  199. protocol = headers['Sec-WebSocket-Protocol']
  200. if protocol != 'null' and protocol not in self.wsock.protocols:
  201. self.fail('unsupported protocol "%s"' % protocol)
  202. self.wsock.protocol = protocol
  203. def handle_auth(self, headers):
  204. # HTTP authentication is required in the request
  205. hdr = headers['WWW-Authenticate']
  206. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  207. mode = hdr.lstrip().split(' ', 1)[0]
  208. if not self.wsock.auth:
  209. self.fail('missing username and password for HTTP authentication')
  210. if mode == 'Basic':
  211. auth_hdr = self.http_auth_basic_headers(**authres)
  212. elif mode == 'Digest':
  213. auth_hdr = self.http_auth_digest_headers(**authres)
  214. else:
  215. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  216. # Send new, authenticated handshake
  217. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  218. self.handle_response(*self.receive_response())
  219. def handle_redirect(self, headers):
  220. self.redirects += 1
  221. if self.redirects > MAX_REDIRECTS:
  222. self.fail('reached maximum number of redirects (%d)'
  223. % MAX_REDIRECTS)
  224. # Handle HTTP redirect
  225. url = urlparse(headers['Location'].strip())
  226. # Reconnect socket to new host if net location changed
  227. if not url.port:
  228. url.port = 443 if self.secure else 80
  229. addr = (url.netloc, url.port)
  230. if addr != self.sock.getpeername():
  231. self.sock.close()
  232. self.sock.connect(addr)
  233. # Update websocket object and send new handshake
  234. self.wsock.location = url.path
  235. self.perform()
  236. def request_headers(self):
  237. if len(self.wsock.location) == 0:
  238. self.fail('request location is empty')
  239. # Generate a 16-byte random base64-encoded key for this connection
  240. self.key = b64encode(os.urandom(16))
  241. # Send client handshake
  242. yield 'GET %s HTTP/1.1' % self.wsock.location
  243. yield 'Host', '%s:%d' % self.sock.getpeername()
  244. yield 'Upgrade', 'websocket'
  245. yield 'Connection', 'keep-alive, Upgrade'
  246. yield 'Sec-WebSocket-Key', self.key
  247. yield 'Sec-WebSocket-Version', WS_VERSION
  248. if self.wsock.origin:
  249. yield 'Origin', self.wsock.origin
  250. # These are for eagerly caching webservers
  251. yield 'Pragma', 'no-cache'
  252. yield 'Cache-Control', 'no-cache'
  253. # Request protocols and extensions, these are later checked with the
  254. # actual supported values from the server's response
  255. if self.wsock.protocols:
  256. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  257. if self.wsock.extensions:
  258. values = [format_param_hdr(e.name, e.request)
  259. for e in self.wsock.extensions]
  260. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  261. def http_auth_basic_headers(self, **kwargs):
  262. u, p = self.wsock.auth
  263. u = u.encode('utf-8')
  264. p = p.encode('utf-8')
  265. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  266. def http_auth_digest_headers(self, **kwargs):
  267. username, password = self.wsock.auth
  268. yield 'Authorization', build_authorization_request(
  269. username=username.encode('utf-8'),
  270. method='GET',
  271. uri=self.wsock.location,
  272. nonce_count=0,
  273. realm=kwargs['realm'],
  274. nonce=kwargs['nonce'],
  275. opaque=kwargs['opaque'],
  276. password=password.encode('utf-8'))
  277. def split_stripped(value, delim=',', maxsplits=-1):
  278. return map(str.strip, str(value).split(delim, maxsplits)) if value else []
  279. def parse_param_hdr(hdr):
  280. if ';' in hdr:
  281. name, paramstr = split_stripped(hdr, ';', 1)
  282. else:
  283. name = hdr
  284. paramstr = ''
  285. params = {}
  286. for param in split_stripped(paramstr):
  287. if '=' in param:
  288. key, value = split_stripped(param, '=', 1)
  289. if value.isdigit():
  290. value = int(value)
  291. else:
  292. key = param
  293. value = True
  294. params[key] = value
  295. return name, params
  296. def format_param_hdr(value, params):
  297. if not params:
  298. return value
  299. def fmt_param((k, v)):
  300. if v is True:
  301. return k
  302. if v is not False and v is not None:
  303. return k + '=' + str(v)
  304. strparams = filter(None, map(fmt_param, params.items()))
  305. return '%s; %s' % (value, ', '.join(strparams))