|
@@ -6,6 +6,7 @@ from urlparse import urlparse
|
|
|
|
|
|
|
|
from python_digest import build_authorization_request
|
|
from python_digest import build_authorization_request
|
|
|
from errors import HandshakeError
|
|
from errors import HandshakeError
|
|
|
|
|
+from extension import filter_extensions
|
|
|
|
|
|
|
|
|
|
|
|
|
WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
|
WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
|
@@ -13,10 +14,6 @@ WS_VERSION = '13'
|
|
|
MAX_REDIRECTS = 10
|
|
MAX_REDIRECTS = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
-def split_stripped(value, delim=','):
|
|
|
|
|
- return map(str.strip, str(value).split(delim)) if value else []
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
class Handshake(object):
|
|
class Handshake(object):
|
|
|
def __init__(self, wsock):
|
|
def __init__(self, wsock):
|
|
|
self.wsock = wsock
|
|
self.wsock = wsock
|
|
@@ -134,20 +131,28 @@ class ServerHandshake(Handshake):
|
|
|
# Only a supported protocol can be returned
|
|
# Only a supported protocol can be returned
|
|
|
client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
|
|
client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
|
|
|
if 'Sec-WebSocket-Protocol' in headers else []
|
|
if 'Sec-WebSocket-Protocol' in headers else []
|
|
|
- protocol = None
|
|
|
|
|
|
|
+ self.wsock.protocol = None
|
|
|
|
|
|
|
|
for p in client_proto:
|
|
for p in client_proto:
|
|
|
- if p in self.wsock.proto:
|
|
|
|
|
- protocol = p
|
|
|
|
|
|
|
+ if p in self.wsock.protocols:
|
|
|
|
|
+ self.wsock.protocol = p
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
# Only supported extensions are returned
|
|
# Only supported extensions are returned
|
|
|
if 'Sec-WebSocket-Extensions' in headers:
|
|
if 'Sec-WebSocket-Extensions' in headers:
|
|
|
- client_ext = split_stripped(headers['Sec-WebSocket-Extensions'])
|
|
|
|
|
- extensions = [e for e in client_ext if e in self.wsock.extensions]
|
|
|
|
|
- else:
|
|
|
|
|
|
|
+ supported_ext = dict((e.name, e) for e in self.wsock.extensions)
|
|
|
extensions = []
|
|
extensions = []
|
|
|
|
|
|
|
|
|
|
+ for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
|
|
|
|
|
+ name, params = parse_param_hdr(ext)
|
|
|
|
|
+
|
|
|
|
|
+ if name in supported_ext:
|
|
|
|
|
+ extensions.append(supported_ext[name](**params))
|
|
|
|
|
+
|
|
|
|
|
+ self.wsock.extensions = filter_extensions(extensions)
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.wsock.extensions = []
|
|
|
|
|
+
|
|
|
# Encode acceptation key using the WebSocket GUID
|
|
# Encode acceptation key using the WebSocket GUID
|
|
|
key = headers['Sec-WebSocket-Key'].strip()
|
|
key = headers['Sec-WebSocket-Key'].strip()
|
|
|
accept = b64encode(sha1(key + WS_GUID).digest())
|
|
accept = b64encode(sha1(key + WS_GUID).digest())
|
|
@@ -174,12 +179,13 @@ class ServerHandshake(Handshake):
|
|
|
yield 'WebSocket-Location', location
|
|
yield 'WebSocket-Location', location
|
|
|
yield 'Sec-WebSocket-Accept', accept
|
|
yield 'Sec-WebSocket-Accept', accept
|
|
|
|
|
|
|
|
- if protocol:
|
|
|
|
|
- yield 'Sec-WebSocket-Protocol', protocol
|
|
|
|
|
-
|
|
|
|
|
- if extensions:
|
|
|
|
|
- yield 'Sec-WebSocket-Extensions', ', '.join(extensions)
|
|
|
|
|
|
|
+ if self.wsock.protocol:
|
|
|
|
|
+ yield 'Sec-WebSocket-Protocol', self.wsock.protocol
|
|
|
|
|
|
|
|
|
|
+ if self.wsock.extensions:
|
|
|
|
|
+ values = [format_param_hdr(e.name, e.header_params())
|
|
|
|
|
+ for e in self.wsock.extensions]
|
|
|
|
|
+ yield 'Sec-WebSocket-Extensions', ', '.join(values)
|
|
|
|
|
|
|
|
class ClientHandshake(Handshake):
|
|
class ClientHandshake(Handshake):
|
|
|
"""
|
|
"""
|
|
@@ -226,13 +232,17 @@ class ClientHandshake(Handshake):
|
|
|
|
|
|
|
|
# Compare extensions
|
|
# Compare extensions
|
|
|
if 'Sec-WebSocket-Extensions' in headers:
|
|
if 'Sec-WebSocket-Extensions' in headers:
|
|
|
- server_ext = split_stripped(headers['Sec-WebSocket-Extensions'])
|
|
|
|
|
|
|
+ supported_ext = dict((e.name, e) for e in self.wsock.extensions)
|
|
|
|
|
+ self.wsock.extensions = []
|
|
|
|
|
|
|
|
- for e in set(server_ext) - set(self.wsock.extensions):
|
|
|
|
|
- self.fail('server extension "%s" unsupported by client' % e)
|
|
|
|
|
|
|
+ for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
|
|
|
|
|
+ name, params = parse_param_hdr(ext)
|
|
|
|
|
|
|
|
- for e in set(self.wsock.extensions) - set(server_ext):
|
|
|
|
|
- self.fail('client extension "%s" unsupported by server' % e)
|
|
|
|
|
|
|
+ if name not in supported_ext:
|
|
|
|
|
+ raise HandshakeError('server handshake contains '
|
|
|
|
|
+ 'unsupported extension "%s"' % name)
|
|
|
|
|
+
|
|
|
|
|
+ self.wsock.extensions.append(supported_ext[name](**params))
|
|
|
|
|
|
|
|
# Assert that returned protocol (if any) is supported
|
|
# Assert that returned protocol (if any) is supported
|
|
|
if 'Sec-WebSocket-Protocol' in headers:
|
|
if 'Sec-WebSocket-Protocol' in headers:
|
|
@@ -309,13 +319,15 @@ class ClientHandshake(Handshake):
|
|
|
yield 'Pragma', 'no-cache'
|
|
yield 'Pragma', 'no-cache'
|
|
|
yield 'Cache-Control', 'no-cache'
|
|
yield 'Cache-Control', 'no-cache'
|
|
|
|
|
|
|
|
- # Request protocols and extension, these are later checked with the
|
|
|
|
|
|
|
+ # Request protocols and extensions, these are later checked with the
|
|
|
# actual supported values from the server's response
|
|
# actual supported values from the server's response
|
|
|
if self.wsock.protocols:
|
|
if self.wsock.protocols:
|
|
|
yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
|
|
yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
|
|
|
|
|
|
|
|
if self.wsock.extensions:
|
|
if self.wsock.extensions:
|
|
|
- yield 'Sec-WebSocket-Extensions', ', '.join(self.wsock.extensions)
|
|
|
|
|
|
|
+ values = [format_param_hdr(e.name, e.header_params())
|
|
|
|
|
+ for e in self.wsock.extensions]
|
|
|
|
|
+ yield 'Sec-WebSocket-Extensions', ', '.join(values)
|
|
|
|
|
|
|
|
def http_auth_basic_headers(self, **kwargs):
|
|
def http_auth_basic_headers(self, **kwargs):
|
|
|
u, p = self.wsock.auth
|
|
u, p = self.wsock.auth
|
|
@@ -334,3 +346,41 @@ class ClientHandshake(Handshake):
|
|
|
nonce=kwargs['nonce'],
|
|
nonce=kwargs['nonce'],
|
|
|
opaque=kwargs['opaque'],
|
|
opaque=kwargs['opaque'],
|
|
|
password=password.encode('utf-8'))
|
|
password=password.encode('utf-8'))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def split_stripped(value, delim=','):
|
|
|
|
|
+ return map(str.strip, str(value).split(delim)) if value else []
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def parse_param_hdr(hdr):
|
|
|
|
|
+ name, paramstr = split_stripped(hdr, ';')
|
|
|
|
|
+ params = {}
|
|
|
|
|
+
|
|
|
|
|
+ for param in split_stripped(paramstr):
|
|
|
|
|
+ if '=' in param:
|
|
|
|
|
+ key, value = split_stripped(param, '=')
|
|
|
|
|
+
|
|
|
|
|
+ if value.isdigit():
|
|
|
|
|
+ value = int(value)
|
|
|
|
|
+ else:
|
|
|
|
|
+ key = param
|
|
|
|
|
+ value = True
|
|
|
|
|
+
|
|
|
|
|
+ params[key] = value
|
|
|
|
|
+
|
|
|
|
|
+ yield name, params
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def format_param_hdr(value, params):
|
|
|
|
|
+ if not params:
|
|
|
|
|
+ return value
|
|
|
|
|
+
|
|
|
|
|
+ def fmt_param((k, v)):
|
|
|
|
|
+ if v is True:
|
|
|
|
|
+ return k
|
|
|
|
|
+
|
|
|
|
|
+ if v is not False and v is not None:
|
|
|
|
|
+ return k + '=' + str(v)
|
|
|
|
|
+
|
|
|
|
|
+ strparams = filter(None, map(fmt_param, params.items()))
|
|
|
|
|
+ return '%s; %s' % (value, ', '.join(strparams))
|