handshake.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. import os
  2. import re
  3. import socket
  4. import time
  5. from base64 import b64encode
  6. from hashlib import sha1
  7. from urlparse import urlparse
  8. from errors import HandshakeError
  9. from extension import filter_extensions
  10. from python_digest import build_authorization_request
  11. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  12. WS_VERSION = '13'
  13. MAX_REDIRECTS = 10
  14. HDR_TIMEOUT = 5
  15. MAX_HDR_LEN = 1024
  16. class Handshake(object):
  17. def __init__(self, wsock):
  18. self.wsock = wsock
  19. self.sock = wsock.sock
  20. def fail(self, msg):
  21. self.sock.close()
  22. raise HandshakeError(msg)
  23. def receive_request(self):
  24. raw, headers = self.receive_headers()
  25. # Request must be HTTP (at least 1.1) GET request, find the location
  26. # (without trailing slash)
  27. match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
  28. if match is None:
  29. self.fail('not a valid HTTP 1.1 GET request')
  30. location = match.group(1)
  31. return location, headers
  32. def receive_response(self):
  33. raw, headers = self.receive_headers()
  34. # Response must be HTTP (at least 1.1) with status 101
  35. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  36. if match is None:
  37. self.fail('not a valid HTTP 1.1 response')
  38. status = int(match.group(1))
  39. return status, headers
  40. def receive_headers(self):
  41. # Receive entire HTTP header
  42. hdr = ''
  43. sock_timeout = self.sock.gettimeout()
  44. try:
  45. force_timeout = sock_timeout is None
  46. timeout = HDR_TIMEOUT if force_timeout else sock_timeout
  47. self.sock.settimeout(timeout)
  48. start_time = time.time()
  49. while hdr[-4:] != '\r\n\r\n':
  50. if len(hdr) == MAX_HDR_LEN:
  51. raise HandshakeError('request exceeds maximum header '
  52. 'length of %d' % MAX_HDR_LEN)
  53. hdr += self.sock.recv(1)
  54. time_diff = time.time() - start_time
  55. if time_diff > timeout:
  56. raise socket.timeout
  57. self.sock.settimeout(timeout - time_diff)
  58. except socket.timeout:
  59. self.sock.close()
  60. raise HandshakeError('timeout while receiving handshake headers')
  61. self.sock.settimeout(sock_timeout)
  62. hdr = hdr.decode('utf-8', 'ignore')
  63. headers = {}
  64. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', hdr):
  65. if key in headers:
  66. headers[key] += ', ' + value
  67. else:
  68. headers[key] = value
  69. return hdr, headers
  70. def send_headers(self, headers):
  71. # Send request
  72. for hdr in list(headers):
  73. if isinstance(hdr, tuple):
  74. hdr = '%s: %s' % hdr
  75. self.sock.sendall(hdr + '\r\n')
  76. self.sock.sendall('\r\n')
  77. def perform(self):
  78. raise NotImplementedError
  79. class ServerHandshake(Handshake):
  80. """
  81. Executes a handshake as the server end point of the socket. If the HTTP
  82. request headers sent by the client are invalid, a HandshakeError is raised.
  83. """
  84. def perform(self, ssock):
  85. # Receive and validate client handshake
  86. location, headers = self.receive_request()
  87. self.wsock.location = location
  88. self.wsock.request_headers = headers
  89. # Send server handshake in response
  90. self.send_headers(self.response_headers(ssock))
  91. def response_headers(self, ssock):
  92. headers = self.wsock.request_headers
  93. # Check if headers that MUST be present are actually present
  94. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  95. 'Sec-WebSocket-Version'):
  96. if name not in headers:
  97. self.fail('missing "%s" header' % name)
  98. # Check WebSocket version used by client
  99. version = headers['Sec-WebSocket-Version']
  100. if version != WS_VERSION:
  101. self.fail('WebSocket version %s requested (only %s is supported)'
  102. % (version, WS_VERSION))
  103. # Verify required header keywords
  104. if 'websocket' not in headers['Upgrade'].lower():
  105. self.fail('"Upgrade" header must contain "websocket"')
  106. if 'upgrade' not in headers['Connection'].lower():
  107. self.fail('"Connection" header must contain "Upgrade"')
  108. # Origin must be present if browser client, and must match the list of
  109. # trusted origins
  110. if 'Origin' not in headers and 'User-Agent' in headers:
  111. self.fail('browser client must specify "Origin" header')
  112. origin = headers.get('Origin', 'null')
  113. if origin == 'null':
  114. if ssock.trusted_origins:
  115. self.fail('no "Origin" header specified, assuming untrusted')
  116. elif ssock.trusted_origins and origin not in ssock.trusted_origins:
  117. self.fail('untrusted origin "%s"' % origin)
  118. # Only a supported protocol can be returned
  119. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  120. if 'Sec-WebSocket-Protocol' in headers else []
  121. self.wsock.protocol = None
  122. for p in client_proto:
  123. if p in ssock.protocols:
  124. self.wsock.protocol = p
  125. break
  126. # Only supported extensions are returned
  127. if 'Sec-WebSocket-Extensions' in headers:
  128. supported_ext = dict((e.name, e) for e in ssock.extensions)
  129. extensions = []
  130. all_params = []
  131. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  132. name, params = parse_param_hdr(ext)
  133. if name in supported_ext:
  134. extensions.append(supported_ext[name])
  135. all_params.append(params)
  136. self.wsock.extensions = filter_extensions(extensions)
  137. for ext, params in zip(self.wsock.extensions, all_params):
  138. hook = ext.create_hook(**params)
  139. self.wsock.add_hook(send=hook.send, recv=hook.recv)
  140. else:
  141. self.wsock.extensions = []
  142. # Check if requested resource location is served by this server
  143. if ssock.locations:
  144. if self.wsock.location not in ssock.locations:
  145. raise HandshakeError('location "%s" is not supported by this '
  146. 'server' % self.wsock.location)
  147. # Encode acceptation key using the WebSocket GUID
  148. key = headers['Sec-WebSocket-Key'].strip()
  149. accept = b64encode(sha1(key + WS_GUID).digest())
  150. # Location scheme differs for SSL-enabled connections
  151. scheme = 'wss' if self.wsock.secure else 'ws'
  152. if 'Host' in headers:
  153. host = headers['Host']
  154. else:
  155. host, port = self.sock.getpeername()
  156. default_port = 443 if self.wsock.secure else 80
  157. if port != default_port:
  158. host += ':%d' % port
  159. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  160. # Construct HTTP response header
  161. yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
  162. yield 'Upgrade', 'websocket'
  163. yield 'Connection', 'Upgrade'
  164. yield 'Sec-WebSocket-Origin', origin
  165. yield 'Sec-WebSocket-Location', location
  166. yield 'Sec-WebSocket-Accept', accept
  167. if self.wsock.protocol:
  168. yield 'Sec-WebSocket-Protocol', self.wsock.protocol
  169. if self.wsock.extensions:
  170. values = [format_param_hdr(e.name, e.request)
  171. for e in self.wsock.extensions]
  172. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  173. class ClientHandshake(Handshake):
  174. """
  175. Executes a handshake as the client end point of the socket. May raise a
  176. HandshakeError if the server response is invalid.
  177. """
  178. def __init__(self, wsock):
  179. Handshake.__init__(self, wsock)
  180. self.redirects = 0
  181. def perform(self):
  182. self.send_headers(self.request_headers())
  183. self.handle_response(*self.receive_response())
  184. def handle_response(self, status, headers):
  185. if status == 101:
  186. self.handle_handshake(headers)
  187. elif status == 401:
  188. self.handle_auth(headers)
  189. elif status in (301, 302, 303, 307, 308):
  190. self.handle_redirect(headers)
  191. else:
  192. self.fail('invalid HTTP response status %d' % status)
  193. def handle_handshake(self, headers):
  194. # Check if headers that MUST be present are actually present
  195. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  196. if name not in headers:
  197. self.fail('missing "%s" header' % name)
  198. if 'websocket' not in headers['Upgrade'].lower():
  199. self.fail('"Upgrade" header must contain "websocket"')
  200. if 'upgrade' not in headers['Connection'].lower():
  201. self.fail('"Connection" header must contain "Upgrade"')
  202. # Verify accept header
  203. accept = headers['Sec-WebSocket-Accept'].strip()
  204. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  205. if accept != required_accept:
  206. self.fail('invalid websocket accept header "%s"' % accept)
  207. # Compare extensions, add hooks only for those returned by server
  208. if 'Sec-WebSocket-Extensions' in headers:
  209. supported_ext = dict((e.name, e) for e in self.wsock.extensions)
  210. self.wsock.extensions = []
  211. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  212. name, params = parse_param_hdr(ext)
  213. if name not in supported_ext:
  214. raise HandshakeError('server handshake contains '
  215. 'unsupported extension "%s"' % name)
  216. hook = supported_ext[name].create_hook(**params)
  217. self.wsock.extensions.append(supported_ext[name])
  218. self.wsock.add_hook(send=hook.send, recv=hook.recv)
  219. # Assert that returned protocol (if any) is supported
  220. if 'Sec-WebSocket-Protocol' in headers:
  221. protocol = headers['Sec-WebSocket-Protocol']
  222. if protocol != 'null' and protocol not in self.wsock.protocols:
  223. self.fail('unsupported protocol "%s"' % protocol)
  224. self.wsock.protocol = protocol
  225. def handle_auth(self, headers):
  226. # HTTP authentication is required in the request
  227. hdr = headers['WWW-Authenticate']
  228. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  229. mode = hdr.lstrip().split(' ', 1)[0]
  230. if not self.wsock.auth:
  231. self.fail('missing username and password for HTTP authentication')
  232. if mode == 'Basic':
  233. auth_hdr = self.http_auth_basic_headers(**authres)
  234. elif mode == 'Digest':
  235. auth_hdr = self.http_auth_digest_headers(**authres)
  236. else:
  237. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  238. # Send new, authenticated handshake
  239. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  240. self.handle_response(*self.receive_response())
  241. def handle_redirect(self, headers):
  242. self.redirects += 1
  243. if self.redirects > MAX_REDIRECTS:
  244. self.fail('reached maximum number of redirects (%d)'
  245. % MAX_REDIRECTS)
  246. # Handle HTTP redirect
  247. url = urlparse(headers['Location'].strip())
  248. # Reconnect socket to new host if net location changed
  249. if not url.port:
  250. url.port = 443 if self.secure else 80
  251. addr = (url.netloc, url.port)
  252. if addr != self.sock.getpeername():
  253. self.sock.close()
  254. self.sock.connect(addr)
  255. # Update websocket object and send new handshake
  256. self.wsock.location = url.path
  257. self.perform()
  258. def request_headers(self):
  259. if len(self.wsock.location) == 0:
  260. self.fail('request location is empty')
  261. # Generate a 16-byte random base64-encoded key for this connection
  262. self.key = b64encode(os.urandom(16))
  263. # Send client handshake
  264. yield 'GET %s HTTP/1.1' % self.wsock.location
  265. yield 'Host', '%s:%d' % self.sock.getpeername()
  266. yield 'Upgrade', 'websocket'
  267. yield 'Connection', 'keep-alive, Upgrade'
  268. yield 'Sec-WebSocket-Key', self.key
  269. yield 'Sec-WebSocket-Version', WS_VERSION
  270. if self.wsock.origin:
  271. yield 'Origin', self.wsock.origin
  272. # These are for eagerly caching webservers
  273. yield 'Pragma', 'no-cache'
  274. yield 'Cache-Control', 'no-cache'
  275. # Request protocols and extensions, these are later checked with the
  276. # actual supported values from the server's response
  277. if self.wsock.protocols:
  278. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  279. if self.wsock.extensions:
  280. values = [format_param_hdr(e.name, e.request)
  281. for e in self.wsock.extensions]
  282. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  283. def http_auth_basic_headers(self, **kwargs):
  284. u, p = self.wsock.auth
  285. u = u.encode('utf-8')
  286. p = p.encode('utf-8')
  287. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  288. def http_auth_digest_headers(self, **kwargs):
  289. username, password = self.wsock.auth
  290. yield 'Authorization', build_authorization_request(
  291. username=username.encode('utf-8'),
  292. method='GET',
  293. uri=self.wsock.location,
  294. nonce_count=0,
  295. realm=kwargs['realm'],
  296. nonce=kwargs['nonce'],
  297. opaque=kwargs['opaque'],
  298. password=password.encode('utf-8'))
  299. def split_stripped(value, delim=',', maxsplits=-1):
  300. return map(str.strip, str(value).split(delim, maxsplits)) if value else []
  301. def parse_param_hdr(hdr):
  302. if ';' in hdr:
  303. name, paramstr = split_stripped(hdr, ';', 1)
  304. else:
  305. name = hdr
  306. paramstr = ''
  307. params = {}
  308. for param in split_stripped(paramstr):
  309. if '=' in param:
  310. key, value = split_stripped(param, '=', 1)
  311. if value.isdigit():
  312. value = int(value)
  313. else:
  314. key = param
  315. value = True
  316. params[key] = value
  317. return name, params
  318. def format_param_hdr(value, params):
  319. if not params:
  320. return value
  321. def fmt_param((k, v)):
  322. if v is True:
  323. return k
  324. if v is not False and v is not None:
  325. return k + '=' + str(v)
  326. strparams = filter(None, map(fmt_param, params.items()))
  327. return '%s; %s' % (value, ', '.join(strparams))