Commit 7d246d5d authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented extensions in handshakes and installed extension hooks in websocket

parent 815fe7bb
- (Unit) tests - (Unit) tests
- Mutual exclusion in Server/Client (multiple threads sending stuff at the same - Mutual exclusion in Server/Client (multiple threads sending stuff at the same
time will go wrong) time will go wrong)
- Extensions - Extensions: pass parameters for Extension.header_params() to websocket
constructor
...@@ -2,6 +2,7 @@ from errors import HandshakeError ...@@ -2,6 +2,7 @@ from errors import HandshakeError
class Extension(object): class Extension(object):
name = ''
rsv1 = False rsv1 = False
rsv2 = False rsv2 = False
rsv3 = False rsv3 = False
...@@ -21,7 +22,16 @@ class Extension(object): ...@@ -21,7 +22,16 @@ class Extension(object):
setattr(self, param, value) setattr(self, param, value)
def client_params(self, frame): def __str__(self, frame):
if len(self.parameters):
params = ' ' + ', '.join(p + '=' + str(getattr(self, p))
for p in self.parameters)
else:
params = ''
return '<Extension "%s"%s>' % (self.name, params)
def header_params(self, frame):
return {} return {}
def hook_send(self, frame): def hook_send(self, frame):
...@@ -38,7 +48,10 @@ class DeflateFrame(Extension): ...@@ -38,7 +48,10 @@ class DeflateFrame(Extension):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DeflateFrame, self).__init__(**kwargs) super(DeflateFrame, self).__init__(**kwargs)
self.max_window_bits = int(self.max_window_bits)
if self.max_window_bits is None:
# FIXME: is this correct? None may actually be a better value
self.max_window_bits = 0
def hook_send(self, frame): def hook_send(self, frame):
# FIXME: original `frame` is modified, maybe it should be copied? # FIXME: original `frame` is modified, maybe it should be copied?
...@@ -58,7 +71,7 @@ class DeflateFrame(Extension): ...@@ -58,7 +71,7 @@ class DeflateFrame(Extension):
return frame return frame
def client_params(self): def header_params(self):
raise NotImplementedError # TODO raise NotImplementedError # TODO
def encode(self, data): def encode(self, data):
...@@ -71,7 +84,7 @@ class DeflateFrame(Extension): ...@@ -71,7 +84,7 @@ class DeflateFrame(Extension):
def filter_extensions(extensions): def filter_extensions(extensions):
""" """
Remove extensions that use conflicting rsv bits and/or opcodes, with the Remove extensions that use conflicting rsv bits and/or opcodes, with the
first options being most preferable. first options being the most preferable.
""" """
rsv1_reserved = True rsv1_reserved = True
rsv2_reserved = True rsv2_reserved = True
...@@ -93,23 +106,3 @@ def filter_extensions(extensions): ...@@ -93,23 +106,3 @@ def filter_extensions(extensions):
compat.append(ext) compat.append(ext)
return compat return compat
"""
Class map used to find contructors for client-specified extensions. Not to be
modified manually, only through `register_extension`.
"""
extension_class_map = {}
def register_extension(ext):
if not isinstance(ext, Extension):
raise ValueError('extensions should extend the `Extension` class')
if ext.name in extension_clas_map:
raise KeyError('extension "%s" has already been registered' % ext.name)
extension_class_map[ext.name] = ext
register_extension(DeflateFrame)
...@@ -6,6 +6,7 @@ from urlparse import urlparse ...@@ -6,6 +6,7 @@ from urlparse import urlparse
from python_digest import build_authorization_request from python_digest import build_authorization_request
from errors import HandshakeError from errors import HandshakeError
from extension import filter_extensions
WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
...@@ -13,10 +14,6 @@ WS_VERSION = '13' ...@@ -13,10 +14,6 @@ WS_VERSION = '13'
MAX_REDIRECTS = 10 MAX_REDIRECTS = 10
def split_stripped(value, delim=','):
return map(str.strip, str(value).split(delim)) if value else []
class Handshake(object): class Handshake(object):
def __init__(self, wsock): def __init__(self, wsock):
self.wsock = wsock self.wsock = wsock
...@@ -134,20 +131,28 @@ class ServerHandshake(Handshake): ...@@ -134,20 +131,28 @@ class ServerHandshake(Handshake):
# Only a supported protocol can be returned # Only a supported protocol can be returned
client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \ client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
if 'Sec-WebSocket-Protocol' in headers else [] if 'Sec-WebSocket-Protocol' in headers else []
protocol = None self.wsock.protocol = None
for p in client_proto: for p in client_proto:
if p in self.wsock.proto: if p in self.wsock.protocols:
protocol = p self.wsock.protocol = p
break break
# Only supported extensions are returned # Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
client_ext = split_stripped(headers['Sec-WebSocket-Extensions']) supported_ext = dict((e.name, e) for e in self.wsock.extensions)
extensions = [e for e in client_ext if e in self.wsock.extensions]
else:
extensions = [] extensions = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext)
if name in supported_ext:
extensions.append(supported_ext[name](**params))
self.wsock.extensions = filter_extensions(extensions)
else:
self.wsock.extensions = []
# Encode acceptation key using the WebSocket GUID # Encode acceptation key using the WebSocket GUID
key = headers['Sec-WebSocket-Key'].strip() key = headers['Sec-WebSocket-Key'].strip()
accept = b64encode(sha1(key + WS_GUID).digest()) accept = b64encode(sha1(key + WS_GUID).digest())
...@@ -174,12 +179,13 @@ class ServerHandshake(Handshake): ...@@ -174,12 +179,13 @@ class ServerHandshake(Handshake):
yield 'WebSocket-Location', location yield 'WebSocket-Location', location
yield 'Sec-WebSocket-Accept', accept yield 'Sec-WebSocket-Accept', accept
if protocol: if self.wsock.protocol:
yield 'Sec-WebSocket-Protocol', protocol yield 'Sec-WebSocket-Protocol', self.wsock.protocol
if extensions:
yield 'Sec-WebSocket-Extensions', ', '.join(extensions)
if self.wsock.extensions:
values = [format_param_hdr(e.name, e.header_params())
for e in self.wsock.extensions]
yield 'Sec-WebSocket-Extensions', ', '.join(values)
class ClientHandshake(Handshake): class ClientHandshake(Handshake):
""" """
...@@ -226,13 +232,17 @@ class ClientHandshake(Handshake): ...@@ -226,13 +232,17 @@ class ClientHandshake(Handshake):
# Compare extensions # Compare extensions
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
server_ext = split_stripped(headers['Sec-WebSocket-Extensions']) supported_ext = dict((e.name, e) for e in self.wsock.extensions)
self.wsock.extensions = []
for e in set(server_ext) - set(self.wsock.extensions): for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
self.fail('server extension "%s" unsupported by client' % e) name, params = parse_param_hdr(ext)
for e in set(self.wsock.extensions) - set(server_ext): if name not in supported_ext:
self.fail('client extension "%s" unsupported by server' % e) raise HandshakeError('server handshake contains '
'unsupported extension "%s"' % name)
self.wsock.extensions.append(supported_ext[name](**params))
# Assert that returned protocol (if any) is supported # Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers: if 'Sec-WebSocket-Protocol' in headers:
...@@ -309,13 +319,15 @@ class ClientHandshake(Handshake): ...@@ -309,13 +319,15 @@ class ClientHandshake(Handshake):
yield 'Pragma', 'no-cache' yield 'Pragma', 'no-cache'
yield 'Cache-Control', 'no-cache' yield 'Cache-Control', 'no-cache'
# Request protocols and extension, these are later checked with the # Request protocols and extensions, these are later checked with the
# actual supported values from the server's response # actual supported values from the server's response
if self.wsock.protocols: if self.wsock.protocols:
yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols) yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
if self.wsock.extensions: if self.wsock.extensions:
yield 'Sec-WebSocket-Extensions', ', '.join(self.wsock.extensions) values = [format_param_hdr(e.name, e.header_params())
for e in self.wsock.extensions]
yield 'Sec-WebSocket-Extensions', ', '.join(values)
def http_auth_basic_headers(self, **kwargs): def http_auth_basic_headers(self, **kwargs):
u, p = self.wsock.auth u, p = self.wsock.auth
...@@ -334,3 +346,41 @@ class ClientHandshake(Handshake): ...@@ -334,3 +346,41 @@ class ClientHandshake(Handshake):
nonce=kwargs['nonce'], nonce=kwargs['nonce'],
opaque=kwargs['opaque'], opaque=kwargs['opaque'],
password=password.encode('utf-8')) password=password.encode('utf-8'))
def split_stripped(value, delim=','):
return map(str.strip, str(value).split(delim)) if value else []
def parse_param_hdr(hdr):
name, paramstr = split_stripped(hdr, ';')
params = {}
for param in split_stripped(paramstr):
if '=' in param:
key, value = split_stripped(param, '=')
if value.isdigit():
value = int(value)
else:
key = param
value = True
params[key] = value
yield name, params
def format_param_hdr(value, params):
if not params:
return value
def fmt_param((k, v)):
if v is True:
return k
if v is not False and v is not None:
return k + '=' + str(v)
strparams = filter(None, map(fmt_param, params.items()))
return '%s; %s' % (value, ', '.join(strparams))
...@@ -104,6 +104,9 @@ class websocket(object): ...@@ -104,6 +104,9 @@ class websocket(object):
Send a number of frames. Send a number of frames.
""" """
for frame in args: for frame in args:
for ext in self.extensions:
frame = ext.hook_send(frame)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername() #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self.sock.sendall(frame.pack()) self.sock.sendall(frame.pack())
...@@ -113,6 +116,10 @@ class websocket(object): ...@@ -113,6 +116,10 @@ class websocket(object):
frame. frame.
""" """
frame = receive_frame(self.sock) frame = receive_frame(self.sock)
for ext in reversed(self.extensions):
frame = ext.hook_recv(frame)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername() #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return frame return frame
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment