handshake.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. import os
  2. import re
  3. import socket
  4. import time
  5. from base64 import b64encode
  6. from hashlib import sha1
  7. from urlparse import urlparse
  8. from errors import HandshakeError
  9. from python_digest import build_authorization_request
  10. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  11. WS_VERSION = '13'
  12. MAX_REDIRECTS = 10
  13. HDR_TIMEOUT = 5
  14. MAX_HDR_LEN = 1024
  15. class Handshake(object):
  16. def __init__(self, wsock):
  17. self.wsock = wsock
  18. self.sock = wsock.sock
  19. def fail(self, msg):
  20. self.sock.close()
  21. raise HandshakeError(msg)
  22. def receive_request(self):
  23. raw, headers = self.receive_headers()
  24. # Request must be HTTP (at least 1.1) GET request, find the location
  25. # (without trailing slash)
  26. match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
  27. if match is None:
  28. self.fail('not a valid HTTP 1.1 GET request')
  29. location = match.group(1)
  30. return location, headers
  31. def receive_response(self):
  32. raw, headers = self.receive_headers()
  33. # Response must be HTTP (at least 1.1) with status 101
  34. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  35. if match is None:
  36. self.fail('not a valid HTTP 1.1 response')
  37. status = int(match.group(1))
  38. return status, headers
  39. def receive_headers(self):
  40. # Receive entire HTTP header
  41. hdr = ''
  42. sock_timeout = self.sock.gettimeout()
  43. try:
  44. force_timeout = sock_timeout is None
  45. timeout = HDR_TIMEOUT if force_timeout else sock_timeout
  46. self.sock.settimeout(timeout)
  47. start_time = time.time()
  48. while hdr[-4:] != '\r\n\r\n':
  49. if len(hdr) == MAX_HDR_LEN:
  50. raise HandshakeError('request exceeds maximum header '
  51. 'length of %d' % MAX_HDR_LEN)
  52. hdr += self.sock.recv(1)
  53. time_diff = time.time() - start_time
  54. if time_diff > timeout:
  55. raise socket.timeout
  56. self.sock.settimeout(timeout - time_diff)
  57. except socket.timeout:
  58. self.sock.close()
  59. raise HandshakeError('timeout while receiving handshake headers')
  60. self.sock.settimeout(sock_timeout)
  61. hdr = hdr.decode('utf-8', 'ignore')
  62. headers = {}
  63. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', hdr):
  64. if key in headers:
  65. headers[key] += ', ' + value
  66. else:
  67. headers[key] = value
  68. return hdr, headers
  69. def send_headers(self, headers):
  70. # Send request
  71. for hdr in list(headers):
  72. if isinstance(hdr, tuple):
  73. hdr = '%s: %s' % hdr
  74. self.sock.sendall(hdr + '\r\n')
  75. self.sock.sendall('\r\n')
  76. def perform(self):
  77. raise NotImplementedError
  78. class ServerHandshake(Handshake):
  79. """
  80. Executes a handshake as the server end point of the socket. If the HTTP
  81. request headers sent by the client are invalid, a HandshakeError is raised.
  82. """
  83. def perform(self, ssock):
  84. # Receive and validate client handshake
  85. location, headers = self.receive_request()
  86. self.wsock.location = location
  87. self.wsock.request_headers = headers
  88. # Send server handshake in response
  89. self.send_headers(self.response_headers(ssock))
  90. def response_headers(self, ssock):
  91. headers = self.wsock.request_headers
  92. # Check if headers that MUST be present are actually present
  93. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  94. 'Sec-WebSocket-Version'):
  95. if name not in headers:
  96. self.fail('missing "%s" header' % name)
  97. # Check WebSocket version used by client
  98. version = headers['Sec-WebSocket-Version']
  99. if version != WS_VERSION:
  100. self.fail('WebSocket version %s requested (only %s is supported)'
  101. % (version, WS_VERSION))
  102. # Verify required header keywords
  103. if 'websocket' not in headers['Upgrade'].lower():
  104. self.fail('"Upgrade" header must contain "websocket"')
  105. if 'upgrade' not in headers['Connection'].lower():
  106. self.fail('"Connection" header must contain "Upgrade"')
  107. # Origin must be present if browser client, and must match the list of
  108. # trusted origins
  109. if 'Origin' not in headers and 'User-Agent' in headers:
  110. self.fail('browser client must specify "Origin" header')
  111. origin = headers.get('Origin', 'null')
  112. if origin == 'null':
  113. if ssock.trusted_origins:
  114. self.fail('no "Origin" header specified, assuming untrusted')
  115. elif ssock.trusted_origins and origin not in ssock.trusted_origins:
  116. self.fail('untrusted origin "%s"' % origin)
  117. # Only a supported protocol can be returned
  118. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  119. if 'Sec-WebSocket-Protocol' in headers else []
  120. self.wsock.protocol = None
  121. for p in client_proto:
  122. if p in ssock.protocols:
  123. self.wsock.protocol = p
  124. break
  125. # Only supported extensions are returned
  126. if 'Sec-WebSocket-Extensions' in headers:
  127. self.wsock.extension_instances = []
  128. for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
  129. name, params = parse_param_hdr(hdr)
  130. for ext in ssock.extensions:
  131. if ext.is_supported(name, self.wsock.extension_instances):
  132. accept_params = ext.negotiate_safe(name, params)
  133. if accept_params is not None:
  134. instance = ext.Instance(ext, name, accept_params)
  135. self.wsock.extension_instances.append(instance)
  136. # Check if requested resource location is served by this server
  137. if ssock.locations:
  138. if self.wsock.location not in ssock.locations:
  139. raise HandshakeError('location "%s" is not supported by this '
  140. 'server' % self.wsock.location)
  141. # Encode acceptation key using the WebSocket GUID
  142. key = headers['Sec-WebSocket-Key'].strip()
  143. accept = b64encode(sha1(key + WS_GUID).digest())
  144. # Location scheme differs for SSL-enabled connections
  145. scheme = 'wss' if self.wsock.secure else 'ws'
  146. if 'Host' in headers:
  147. host = headers['Host']
  148. else:
  149. host, port = self.sock.getpeername()
  150. default_port = 443 if self.wsock.secure else 80
  151. if port != default_port:
  152. host += ':%d' % port
  153. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  154. # Construct HTTP response header
  155. yield 'HTTP/1.1 101 Switching Protocols'
  156. yield 'Upgrade', 'websocket'
  157. yield 'Connection', 'Upgrade'
  158. yield 'Sec-WebSocket-Origin', origin
  159. yield 'Sec-WebSocket-Location', location
  160. yield 'Sec-WebSocket-Accept', accept
  161. if self.wsock.protocol:
  162. yield 'Sec-WebSocket-Protocol', self.wsock.protocol
  163. if self.wsock.extension_instances:
  164. values = [format_param_hdr(i.name, i.params)
  165. for i in self.wsock.extension_instances]
  166. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  167. class ClientHandshake(Handshake):
  168. """
  169. Executes a handshake as the client end point of the socket. May raise a
  170. HandshakeError if the server response is invalid.
  171. """
  172. def __init__(self, wsock):
  173. Handshake.__init__(self, wsock)
  174. self.redirects = 0
  175. def perform(self):
  176. self.send_headers(self.request_headers())
  177. self.handle_response(*self.receive_response())
  178. def handle_response(self, status, headers):
  179. if status == 101:
  180. self.handle_handshake(headers)
  181. elif status == 401:
  182. self.handle_auth(headers)
  183. elif status in (301, 302, 303, 307, 308):
  184. self.handle_redirect(headers)
  185. else:
  186. self.fail('invalid HTTP response status %d' % status)
  187. def handle_handshake(self, headers):
  188. # Check if headers that MUST be present are actually present
  189. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  190. if name not in headers:
  191. self.fail('missing "%s" header' % name)
  192. if 'websocket' not in headers['Upgrade'].lower():
  193. self.fail('"Upgrade" header must contain "websocket"')
  194. if 'upgrade' not in headers['Connection'].lower():
  195. self.fail('"Connection" header must contain "Upgrade"')
  196. # Verify accept header
  197. accept = headers['Sec-WebSocket-Accept'].strip()
  198. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  199. if accept != required_accept:
  200. self.fail('invalid websocket accept header "%s"' % accept)
  201. # Compare extensions, add hooks only for those returned by server
  202. if 'Sec-WebSocket-Extensions' in headers:
  203. # FIXME: there is no distinction between server/client extension
  204. # instances, while the extension instance may assume it belongs to
  205. # a server, leading to undefined behavior
  206. self.wsock.extension_instances = []
  207. for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
  208. name, accept_params = parse_param_hdr(hdr)
  209. for ext in self.wsock.extensions:
  210. if name in ext.names:
  211. instance = ext.Instance(ext, name, accept_params)
  212. self.wsock.extension_instances.append(instance)
  213. break
  214. else:
  215. raise HandshakeError('server handshake contains '
  216. 'unsupported extension "%s"' % name)
  217. # Assert that returned protocol (if any) is supported
  218. if 'Sec-WebSocket-Protocol' in headers:
  219. protocol = headers['Sec-WebSocket-Protocol']
  220. if protocol != 'null' and protocol not in self.wsock.protocols:
  221. self.fail('unsupported protocol "%s"' % protocol)
  222. self.wsock.protocol = protocol
  223. def handle_auth(self, headers):
  224. # HTTP authentication is required in the request
  225. hdr = headers['WWW-Authenticate']
  226. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  227. mode = hdr.lstrip().split(' ', 1)[0]
  228. if not self.wsock.auth:
  229. self.fail('missing username and password for HTTP authentication')
  230. if mode == 'Basic':
  231. auth_hdr = self.http_auth_basic_headers(**authres)
  232. elif mode == 'Digest':
  233. auth_hdr = self.http_auth_digest_headers(**authres)
  234. else:
  235. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  236. # Send new, authenticated handshake
  237. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  238. self.handle_response(*self.receive_response())
  239. def handle_redirect(self, headers):
  240. self.redirects += 1
  241. if self.redirects > MAX_REDIRECTS:
  242. self.fail('reached maximum number of redirects (%d)'
  243. % MAX_REDIRECTS)
  244. # Handle HTTP redirect
  245. url = urlparse(headers['Location'].strip())
  246. # Reconnect socket to new host if net location changed
  247. if not url.port:
  248. url.port = 443 if self.secure else 80
  249. addr = (url.netloc, url.port)
  250. if addr != self.sock.getpeername():
  251. self.sock.close()
  252. self.sock.connect(addr)
  253. # Update websocket object and send new handshake
  254. self.wsock.location = url.path
  255. self.perform()
  256. def request_headers(self):
  257. if len(self.wsock.location) == 0:
  258. self.fail('request location is empty')
  259. # Generate a 16-byte random base64-encoded key for this connection
  260. self.key = b64encode(os.urandom(16))
  261. # Send client handshake
  262. yield 'GET %s HTTP/1.1' % self.wsock.location
  263. yield 'Host', '%s:%d' % self.sock.getpeername()
  264. yield 'Upgrade', 'websocket'
  265. yield 'Connection', 'keep-alive, Upgrade'
  266. yield 'Sec-WebSocket-Key', self.key
  267. yield 'Sec-WebSocket-Version', WS_VERSION
  268. if self.wsock.origin:
  269. yield 'Origin', self.wsock.origin
  270. # These are for eagerly caching webservers
  271. yield 'Pragma', 'no-cache'
  272. yield 'Cache-Control', 'no-cache'
  273. # Request protocols and extensions, these are later checked with the
  274. # actual supported values from the server's response
  275. if self.wsock.protocols:
  276. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  277. if self.wsock.extensions:
  278. values = [format_param_hdr(e.name, e.request)
  279. for e in self.wsock.extensions]
  280. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  281. def http_auth_basic_headers(self, **kwargs):
  282. u, p = self.wsock.auth
  283. u = u.encode('utf-8')
  284. p = p.encode('utf-8')
  285. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  286. def http_auth_digest_headers(self, **kwargs):
  287. username, password = self.wsock.auth
  288. yield 'Authorization', build_authorization_request(
  289. username=username.encode('utf-8'),
  290. method='GET',
  291. uri=self.wsock.location,
  292. nonce_count=0,
  293. realm=kwargs['realm'],
  294. nonce=kwargs['nonce'],
  295. opaque=kwargs['opaque'],
  296. password=password.encode('utf-8'))
  297. def split_stripped(value, delim=',', maxsplits=-1):
  298. return map(str.strip, str(value).split(delim, maxsplits)) if value else []
  299. def parse_param_hdr(hdr):
  300. if ';' in hdr:
  301. name, paramstr = split_stripped(hdr, ';', 1)
  302. else:
  303. name = hdr
  304. paramstr = ''
  305. params = {}
  306. for param in split_stripped(paramstr):
  307. if '=' in param:
  308. key, value = split_stripped(param, '=', 1)
  309. if value.isdigit():
  310. value = int(value)
  311. else:
  312. key = param
  313. value = True
  314. params[key] = value
  315. return name, params
  316. def format_param_hdr(value, params):
  317. if not params:
  318. return value
  319. def fmt_param((k, v)):
  320. if v is True:
  321. return k
  322. if v is not False and v is not None:
  323. return k + '=' + str(v)
  324. strparams = filter(None, map(fmt_param, params.items()))
  325. return '; '.join([value] + strparams)