handshake.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. import os
  2. import re
  3. from hashlib import sha1
  4. from base64 import b64encode
  5. from urlparse import urlparse
  6. from python_digest import build_authorization_request
  7. from errors import HandshakeError
  8. from extension import filter_extensions
  9. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  10. WS_VERSION = '13'
  11. MAX_REDIRECTS = 10
  12. class Handshake(object):
  13. def __init__(self, wsock):
  14. self.wsock = wsock
  15. self.sock = wsock.sock
  16. def fail(self, msg):
  17. self.sock.close()
  18. raise HandshakeError(msg)
  19. def receive_request(self):
  20. raw, headers = self.receive_headers()
  21. # Request must be HTTP (at least 1.1) GET request, find the location
  22. match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw)
  23. if match is None:
  24. self.fail('not a valid HTTP 1.1 GET request')
  25. location = match.group(1)
  26. return location, headers
  27. def receive_response(self):
  28. raw, headers = self.receive_headers()
  29. # Response must be HTTP (at least 1.1) with status 101
  30. match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
  31. if match is None:
  32. self.fail('not a valid HTTP 1.1 response')
  33. status = int(match.group(1))
  34. return status, headers
  35. def receive_headers(self):
  36. # Receive entire HTTP header
  37. raw_headers = ''
  38. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  39. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  40. headers = {}
  41. for key, value in re.findall(r'(.*?): ?(.*?)\r\n', raw_headers):
  42. if key in headers:
  43. headers[key] += ', ' + value
  44. else:
  45. headers[key] = value
  46. return raw_headers, headers
  47. def send_headers(self, headers):
  48. # Send request
  49. for hdr in list(headers):
  50. if isinstance(hdr, tuple):
  51. hdr = '%s: %s' % hdr
  52. self.sock.sendall(hdr + '\r\n')
  53. self.sock.sendall('\r\n')
  54. def perform(self):
  55. raise NotImplementedError
  56. class ServerHandshake(Handshake):
  57. """
  58. Executes a handshake as the server end point of the socket. If the HTTP
  59. request headers sent by the client are invalid, a HandshakeError is raised.
  60. """
  61. def perform(self):
  62. # Receive and validate client handshake
  63. self.wsock.location, headers = self.receive_request()
  64. # Send server handshake in response
  65. self.send_headers(self.response_headers(headers))
  66. def response_headers(self, headers):
  67. # Check if headers that MUST be present are actually present
  68. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  69. 'Sec-WebSocket-Version'):
  70. if name not in headers:
  71. self.fail('missing "%s" header' % name)
  72. # Check WebSocket version used by client
  73. version = headers['Sec-WebSocket-Version']
  74. if version != WS_VERSION:
  75. self.fail('WebSocket version %s requested (only %s is supported)'
  76. % (version, WS_VERSION))
  77. # Verify required header keywords
  78. if 'websocket' not in headers['Upgrade'].lower():
  79. self.fail('"Upgrade" header must contain "websocket"')
  80. if 'upgrade' not in headers['Connection'].lower():
  81. self.fail('"Connection" header must contain "Upgrade"')
  82. # Origin must be present if browser client, and must match the list of
  83. # trusted origins
  84. origin = 'null'
  85. if 'Origin' not in headers:
  86. if 'User-Agent' in headers:
  87. self.fail('browser client must specify "Origin" header')
  88. if self.wsock.trusted_origins:
  89. self.fail('no "Origin" header specified, assuming untrusted')
  90. elif self.wsock.trusted_origins:
  91. origin = headers['Origin']
  92. if origin not in self.wsock.trusted_origins:
  93. self.fail('untrusted origin "%s"' % origin)
  94. # Only a supported protocol can be returned
  95. client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
  96. if 'Sec-WebSocket-Protocol' in headers else []
  97. self.wsock.protocol = None
  98. for p in client_proto:
  99. if p in self.wsock.protocols:
  100. self.wsock.protocol = p
  101. break
  102. # Only supported extensions are returned
  103. if 'Sec-WebSocket-Extensions' in headers:
  104. supported_ext = dict((e.name, e) for e in self.wsock.extensions)
  105. extensions = []
  106. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  107. name, params = parse_param_hdr(ext)
  108. if name in supported_ext:
  109. extensions.append(supported_ext[name](**params))
  110. self.wsock.extensions = filter_extensions(extensions)
  111. else:
  112. self.wsock.extensions = []
  113. # Encode acceptation key using the WebSocket GUID
  114. key = headers['Sec-WebSocket-Key'].strip()
  115. accept = b64encode(sha1(key + WS_GUID).digest())
  116. # Location scheme differs for SSL-enabled connections
  117. scheme = 'wss' if self.wsock.secure else 'ws'
  118. if 'Host' in headers:
  119. host = headers['Host']
  120. else:
  121. host, port = self.sock.getpeername()
  122. default_port = 443 if self.wsock.secure else 80
  123. if port != default_port:
  124. host += ':%d' % port
  125. location = '%s://%s%s' % (scheme, host, self.wsock.location)
  126. # Construct HTTP response header
  127. yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
  128. yield 'Upgrade', 'websocket'
  129. yield 'Connection', 'Upgrade'
  130. yield 'WebSocket-Origin', origin
  131. yield 'WebSocket-Location', location
  132. yield 'Sec-WebSocket-Accept', accept
  133. if self.wsock.protocol:
  134. yield 'Sec-WebSocket-Protocol', self.wsock.protocol
  135. if self.wsock.extensions:
  136. values = [format_param_hdr(e.name, e.header_params())
  137. for e in self.wsock.extensions]
  138. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  139. class ClientHandshake(Handshake):
  140. """
  141. Executes a handshake as the client end point of the socket. May raise a
  142. HandshakeError if the server response is invalid.
  143. """
  144. def __init__(self, wsock):
  145. Handshake.__init__(self, wsock)
  146. self.redirects = 0
  147. def perform(self):
  148. self.send_headers(self.request_headers())
  149. self.handle_response(*self.receive_response())
  150. def handle_response(self, status, headers):
  151. if status == 101:
  152. self.handle_handshake(headers)
  153. elif status == 401:
  154. self.handle_auth(headers)
  155. elif status in (301, 302, 303, 307, 308):
  156. self.handle_redirect(headers)
  157. else:
  158. self.fail('invalid HTTP response status %d' % status)
  159. def handle_handshake(self, headers):
  160. # Check if headers that MUST be present are actually present
  161. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  162. if name not in headers:
  163. self.fail('missing "%s" header' % name)
  164. if 'websocket' not in headers['Upgrade'].lower():
  165. self.fail('"Upgrade" header must contain "websocket"')
  166. if 'upgrade' not in headers['Connection'].lower():
  167. self.fail('"Connection" header must contain "Upgrade"')
  168. # Verify accept header
  169. accept = headers['Sec-WebSocket-Accept'].strip()
  170. required_accept = b64encode(sha1(self.key + WS_GUID).digest())
  171. if accept != required_accept:
  172. self.fail('invalid websocket accept header "%s"' % accept)
  173. # Compare extensions
  174. if 'Sec-WebSocket-Extensions' in headers:
  175. supported_ext = dict((e.name, e) for e in self.wsock.extensions)
  176. self.wsock.extensions = []
  177. for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
  178. name, params = parse_param_hdr(ext)
  179. if name not in supported_ext:
  180. raise HandshakeError('server handshake contains '
  181. 'unsupported extension "%s"' % name)
  182. self.wsock.extensions.append(supported_ext[name](**params))
  183. # Assert that returned protocol (if any) is supported
  184. if 'Sec-WebSocket-Protocol' in headers:
  185. protocol = headers['Sec-WebSocket-Protocol']
  186. if protocol != 'null' and protocol not in self.wsock.protocols:
  187. self.fail('unsupported protocol "%s"' % protocol)
  188. self.wsock.protocol = protocol
  189. def handle_auth(self, headers):
  190. # HTTP authentication is required in the request
  191. hdr = headers['WWW-Authenticate']
  192. authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
  193. mode = hdr.lstrip().split(' ', 1)[0]
  194. if not self.wsock.auth:
  195. self.fail('missing username and password for HTTP authentication')
  196. if mode == 'Basic':
  197. auth_hdr = self.http_auth_basic_headers(**authres)
  198. elif mode == 'Digest':
  199. auth_hdr = self.http_auth_digest_headers(**authres)
  200. else:
  201. self.fail('unsupported HTTP authentication mode "%s"' % mode)
  202. # Send new, authenticated handshake
  203. self.send_headers(list(self.request_headers()) + list(auth_hdr))
  204. self.handle_response(*self.receive_response())
  205. def handle_redirect(self, headers):
  206. self.redirects += 1
  207. if self.redirects > MAX_REDIRECTS:
  208. self.fail('reached maximum number of redirects (%d)'
  209. % MAX_REDIRECTS)
  210. # Handle HTTP redirect
  211. url = urlparse(headers['Location'].strip())
  212. # Reconnect socket to new host if net location changed
  213. if not url.port:
  214. url.port = 443 if self.secure else 80
  215. addr = (url.netloc, url.port)
  216. if addr != self.sock.getpeername():
  217. self.sock.close()
  218. self.sock.connect(addr)
  219. # Update websocket object and send new handshake
  220. self.wsock.location = url.path
  221. self.perform()
  222. def request_headers(self):
  223. if len(self.wsock.location) == 0:
  224. self.fail('request location is empty')
  225. # Generate a 16-byte random base64-encoded key for this connection
  226. self.key = b64encode(os.urandom(16))
  227. # Send client handshake
  228. yield 'GET %s HTTP/1.1' % self.wsock.location
  229. yield 'Host', '%s:%d' % self.sock.getpeername()
  230. yield 'Upgrade', 'websocket'
  231. yield 'Connection', 'keep-alive, Upgrade'
  232. yield 'Sec-WebSocket-Key', self.key
  233. yield 'Sec-WebSocket-Version', WS_VERSION
  234. if self.wsock.origin:
  235. yield 'Origin', self.wsock.origin
  236. # These are for eagerly caching webservers
  237. yield 'Pragma', 'no-cache'
  238. yield 'Cache-Control', 'no-cache'
  239. # Request protocols and extensions, these are later checked with the
  240. # actual supported values from the server's response
  241. if self.wsock.protocols:
  242. yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
  243. if self.wsock.extensions:
  244. values = [format_param_hdr(e.name, e.header_params())
  245. for e in self.wsock.extensions]
  246. yield 'Sec-WebSocket-Extensions', ', '.join(values)
  247. def http_auth_basic_headers(self, **kwargs):
  248. u, p = self.wsock.auth
  249. u = u.encode('utf-8')
  250. p = p.encode('utf-8')
  251. yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
  252. def http_auth_digest_headers(self, **kwargs):
  253. username, password = self.wsock.auth
  254. yield 'Authorization', build_authorization_request(
  255. username=username.encode('utf-8'),
  256. method='GET',
  257. uri=self.wsock.location,
  258. nonce_count=0,
  259. realm=kwargs['realm'],
  260. nonce=kwargs['nonce'],
  261. opaque=kwargs['opaque'],
  262. password=password.encode('utf-8'))
  263. def split_stripped(value, delim=','):
  264. return map(str.strip, str(value).split(delim)) if value else []
  265. def parse_param_hdr(hdr):
  266. name, paramstr = split_stripped(hdr, ';')
  267. params = {}
  268. for param in split_stripped(paramstr):
  269. if '=' in param:
  270. key, value = split_stripped(param, '=')
  271. if value.isdigit():
  272. value = int(value)
  273. else:
  274. key = param
  275. value = True
  276. params[key] = value
  277. yield name, params
  278. def format_param_hdr(value, params):
  279. if not params:
  280. return value
  281. def fmt_param((k, v)):
  282. if v is True:
  283. return k
  284. if v is not False and v is not None:
  285. return k + '=' + str(v)
  286. strparams = filter(None, map(fmt_param, params.items()))
  287. return '%s; %s' % (value, ', '.join(strparams))