handshake.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. import os
  2. import re
  3. import socket
  4. import time
  5. from base64 import b64encode
  6. from hashlib import sha1
  7. from urlparse import urlparse
  8. from errors import HandshakeError
  9. from extension import filter_extensions
  10. from python_digest import build_authorization_request
  11. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  12. WS_VERSION = '13'
  13. MAX_REDIRECTS = 10
  14. HDR_TIMEOUT = 5
  15. MAX_HDR_LEN = 512
  16. class Handshake(object):
  17. def __init__(self, wsock):
  18. self.wsock = wsock
  19. self.sock = wsock.sock
  20. def fail(self, msg):
  21. self.sock.close()
  22. raise HandshakeError(msg)
  23. def receive_request(self):
  24. raw, headers = self.receive_headers()
  25. # Request must be HTTP (at least 1.1) GET request, find the location
  26. # (without trailing slash)
  27. match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
  28. if match is None:
  29. self.fail('not a valid HTTP 1.1 GET request')
  30. location = match.group(1)
  31. return location, headers
  32. def receive_response(self):
  33. raw, headers = self.receive_headers()
  34. # Response must be HTTP (at least 1.1) with status 101
  35. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  36. if match is None:
  37. self.fail('not a valid HTTP 1.1 response')
  38. status = int(match.group(1))
  39. return status, headers
  40. def receive_headers(self):
  41. # Receive entire HTTP header
  42. hdr = ''
  43. sock_timeout = self.sock.gettimeout()
  44. try:
  45. force_timeout = sock_timeout is None
  46. timeout = HDR_TIMEOUT if force_timeout else sock_timeout
  47. self.sock.settimeout(timeout)
  48. start_time = time.time()
  49. while hdr[-4:] != '\r\n\r\n' and len(hdr) < MAX_HDR_LEN:
  50. hdr += self.sock.recv(1)
  51. time_diff = time.time() - start_time
  52. if time_diff > timeout:
  53. raise socket.timeout
  54. self.sock.settimeout(timeout - time_diff)
  55. except socket.timeout:
  56. self.sock.close()
  57. raise HandshakeError('timeout while receiving handshake headers')
  58. self.sock.settimeout(sock_timeout)
  59. hdr = hdr.decode('utf-8', 'ignore')
  60. headers = {}
  61. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', hdr):
  62. if key in headers:
  63. headers[key] += ', ' + value
  64. else:
  65. headers[key] = value
  66. return hdr, headers
  67. def send_headers(self, headers):
  68. # Send request
  69. for hdr in list(headers):
  70. if isinstance(hdr, tuple):
  71. hdr = '%s: %s' % hdr
  72. self.sock.sendall(hdr + '\r\n')
  73. self.sock.sendall('\r\n')
  74. def perform(self):
  75. raise NotImplementedError
  76. class ServerHandshake(Handshake):
  77. """
  78. Executes a handshake as the server end point of the socket. If the HTTP
  79. request headers sent by the client are invalid, a HandshakeError is raised.
  80. """
  81. def perform(self, ssock):
  82. # Receive and validate client handshake
  83. location, headers = self.receive_request()
  84. self.wsock.location = location
  85. self.wsock.request_headers = headers
  86. # Send server handshake in response
  87. self.send_headers(self.response_headers(ssock))
  88. def response_headers(self, ssock):
  89. headers = self.wsock.request_headers
  90. # Check if headers that MUST be present are actually present
  91. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  92. 'Sec-WebSocket-Version'):
  93. if name not in headers:
  94. self.fail('missing "%s" header' % name)
  95. # Check WebSocket version used by client
  96. version = headers['Sec-WebSocket-Version']
  97. if version != WS_VERSION:
  98. self.fail('WebSocket version %s requested (only %s is supported)'
  99. % (version, WS_VERSION))
  100. # Verify required header keywords
  101. if 'websocket' not in headers['Upgrade'].lower():
  102. self.fail('"Upgrade" header must contain "websocket"')
  103. if 'upgrade' not in headers['Connection'].lower():
  104. self.fail('"Connection" header must contain "Upgrade"')
  105. # Origin must be present if browser client, and must match the list of
  106. # trusted origins
  107. if 'Origin' not in headers and 'User-Agent' in headers:
  108. self.fail('browser client must specify "Origin" header')
  109. origin = headers.get('Origin', 'null')
  110. if origin == 'null':
  111. if ssock.trusted_origins:
  112. self.fail('no "Origin" header specified, assuming untrusted')
  113. elif ssock.trusted_origins and origin not in ssock.trusted_origins:
  114. self.fail('untrusted origin "%s"' % origin)
  115. # Only a supported protocol can be returned
  116. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  117. if 'Sec-WebSocket-Protocol' in headers else []
  118. self.wsock.protocol = None
  119. for p in client_proto:
  120. if p in ssock.protocols:
  121. self.wsock.protocol = p
  122. break
  123. # Only supported extensions are returned
  124. if 'Sec-WebSocket-Extensions' in headers:
  125. supported_ext = dict((e.name, e) for e in ssock.extensions)
  126. extensions = []
  127. all_params = []
  128. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  129. name, params = parse_param_hdr(ext)
  130. if name in supported_ext:
  131. extensions.append(supported_ext[name])
  132. all_params.append(params)
  133. self.wsock.extensions = filter_extensions(extensions)
  134. for ext, params in zip(self.wsock.extensions, all_params):
  135. hook = ext.create_hook(**params)
  136. self.wsock.add_hook(send=hook.send, recv=hook.recv)
  137. else:
  138. self.wsock.extensions = []
  139. # Check if requested resource location is served by this server
  140. if ssock.locations:
  141. if self.wsock.location not in ssock.locations:
  142. raise HandshakeError('location "%s" is not supported by this '
  143. 'server' % self.wsock.location)
  144. # Encode acceptation key using the WebSocket GUID
  145. key = headers['Sec-WebSocket-Key'].strip()
  146. accept = b64encode(sha1(key + WS_GUID).digest())
  147. # Location scheme differs for SSL-enabled connections
  148. scheme = 'wss' if self.wsock.secure else 'ws'
  149. if 'Host' in headers:
  150. host = headers['Host']
  151. else:
  152. host, port = self.sock.getpeername()
  153. default_port = 443 if self.wsock.secure else 80
  154. if port != default_port:
  155. host += ':%d' % port
  156. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  157. # Construct HTTP response header
  158. yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
  159. yield 'Upgrade', 'websocket'
  160. yield 'Connection', 'Upgrade'
  161. yield 'Sec-WebSocket-Origin', origin
  162. yield 'Sec-WebSocket-Location', location
  163. yield 'Sec-WebSocket-Accept', accept
  164. if self.wsock.protocol:
  165. yield 'Sec-WebSocket-Protocol', self.wsock.protocol
  166. if self.wsock.extensions:
  167. values = [format_param_hdr(e.name, e.request)
  168. for e in self.wsock.extensions]
  169. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  170. class ClientHandshake(Handshake):
  171. """
  172. Executes a handshake as the client end point of the socket. May raise a
  173. HandshakeError if the server response is invalid.
  174. """
  175. def __init__(self, wsock):
  176. Handshake.__init__(self, wsock)
  177. self.redirects = 0
  178. def perform(self):
  179. self.send_headers(self.request_headers())
  180. self.handle_response(*self.receive_response())
  181. def handle_response(self, status, headers):
  182. if status == 101:
  183. self.handle_handshake(headers)
  184. elif status == 401:
  185. self.handle_auth(headers)
  186. elif status in (301, 302, 303, 307, 308):
  187. self.handle_redirect(headers)
  188. else:
  189. self.fail('invalid HTTP response status %d' % status)
  190. def handle_handshake(self, headers):
  191. # Check if headers that MUST be present are actually present
  192. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  193. if name not in headers:
  194. self.fail('missing "%s" header' % name)
  195. if 'websocket' not in headers['Upgrade'].lower():
  196. self.fail('"Upgrade" header must contain "websocket"')
  197. if 'upgrade' not in headers['Connection'].lower():
  198. self.fail('"Connection" header must contain "Upgrade"')
  199. # Verify accept header
  200. accept = headers['Sec-WebSocket-Accept'].strip()
  201. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  202. if accept != required_accept:
  203. self.fail('invalid websocket accept header "%s"' % accept)
  204. # Compare extensions, add hooks only for those returned by server
  205. if 'Sec-WebSocket-Extensions' in headers:
  206. supported_ext = dict((e.name, e) for e in self.wsock.extensions)
  207. self.wsock.extensions = []
  208. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  209. name, params = parse_param_hdr(ext)
  210. if name not in supported_ext:
  211. raise HandshakeError('server handshake contains '
  212. 'unsupported extension "%s"' % name)
  213. hook = supported_ext[name].create_hook(**params)
  214. self.wsock.extensions.append(supported_ext[name])
  215. self.wsock.add_hook(send=hook.send, recv=hook.recv)
  216. # Assert that returned protocol (if any) is supported
  217. if 'Sec-WebSocket-Protocol' in headers:
  218. protocol = headers['Sec-WebSocket-Protocol']
  219. if protocol != 'null' and protocol not in self.wsock.protocols:
  220. self.fail('unsupported protocol "%s"' % protocol)
  221. self.wsock.protocol = protocol
  222. def handle_auth(self, headers):
  223. # HTTP authentication is required in the request
  224. hdr = headers['WWW-Authenticate']
  225. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  226. mode = hdr.lstrip().split(' ', 1)[0]
  227. if not self.wsock.auth:
  228. self.fail('missing username and password for HTTP authentication')
  229. if mode == 'Basic':
  230. auth_hdr = self.http_auth_basic_headers(**authres)
  231. elif mode == 'Digest':
  232. auth_hdr = self.http_auth_digest_headers(**authres)
  233. else:
  234. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  235. # Send new, authenticated handshake
  236. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  237. self.handle_response(*self.receive_response())
  238. def handle_redirect(self, headers):
  239. self.redirects += 1
  240. if self.redirects > MAX_REDIRECTS:
  241. self.fail('reached maximum number of redirects (%d)'
  242. % MAX_REDIRECTS)
  243. # Handle HTTP redirect
  244. url = urlparse(headers['Location'].strip())
  245. # Reconnect socket to new host if net location changed
  246. if not url.port:
  247. url.port = 443 if self.secure else 80
  248. addr = (url.netloc, url.port)
  249. if addr != self.sock.getpeername():
  250. self.sock.close()
  251. self.sock.connect(addr)
  252. # Update websocket object and send new handshake
  253. self.wsock.location = url.path
  254. self.perform()
  255. def request_headers(self):
  256. if len(self.wsock.location) == 0:
  257. self.fail('request location is empty')
  258. # Generate a 16-byte random base64-encoded key for this connection
  259. self.key = b64encode(os.urandom(16))
  260. # Send client handshake
  261. yield 'GET %s HTTP/1.1' % self.wsock.location
  262. yield 'Host', '%s:%d' % self.sock.getpeername()
  263. yield 'Upgrade', 'websocket'
  264. yield 'Connection', 'keep-alive, Upgrade'
  265. yield 'Sec-WebSocket-Key', self.key
  266. yield 'Sec-WebSocket-Version', WS_VERSION
  267. if self.wsock.origin:
  268. yield 'Origin', self.wsock.origin
  269. # These are for eagerly caching webservers
  270. yield 'Pragma', 'no-cache'
  271. yield 'Cache-Control', 'no-cache'
  272. # Request protocols and extensions, these are later checked with the
  273. # actual supported values from the server's response
  274. if self.wsock.protocols:
  275. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  276. if self.wsock.extensions:
  277. values = [format_param_hdr(e.name, e.request)
  278. for e in self.wsock.extensions]
  279. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  280. def http_auth_basic_headers(self, **kwargs):
  281. u, p = self.wsock.auth
  282. u = u.encode('utf-8')
  283. p = p.encode('utf-8')
  284. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  285. def http_auth_digest_headers(self, **kwargs):
  286. username, password = self.wsock.auth
  287. yield 'Authorization', build_authorization_request(
  288. username=username.encode('utf-8'),
  289. method='GET',
  290. uri=self.wsock.location,
  291. nonce_count=0,
  292. realm=kwargs['realm'],
  293. nonce=kwargs['nonce'],
  294. opaque=kwargs['opaque'],
  295. password=password.encode('utf-8'))
  296. def split_stripped(value, delim=',', maxsplits=-1):
  297. return map(str.strip, str(value).split(delim, maxsplits)) if value else []
  298. def parse_param_hdr(hdr):
  299. if ';' in hdr:
  300. name, paramstr = split_stripped(hdr, ';', 1)
  301. else:
  302. name = hdr
  303. paramstr = ''
  304. params = {}
  305. for param in split_stripped(paramstr):
  306. if '=' in param:
  307. key, value = split_stripped(param, '=', 1)
  308. if value.isdigit():
  309. value = int(value)
  310. else:
  311. key = param
  312. value = True
  313. params[key] = value
  314. return name, params
  315. def format_param_hdr(value, params):
  316. if not params:
  317. return value
  318. def fmt_param((k, v)):
  319. if v is True:
  320. return k
  321. if v is not False and v is not None:
  322. return k + '=' + str(v)
  323. strparams = filter(None, map(fmt_param, params.items()))
  324. return '%s; %s' % (value, ', '.join(strparams))