handshake.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. import os
  2. import re
  3. import socket
  4. import time
  5. from base64 import b64encode
  6. from hashlib import sha1
  7. from urlparse import urlparse
  8. from errors import HandshakeError
  9. from extension import extension_conflicts
  10. from python_digest import build_authorization_request
  11. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  12. WS_VERSION = '13'
  13. MAX_REDIRECTS = 10
  14. HDR_TIMEOUT = 5
  15. MAX_HDR_LEN = 1024
  16. class Handshake(object):
  17. def __init__(self, wsock):
  18. self.wsock = wsock
  19. self.sock = wsock.sock
  20. def fail(self, msg):
  21. self.sock.close()
  22. raise HandshakeError(msg)
  23. def receive_request(self):
  24. raw, headers = self.receive_headers()
  25. # Request must be HTTP (at least 1.1) GET request, find the location
  26. # (without trailing slash)
  27. match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
  28. if match is None:
  29. self.fail('not a valid HTTP 1.1 GET request')
  30. location = match.group(1)
  31. return location, headers
  32. def receive_response(self):
  33. raw, headers = self.receive_headers()
  34. # Response must be HTTP (at least 1.1) with status 101
  35. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  36. if match is None:
  37. self.fail('not a valid HTTP 1.1 response')
  38. status = int(match.group(1))
  39. return status, headers
  40. def receive_headers(self):
  41. # Receive entire HTTP header
  42. hdr = ''
  43. sock_timeout = self.sock.gettimeout()
  44. try:
  45. force_timeout = sock_timeout is None
  46. timeout = HDR_TIMEOUT if force_timeout else sock_timeout
  47. self.sock.settimeout(timeout)
  48. start_time = time.time()
  49. while hdr[-4:] != '\r\n\r\n':
  50. if len(hdr) == MAX_HDR_LEN:
  51. raise HandshakeError('request exceeds maximum header '
  52. 'length of %d' % MAX_HDR_LEN)
  53. hdr += self.sock.recv(1)
  54. time_diff = time.time() - start_time
  55. if time_diff > timeout:
  56. raise socket.timeout
  57. self.sock.settimeout(timeout - time_diff)
  58. except socket.timeout:
  59. self.sock.close()
  60. raise HandshakeError('timeout while receiving handshake headers')
  61. self.sock.settimeout(sock_timeout)
  62. hdr = hdr.decode('utf-8', 'ignore')
  63. headers = {}
  64. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', hdr):
  65. if key in headers:
  66. headers[key] += ', ' + value
  67. else:
  68. headers[key] = value
  69. return hdr, headers
  70. def send_headers(self, headers):
  71. # Send request
  72. for hdr in list(headers):
  73. if isinstance(hdr, tuple):
  74. hdr = '%s: %s' % hdr
  75. self.sock.sendall(hdr + '\r\n')
  76. self.sock.sendall('\r\n')
  77. def perform(self):
  78. raise NotImplementedError
  79. class ServerHandshake(Handshake):
  80. """
  81. Executes a handshake as the server end point of the socket. If the HTTP
  82. request headers sent by the client are invalid, a HandshakeError is raised.
  83. """
  84. def perform(self, ssock):
  85. # Receive and validate client handshake
  86. location, headers = self.receive_request()
  87. self.wsock.location = location
  88. self.wsock.request_headers = headers
  89. # Send server handshake in response
  90. self.send_headers(self.response_headers(ssock))
  91. def response_headers(self, ssock):
  92. headers = self.wsock.request_headers
  93. # Check if headers that MUST be present are actually present
  94. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  95. 'Sec-WebSocket-Version'):
  96. if name not in headers:
  97. self.fail('missing "%s" header' % name)
  98. # Check WebSocket version used by client
  99. version = headers['Sec-WebSocket-Version']
  100. if version != WS_VERSION:
  101. self.fail('WebSocket version %s requested (only %s is supported)'
  102. % (version, WS_VERSION))
  103. # Verify required header keywords
  104. if 'websocket' not in headers['Upgrade'].lower():
  105. self.fail('"Upgrade" header must contain "websocket"')
  106. if 'upgrade' not in headers['Connection'].lower():
  107. self.fail('"Connection" header must contain "Upgrade"')
  108. # Origin must be present if browser client, and must match the list of
  109. # trusted origins
  110. if 'Origin' not in headers and 'User-Agent' in headers:
  111. self.fail('browser client must specify "Origin" header')
  112. origin = headers.get('Origin', 'null')
  113. if origin == 'null':
  114. if ssock.trusted_origins:
  115. self.fail('no "Origin" header specified, assuming untrusted')
  116. elif ssock.trusted_origins and origin not in ssock.trusted_origins:
  117. self.fail('untrusted origin "%s"' % origin)
  118. # Only a supported protocol can be returned
  119. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  120. if 'Sec-WebSocket-Protocol' in headers else []
  121. self.wsock.protocol = None
  122. for p in client_proto:
  123. if p in ssock.protocols:
  124. self.wsock.protocol = p
  125. break
  126. # Only supported extensions are returned
  127. if 'Sec-WebSocket-Extensions' in headers:
  128. supported_ext = dict((e.name, e) for e in ssock.extensions)
  129. self.wsock.extension_hooks = []
  130. extensions = []
  131. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  132. name, params = parse_param_hdr(ext)
  133. if name in supported_ext:
  134. ext = supported_ext[name]
  135. if not extension_conflicts(ext, extensions):
  136. extensions.append(ext)
  137. hook = ext.create_hook(**params)
  138. self.wsock.extension_hooks.append(hook)
  139. # Check if requested resource location is served by this server
  140. if ssock.locations:
  141. if self.wsock.location not in ssock.locations:
  142. raise HandshakeError('location "%s" is not supported by this '
  143. 'server' % self.wsock.location)
  144. # Encode acceptation key using the WebSocket GUID
  145. key = headers['Sec-WebSocket-Key'].strip()
  146. accept = b64encode(sha1(key + WS_GUID).digest())
  147. # Location scheme differs for SSL-enabled connections
  148. scheme = 'wss' if self.wsock.secure else 'ws'
  149. if 'Host' in headers:
  150. host = headers['Host']
  151. else:
  152. host, port = self.sock.getpeername()
  153. default_port = 443 if self.wsock.secure else 80
  154. if port != default_port:
  155. host += ':%d' % port
  156. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  157. # Construct HTTP response header
  158. yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
  159. yield 'Upgrade', 'websocket'
  160. yield 'Connection', 'Upgrade'
  161. yield 'Sec-WebSocket-Origin', origin
  162. yield 'Sec-WebSocket-Location', location
  163. yield 'Sec-WebSocket-Accept', accept
  164. if self.wsock.protocol:
  165. yield 'Sec-WebSocket-Protocol', self.wsock.protocol
  166. if self.wsock.extensions:
  167. values = [format_param_hdr(e.name, e.request)
  168. for e in self.wsock.extensions]
  169. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  170. class ClientHandshake(Handshake):
  171. """
  172. Executes a handshake as the client end point of the socket. May raise a
  173. HandshakeError if the server response is invalid.
  174. """
  175. def __init__(self, wsock):
  176. Handshake.__init__(self, wsock)
  177. self.redirects = 0
  178. def perform(self):
  179. self.send_headers(self.request_headers())
  180. self.handle_response(*self.receive_response())
  181. def handle_response(self, status, headers):
  182. if status == 101:
  183. self.handle_handshake(headers)
  184. elif status == 401:
  185. self.handle_auth(headers)
  186. elif status in (301, 302, 303, 307, 308):
  187. self.handle_redirect(headers)
  188. else:
  189. self.fail('invalid HTTP response status %d' % status)
  190. def handle_handshake(self, headers):
  191. # Check if headers that MUST be present are actually present
  192. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  193. if name not in headers:
  194. self.fail('missing "%s" header' % name)
  195. if 'websocket' not in headers['Upgrade'].lower():
  196. self.fail('"Upgrade" header must contain "websocket"')
  197. if 'upgrade' not in headers['Connection'].lower():
  198. self.fail('"Connection" header must contain "Upgrade"')
  199. # Verify accept header
  200. accept = headers['Sec-WebSocket-Accept'].strip()
  201. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  202. if accept != required_accept:
  203. self.fail('invalid websocket accept header "%s"' % accept)
  204. # Compare extensions, add hooks only for those returned by server
  205. if 'Sec-WebSocket-Extensions' in headers:
  206. supported_ext = dict((e.name, e) for e in self.wsock.extensions)
  207. self.wsock.extension_hooks = []
  208. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  209. name, params = parse_param_hdr(ext)
  210. if name not in supported_ext:
  211. raise HandshakeError('server handshake contains '
  212. 'unsupported extension "%s"' % name)
  213. hook = supported_ext[name].create_hook(**params)
  214. self.wsock.extension_hooks.append(hook)
  215. # Assert that returned protocol (if any) is supported
  216. if 'Sec-WebSocket-Protocol' in headers:
  217. protocol = headers['Sec-WebSocket-Protocol']
  218. if protocol != 'null' and protocol not in self.wsock.protocols:
  219. self.fail('unsupported protocol "%s"' % protocol)
  220. self.wsock.protocol = protocol
  221. def handle_auth(self, headers):
  222. # HTTP authentication is required in the request
  223. hdr = headers['WWW-Authenticate']
  224. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  225. mode = hdr.lstrip().split(' ', 1)[0]
  226. if not self.wsock.auth:
  227. self.fail('missing username and password for HTTP authentication')
  228. if mode == 'Basic':
  229. auth_hdr = self.http_auth_basic_headers(**authres)
  230. elif mode == 'Digest':
  231. auth_hdr = self.http_auth_digest_headers(**authres)
  232. else:
  233. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  234. # Send new, authenticated handshake
  235. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  236. self.handle_response(*self.receive_response())
  237. def handle_redirect(self, headers):
  238. self.redirects += 1
  239. if self.redirects > MAX_REDIRECTS:
  240. self.fail('reached maximum number of redirects (%d)'
  241. % MAX_REDIRECTS)
  242. # Handle HTTP redirect
  243. url = urlparse(headers['Location'].strip())
  244. # Reconnect socket to new host if net location changed
  245. if not url.port:
  246. url.port = 443 if self.secure else 80
  247. addr = (url.netloc, url.port)
  248. if addr != self.sock.getpeername():
  249. self.sock.close()
  250. self.sock.connect(addr)
  251. # Update websocket object and send new handshake
  252. self.wsock.location = url.path
  253. self.perform()
  254. def request_headers(self):
  255. if len(self.wsock.location) == 0:
  256. self.fail('request location is empty')
  257. # Generate a 16-byte random base64-encoded key for this connection
  258. self.key = b64encode(os.urandom(16))
  259. # Send client handshake
  260. yield 'GET %s HTTP/1.1' % self.wsock.location
  261. yield 'Host', '%s:%d' % self.sock.getpeername()
  262. yield 'Upgrade', 'websocket'
  263. yield 'Connection', 'keep-alive, Upgrade'
  264. yield 'Sec-WebSocket-Key', self.key
  265. yield 'Sec-WebSocket-Version', WS_VERSION
  266. if self.wsock.origin:
  267. yield 'Origin', self.wsock.origin
  268. # These are for eagerly caching webservers
  269. yield 'Pragma', 'no-cache'
  270. yield 'Cache-Control', 'no-cache'
  271. # Request protocols and extensions, these are later checked with the
  272. # actual supported values from the server's response
  273. if self.wsock.protocols:
  274. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  275. if self.wsock.extensions:
  276. values = [format_param_hdr(e.name, e.request)
  277. for e in self.wsock.extensions]
  278. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  279. def http_auth_basic_headers(self, **kwargs):
  280. u, p = self.wsock.auth
  281. u = u.encode('utf-8')
  282. p = p.encode('utf-8')
  283. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  284. def http_auth_digest_headers(self, **kwargs):
  285. username, password = self.wsock.auth
  286. yield 'Authorization', build_authorization_request(
  287. username=username.encode('utf-8'),
  288. method='GET',
  289. uri=self.wsock.location,
  290. nonce_count=0,
  291. realm=kwargs['realm'],
  292. nonce=kwargs['nonce'],
  293. opaque=kwargs['opaque'],
  294. password=password.encode('utf-8'))
  295. def split_stripped(value, delim=',', maxsplits=-1):
  296. return map(str.strip, str(value).split(delim, maxsplits)) if value else []
  297. def parse_param_hdr(hdr):
  298. if ';' in hdr:
  299. name, paramstr = split_stripped(hdr, ';', 1)
  300. else:
  301. name = hdr
  302. paramstr = ''
  303. params = {}
  304. for param in split_stripped(paramstr):
  305. if '=' in param:
  306. key, value = split_stripped(param, '=', 1)
  307. if value.isdigit():
  308. value = int(value)
  309. else:
  310. key = param
  311. value = True
  312. params[key] = value
  313. return name, params
  314. def format_param_hdr(value, params):
  315. if not params:
  316. return value
  317. def fmt_param((k, v)):
  318. if v is True:
  319. return k
  320. if v is not False and v is not None:
  321. return k + '=' + str(v)
  322. strparams = filter(None, map(fmt_param, params.items()))
  323. return '%s; %s' % (value, ', '.join(strparams))