handshake.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import os
  2. import re
  3. import socket
  4. import time
  5. from base64 import b64encode
  6. from hashlib import sha1
  7. from urlparse import urlparse
  8. from errors import HandshakeError
  9. from python_digest import build_authorization_request
  10. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  11. WS_VERSION = '13'
  12. MAX_REDIRECTS = 10
  13. HDR_TIMEOUT = 5
  14. MAX_HDR_LEN = 1024
  15. class Handshake(object):
  16. def __init__(self, wsock):
  17. self.wsock = wsock
  18. self.sock = wsock.sock
  19. def fail(self, msg):
  20. self.sock.close()
  21. raise HandshakeError(msg)
  22. def receive_request(self):
  23. raw, headers = self.receive_headers()
  24. # Request must be HTTP (at least 1.1) GET request, find the location
  25. # (without trailing slash)
  26. match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
  27. if match is None:
  28. self.fail('not a valid HTTP 1.1 GET request')
  29. location = match.group(1)
  30. return location, headers
  31. def receive_response(self):
  32. raw, headers = self.receive_headers()
  33. # Response must be HTTP (at least 1.1) with status 101
  34. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  35. if match is None:
  36. self.fail('not a valid HTTP 1.1 response')
  37. status = int(match.group(1))
  38. return status, headers
  39. def receive_headers(self):
  40. # Receive entire HTTP header
  41. hdr = ''
  42. sock_timeout = self.sock.gettimeout()
  43. try:
  44. force_timeout = sock_timeout is None
  45. timeout = HDR_TIMEOUT if force_timeout else sock_timeout
  46. self.sock.settimeout(timeout)
  47. start_time = time.time()
  48. while hdr[-4:] != '\r\n\r\n':
  49. if len(hdr) == MAX_HDR_LEN:
  50. raise HandshakeError('request exceeds maximum header '
  51. 'length of %d' % MAX_HDR_LEN)
  52. hdr += self.sock.recv(1)
  53. time_diff = time.time() - start_time
  54. if time_diff > timeout:
  55. raise socket.timeout
  56. self.sock.settimeout(timeout - time_diff)
  57. except socket.timeout:
  58. self.sock.close()
  59. raise HandshakeError('timeout while receiving handshake headers')
  60. self.sock.settimeout(sock_timeout)
  61. hdr = hdr.decode('utf-8', 'ignore')
  62. headers = {}
  63. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', hdr):
  64. if key in headers:
  65. headers[key] += ', ' + value
  66. else:
  67. headers[key] = value
  68. return hdr, headers
  69. def send_headers(self, headers):
  70. # Send request
  71. for hdr in list(headers):
  72. if isinstance(hdr, tuple):
  73. hdr = '%s: %s' % hdr
  74. self.sock.sendall(hdr + '\r\n')
  75. self.sock.sendall('\r\n')
  76. def perform(self):
  77. raise NotImplementedError
  78. class ServerHandshake(Handshake):
  79. """
  80. Executes a handshake as the server end point of the socket. If the HTTP
  81. request headers sent by the client are invalid, a HandshakeError is raised.
  82. """
  83. def perform(self, ssock):
  84. # Receive and validate client handshake
  85. location, headers = self.receive_request()
  86. self.wsock.location = location
  87. self.wsock.request_headers = headers
  88. # Send server handshake in response
  89. self.send_headers(self.response_headers(ssock))
  90. def response_headers(self, ssock):
  91. headers = self.wsock.request_headers
  92. # Check if headers that MUST be present are actually present
  93. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  94. 'Sec-WebSocket-Version'):
  95. if name not in headers:
  96. self.fail('missing "%s" header' % name)
  97. # Check WebSocket version used by client
  98. version = headers['Sec-WebSocket-Version']
  99. if version != WS_VERSION:
  100. self.fail('WebSocket version %s requested (only %s is supported)'
  101. % (version, WS_VERSION))
  102. # Verify required header keywords
  103. if 'websocket' not in headers['Upgrade'].lower():
  104. self.fail('"Upgrade" header must contain "websocket"')
  105. if 'upgrade' not in headers['Connection'].lower():
  106. self.fail('"Connection" header must contain "Upgrade"')
  107. # Origin must be present if browser client, and must match the list of
  108. # trusted origins
  109. if 'Origin' not in headers and 'User-Agent' in headers:
  110. self.fail('browser client must specify "Origin" header')
  111. origin = headers.get('Origin', 'null')
  112. if origin == 'null':
  113. if ssock.trusted_origins:
  114. self.fail('no "Origin" header specified, assuming untrusted')
  115. elif ssock.trusted_origins and origin not in ssock.trusted_origins:
  116. self.fail('untrusted origin "%s"' % origin)
  117. # Only a supported protocol can be returned
  118. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  119. if 'Sec-WebSocket-Protocol' in headers else []
  120. self.wsock.protocol = None
  121. for p in client_proto:
  122. if p in ssock.protocols:
  123. self.wsock.protocol = p
  124. break
  125. # Only supported extensions are returned
  126. if 'Sec-WebSocket-Extensions' in headers:
  127. self.wsock.extension_instances = []
  128. for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
  129. name, params = parse_param_hdr(hdr)
  130. for ext in ssock.extensions:
  131. if not any(ext.conflicts(other.extension)
  132. for other in self.wsock.extension_instances):
  133. accept_params = ext.negotiate_safe(name, params)
  134. if accept_params is not None:
  135. instance = ext.Instance(ext, name, accept_params)
  136. self.wsock.extension_instances.append(instance)
  137. # Check if requested resource location is served by this server
  138. if ssock.locations:
  139. if self.wsock.location not in ssock.locations:
  140. raise HandshakeError('location "%s" is not supported by this '
  141. 'server' % self.wsock.location)
  142. # Encode acceptation key using the WebSocket GUID
  143. key = headers['Sec-WebSocket-Key'].strip()
  144. accept = b64encode(sha1(key + WS_GUID).digest())
  145. # Location scheme differs for SSL-enabled connections
  146. scheme = 'wss' if self.wsock.secure else 'ws'
  147. if 'Host' in headers:
  148. host = headers['Host']
  149. else:
  150. host, port = self.sock.getpeername()
  151. default_port = 443 if self.wsock.secure else 80
  152. if port != default_port:
  153. host += ':%d' % port
  154. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  155. # Construct HTTP response header
  156. yield 'HTTP/1.1 101 Switching Protocols'
  157. yield 'Upgrade', 'websocket'
  158. yield 'Connection', 'Upgrade'
  159. yield 'Sec-WebSocket-Origin', origin
  160. yield 'Sec-WebSocket-Location', location
  161. yield 'Sec-WebSocket-Accept', accept
  162. if self.wsock.protocol:
  163. yield 'Sec-WebSocket-Protocol', self.wsock.protocol
  164. if self.wsock.extension_instances:
  165. values = [format_param_hdr(i.name, i.params)
  166. for i in self.wsock.extension_instances]
  167. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  168. class ClientHandshake(Handshake):
  169. """
  170. Executes a handshake as the client end point of the socket. May raise a
  171. HandshakeError if the server response is invalid.
  172. """
  173. def __init__(self, wsock):
  174. Handshake.__init__(self, wsock)
  175. self.redirects = 0
  176. def perform(self):
  177. self.send_headers(self.request_headers())
  178. self.handle_response(*self.receive_response())
  179. def handle_response(self, status, headers):
  180. if status == 101:
  181. self.handle_handshake(headers)
  182. elif status == 401:
  183. self.handle_auth(headers)
  184. elif status in (301, 302, 303, 307, 308):
  185. self.handle_redirect(headers)
  186. else:
  187. self.fail('invalid HTTP response status %d' % status)
  188. def handle_handshake(self, headers):
  189. # Check if headers that MUST be present are actually present
  190. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  191. if name not in headers:
  192. self.fail('missing "%s" header' % name)
  193. if 'websocket' not in headers['Upgrade'].lower():
  194. self.fail('"Upgrade" header must contain "websocket"')
  195. if 'upgrade' not in headers['Connection'].lower():
  196. self.fail('"Connection" header must contain "Upgrade"')
  197. # Verify accept header
  198. accept = headers['Sec-WebSocket-Accept'].strip()
  199. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  200. if accept != required_accept:
  201. self.fail('invalid websocket accept header "%s"' % accept)
  202. # Compare extensions, add hooks only for those returned by server
  203. if 'Sec-WebSocket-Extensions' in headers:
  204. # FIXME: there is no distinction between server/client extension
  205. # instances, while the extension instance may assume it belongs to
  206. # a server, leading to undefined behavior
  207. self.wsock.extension_instances = []
  208. for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
  209. name, accept_params = parse_param_hdr(hdr)
  210. for ext in self.wsock.extensions:
  211. if name in ext.names:
  212. instance = ext.Instance(ext, name, accept_params)
  213. self.wsock.extension_instances.append(instance)
  214. break
  215. else:
  216. raise HandshakeError('server handshake contains '
  217. 'unsupported extension "%s"' % name)
  218. # Assert that returned protocol (if any) is supported
  219. if 'Sec-WebSocket-Protocol' in headers:
  220. protocol = headers['Sec-WebSocket-Protocol']
  221. if protocol != 'null' and protocol not in self.wsock.protocols:
  222. self.fail('unsupported protocol "%s"' % protocol)
  223. self.wsock.protocol = protocol
  224. def handle_auth(self, headers):
  225. # HTTP authentication is required in the request
  226. hdr = headers['WWW-Authenticate']
  227. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  228. mode = hdr.lstrip().split(' ', 1)[0]
  229. if not self.wsock.auth:
  230. self.fail('missing username and password for HTTP authentication')
  231. if mode == 'Basic':
  232. auth_hdr = self.http_auth_basic_headers(**authres)
  233. elif mode == 'Digest':
  234. auth_hdr = self.http_auth_digest_headers(**authres)
  235. else:
  236. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  237. # Send new, authenticated handshake
  238. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  239. self.handle_response(*self.receive_response())
  240. def handle_redirect(self, headers):
  241. self.redirects += 1
  242. if self.redirects > MAX_REDIRECTS:
  243. self.fail('reached maximum number of redirects (%d)'
  244. % MAX_REDIRECTS)
  245. # Handle HTTP redirect
  246. url = urlparse(headers['Location'].strip())
  247. # Reconnect socket to new host if net location changed
  248. if not url.port:
  249. url.port = 443 if self.secure else 80
  250. addr = (url.netloc, url.port)
  251. if addr != self.sock.getpeername():
  252. self.sock.close()
  253. self.sock.connect(addr)
  254. # Update websocket object and send new handshake
  255. self.wsock.location = url.path
  256. self.perform()
  257. def request_headers(self):
  258. if len(self.wsock.location) == 0:
  259. self.fail('request location is empty')
  260. # Generate a 16-byte random base64-encoded key for this connection
  261. self.key = b64encode(os.urandom(16))
  262. # Send client handshake
  263. yield 'GET %s HTTP/1.1' % self.wsock.location
  264. yield 'Host', '%s:%d' % self.sock.getpeername()
  265. yield 'Upgrade', 'websocket'
  266. yield 'Connection', 'keep-alive, Upgrade'
  267. yield 'Sec-WebSocket-Key', self.key
  268. yield 'Sec-WebSocket-Version', WS_VERSION
  269. if self.wsock.origin:
  270. yield 'Origin', self.wsock.origin
  271. # These are for eagerly caching webservers
  272. yield 'Pragma', 'no-cache'
  273. yield 'Cache-Control', 'no-cache'
  274. # Request protocols and extensions, these are later checked with the
  275. # actual supported values from the server's response
  276. if self.wsock.protocols:
  277. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  278. if self.wsock.extensions:
  279. values = [format_param_hdr(e.name, e.request)
  280. for e in self.wsock.extensions]
  281. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  282. def http_auth_basic_headers(self, **kwargs):
  283. u, p = self.wsock.auth
  284. u = u.encode('utf-8')
  285. p = p.encode('utf-8')
  286. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  287. def http_auth_digest_headers(self, **kwargs):
  288. username, password = self.wsock.auth
  289. yield 'Authorization', build_authorization_request(
  290. username=username.encode('utf-8'),
  291. method='GET',
  292. uri=self.wsock.location,
  293. nonce_count=0,
  294. realm=kwargs['realm'],
  295. nonce=kwargs['nonce'],
  296. opaque=kwargs['opaque'],
  297. password=password.encode('utf-8'))
  298. def split_stripped(value, delim=',', maxsplits=-1):
  299. return map(str.strip, str(value).split(delim, maxsplits)) if value else []
  300. def parse_param_hdr(hdr):
  301. if ';' in hdr:
  302. name, paramstr = split_stripped(hdr, ';', 1)
  303. else:
  304. name = hdr
  305. paramstr = ''
  306. params = {}
  307. for param in split_stripped(paramstr):
  308. if '=' in param:
  309. key, value = split_stripped(param, '=', 1)
  310. if value.isdigit():
  311. value = int(value)
  312. else:
  313. key = param
  314. value = True
  315. params[key] = value
  316. return name, params
  317. def format_param_hdr(value, params):
  318. if not params:
  319. return value
  320. def fmt_param((k, v)):
  321. if v is True:
  322. return k
  323. if v is not False and v is not None:
  324. return k + '=' + str(v)
  325. strparams = filter(None, map(fmt_param, params.items()))
  326. return '%s; %s' % (value, ', '.join(strparams))