handshake.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. import os
  2. import re
  3. from hashlib import sha1
  4. from base64 import b64encode
  5. from urlparse import urlparse
  6. from python_digest import build_authorization_request
  7. from errors import HandshakeError
  8. from extension import filter_extensions
  9. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  10. WS_VERSION = '13'
  11. MAX_REDIRECTS = 10
  12. class Handshake(object):
  13. def __init__(self, wsock):
  14. self.wsock = wsock
  15. self.sock = wsock.sock
  16. def fail(self, msg):
  17. self.sock.close()
  18. raise HandshakeError(msg)
  19. def receive_request(self):
  20. raw, headers = self.receive_headers()
  21. # Request must be HTTP (at least 1.1) GET request, find the location
  22. match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw)
  23. if match is None:
  24. self.fail('not a valid HTTP 1.1 GET request')
  25. location = match.group(1)
  26. return location, headers
  27. def receive_response(self):
  28. raw, headers = self.receive_headers()
  29. # Response must be HTTP (at least 1.1) with status 101
  30. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  31. if match is None:
  32. self.fail('not a valid HTTP 1.1 response')
  33. status = int(match.group(1))
  34. return status, headers
  35. def receive_headers(self):
  36. # Receive entire HTTP header
  37. raw_headers = ''
  38. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  39. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  40. headers = {}
  41. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', raw_headers):
  42. if key in headers:
  43. headers[key] += ', ' + value
  44. else:
  45. headers[key] = value
  46. return raw_headers, headers
  47. def send_headers(self, headers):
  48. # Send request
  49. for hdr in list(headers):
  50. if isinstance(hdr, tuple):
  51. hdr = '%s: %s' % hdr
  52. self.sock.sendall(hdr + '\r\n')
  53. self.sock.sendall('\r\n')
  54. def perform(self):
  55. raise NotImplementedError
  56. class ServerHandshake(Handshake):
  57. """
  58. Executes a handshake as the server end point of the socket. If the HTTP
  59. request headers sent by the client are invalid, a HandshakeError is raised.
  60. """
  61. def perform(self):
  62. # Receive and validate client handshake
  63. self.wsock.location, headers = self.receive_request()
  64. # Send server handshake in response
  65. self.send_headers(self.response_headers(headers))
  66. def response_headers(self, headers):
  67. # Check if headers that MUST be present are actually present
  68. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  69. 'Sec-WebSocket-Version'):
  70. if name not in headers:
  71. self.fail('missing "%s" header' % name)
  72. # Check WebSocket version used by client
  73. version = headers['Sec-WebSocket-Version']
  74. if version != WS_VERSION:
  75. self.fail('WebSocket version %s requested (only %s is supported)'
  76. % (version, WS_VERSION))
  77. # Verify required header keywords
  78. if 'websocket' not in headers['Upgrade'].lower():
  79. self.fail('"Upgrade" header must contain "websocket"')
  80. if 'upgrade' not in headers['Connection'].lower():
  81. self.fail('"Connection" header must contain "Upgrade"')
  82. # Origin must be present if browser client, and must match the list of
  83. # trusted origins
  84. origin = 'null'
  85. if 'Origin' not in headers:
  86. if 'User-Agent' in headers:
  87. self.fail('browser client must specify "Origin" header')
  88. if self.wsock.trusted_origins:
  89. self.fail('no "Origin" header specified, assuming untrusted')
  90. elif self.wsock.trusted_origins:
  91. origin = headers['Origin']
  92. if origin not in self.wsock.trusted_origins:
  93. self.fail('untrusted origin "%s"' % origin)
  94. # Only a supported protocol can be returned
  95. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  96. if 'Sec-WebSocket-Protocol' in headers else []
  97. self.wsock.protocol = None
  98. for p in client_proto:
  99. if p in self.wsock.protocols:
  100. self.wsock.protocol = p
  101. break
  102. # Only supported extensions are returned
  103. if 'Sec-WebSocket-Extensions' in headers:
  104. supported_ext = dict((e.name, e) for e in self.wsock.extensions)
  105. extensions = []
  106. all_params = []
  107. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  108. name, params = parse_param_hdr(ext)
  109. if name in supported_ext:
  110. extensions.append(supported_ext[name])
  111. all_params.append(params)
  112. self.wsock.extensions = filter_extensions(extensions)
  113. for ext, params in zip(self.wsock.extensions, all_params):
  114. hook = ext.Hook(**params)
  115. self.wsock.add_hook(send=hook.send, recv=hook.recv)
  116. else:
  117. self.wsock.extensions = []
  118. # Encode acceptation key using the WebSocket GUID
  119. key = headers['Sec-WebSocket-Key'].strip()
  120. accept = b64encode(sha1(key + WS_GUID).digest())
  121. # Location scheme differs for SSL-enabled connections
  122. scheme = 'wss' if self.wsock.secure else 'ws'
  123. if 'Host' in headers:
  124. host = headers['Host']
  125. else:
  126. host, port = self.sock.getpeername()
  127. default_port = 443 if self.wsock.secure else 80
  128. if port != default_port:
  129. host += ':%d' % port
  130. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  131. # Construct HTTP response header
  132. yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
  133. yield 'Upgrade', 'websocket'
  134. yield 'Connection', 'Upgrade'
  135. yield 'WebSocket-Origin', origin
  136. yield 'WebSocket-Location', location
  137. yield 'Sec-WebSocket-Accept', accept
  138. if self.wsock.protocol:
  139. yield 'Sec-WebSocket-Protocol', self.wsock.protocol
  140. if self.wsock.extensions:
  141. values = [format_param_hdr(e.name, e.request)
  142. for e in self.wsock.extensions]
  143. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  144. class ClientHandshake(Handshake):
  145. """
  146. Executes a handshake as the client end point of the socket. May raise a
  147. HandshakeError if the server response is invalid.
  148. """
  149. def __init__(self, wsock):
  150. Handshake.__init__(self, wsock)
  151. self.redirects = 0
  152. def perform(self):
  153. self.send_headers(self.request_headers())
  154. self.handle_response(*self.receive_response())
  155. def handle_response(self, status, headers):
  156. if status == 101:
  157. self.handle_handshake(headers)
  158. elif status == 401:
  159. self.handle_auth(headers)
  160. elif status in (301, 302, 303, 307, 308):
  161. self.handle_redirect(headers)
  162. else:
  163. self.fail('invalid HTTP response status %d' % status)
  164. def handle_handshake(self, headers):
  165. # Check if headers that MUST be present are actually present
  166. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  167. if name not in headers:
  168. self.fail('missing "%s" header' % name)
  169. if 'websocket' not in headers['Upgrade'].lower():
  170. self.fail('"Upgrade" header must contain "websocket"')
  171. if 'upgrade' not in headers['Connection'].lower():
  172. self.fail('"Connection" header must contain "Upgrade"')
  173. # Verify accept header
  174. accept = headers['Sec-WebSocket-Accept'].strip()
  175. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  176. if accept != required_accept:
  177. self.fail('invalid websocket accept header "%s"' % accept)
  178. # Compare extensions, add hooks only for those returned by server
  179. if 'Sec-WebSocket-Extensions' in headers:
  180. supported_ext = dict((e.name, e) for e in self.wsock.extensions)
  181. self.wsock.extensions = []
  182. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  183. name, params = parse_param_hdr(ext)
  184. if name not in supported_ext:
  185. raise HandshakeError('server handshake contains '
  186. 'unsupported extension "%s"' % name)
  187. hook = supported_ext[name].Hook(**params)
  188. self.wsock.extensions.append(supported_ext[name])
  189. self.wsock.add_hook(send=hook.send, recv=hook.recv)
  190. # Assert that returned protocol (if any) is supported
  191. if 'Sec-WebSocket-Protocol' in headers:
  192. protocol = headers['Sec-WebSocket-Protocol']
  193. if protocol != 'null' and protocol not in self.wsock.protocols:
  194. self.fail('unsupported protocol "%s"' % protocol)
  195. self.wsock.protocol = protocol
  196. def handle_auth(self, headers):
  197. # HTTP authentication is required in the request
  198. hdr = headers['WWW-Authenticate']
  199. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  200. mode = hdr.lstrip().split(' ', 1)[0]
  201. if not self.wsock.auth:
  202. self.fail('missing username and password for HTTP authentication')
  203. if mode == 'Basic':
  204. auth_hdr = self.http_auth_basic_headers(**authres)
  205. elif mode == 'Digest':
  206. auth_hdr = self.http_auth_digest_headers(**authres)
  207. else:
  208. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  209. # Send new, authenticated handshake
  210. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  211. self.handle_response(*self.receive_response())
  212. def handle_redirect(self, headers):
  213. self.redirects += 1
  214. if self.redirects > MAX_REDIRECTS:
  215. self.fail('reached maximum number of redirects (%d)'
  216. % MAX_REDIRECTS)
  217. # Handle HTTP redirect
  218. url = urlparse(headers['Location'].strip())
  219. # Reconnect socket to new host if net location changed
  220. if not url.port:
  221. url.port = 443 if self.secure else 80
  222. addr = (url.netloc, url.port)
  223. if addr != self.sock.getpeername():
  224. self.sock.close()
  225. self.sock.connect(addr)
  226. # Update websocket object and send new handshake
  227. self.wsock.location = url.path
  228. self.perform()
  229. def request_headers(self):
  230. if len(self.wsock.location) == 0:
  231. self.fail('request location is empty')
  232. # Generate a 16-byte random base64-encoded key for this connection
  233. self.key = b64encode(os.urandom(16))
  234. # Send client handshake
  235. yield 'GET %s HTTP/1.1' % self.wsock.location
  236. yield 'Host', '%s:%d' % self.sock.getpeername()
  237. yield 'Upgrade', 'websocket'
  238. yield 'Connection', 'keep-alive, Upgrade'
  239. yield 'Sec-WebSocket-Key', self.key
  240. yield 'Sec-WebSocket-Version', WS_VERSION
  241. if self.wsock.origin:
  242. yield 'Origin', self.wsock.origin
  243. # These are for eagerly caching webservers
  244. yield 'Pragma', 'no-cache'
  245. yield 'Cache-Control', 'no-cache'
  246. # Request protocols and extensions, these are later checked with the
  247. # actual supported values from the server's response
  248. if self.wsock.protocols:
  249. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  250. if self.wsock.extensions:
  251. values = [format_param_hdr(e.name, e.request)
  252. for e in self.wsock.extensions]
  253. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  254. def http_auth_basic_headers(self, **kwargs):
  255. u, p = self.wsock.auth
  256. u = u.encode('utf-8')
  257. p = p.encode('utf-8')
  258. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  259. def http_auth_digest_headers(self, **kwargs):
  260. username, password = self.wsock.auth
  261. yield 'Authorization', build_authorization_request(
  262. username=username.encode('utf-8'),
  263. method='GET',
  264. uri=self.wsock.location,
  265. nonce_count=0,
  266. realm=kwargs['realm'],
  267. nonce=kwargs['nonce'],
  268. opaque=kwargs['opaque'],
  269. password=password.encode('utf-8'))
  270. def split_stripped(value, delim=',', maxsplits=-1):
  271. return map(str.strip, str(value).split(delim, maxsplits)) if value else []
  272. def parse_param_hdr(hdr):
  273. if ';' in hdr:
  274. name, paramstr = split_stripped(hdr, ';', 1)
  275. else:
  276. name = hdr
  277. paramstr = ''
  278. params = {}
  279. for param in split_stripped(paramstr):
  280. if '=' in param:
  281. key, value = split_stripped(param, '=', 1)
  282. if value.isdigit():
  283. value = int(value)
  284. else:
  285. key = param
  286. value = True
  287. params[key] = value
  288. yield name, params
  289. def format_param_hdr(value, params):
  290. if not params:
  291. return value
  292. def fmt_param((k, v)):
  293. if v is True:
  294. return k
  295. if v is not False and v is not None:
  296. return k + '=' + str(v)
  297. strparams = filter(None, map(fmt_param, params.items()))
  298. return '%s; %s' % (value, ', '.join(strparams))