Commit e465862f authored by Taddeüs Kroes's avatar Taddeüs Kroes

Revised extension instantiation, now 'hooks' are installed which are cleaner and more flexible

parent 6efb8807
from errors import HandshakeError
class Extension(object):
name = ''
rsv1 = False
rsv2 = False
rsv3 = False
opcodes = []
parameters = []
def __init__(self, **kwargs):
for param in self.parameters:
setattr(self, param, None)
defaults = {}
request = {}
for param, value in kwargs.items():
if param not in self.parameters:
raise HandshakeError('unrecognized parameter "%s"' % param)
def __init__(self, defaults={}, request={}):
for param in defaults.keys() + request.keys():
if param not in self.defaults:
raise KeyError('unrecognized parameter "%s"' % param)
if value is None:
value = True
# Copy dict first to avoid duplicate references to the same object
self.defaults = dict(self.__class__.defaults)
self.defaults.update(defaults)
setattr(self, param, value)
self.request = dict(self.__class__.request)
self.request.update(request)
def __str__(self, frame):
if len(self.parameters):
params = ' ' + ', '.join(p + '=' + str(getattr(self, p))
for p in self.parameters)
else:
params = ''
return '<Extension "%s" defaults=%s request=%s>' \
% (self.name, self.defaults, self.request)
return '<Extension "%s"%s>' % (self.name, params)
class Hook:
def __init__(self, **kwargs):
for param, value in kwargs.iteritems():
setattr(self, param, value)
def header_params(self, frame):
return {}
def send(self, frame):
return frame
def hook_send(self, frame):
return frame
def hook_receive(self, frame):
return frame
def recv(self, frame):
return frame
class DeflateFrame(Extension):
......@@ -57,49 +51,43 @@ class DeflateFrame(Extension):
name = 'deflate-frame'
rsv1 = True
parameters = ['max_window_bits', 'no_context_takeover']
# FIXME: is this correct?
default_max_window_bits = 32768
# FIXME: is 32768 (below) correct?
defaults = {'max_window_bits': 32768, 'no_context_takeover': True}
def __init__(self, **kwargs):
super(DeflateFrame, self).__init__(**kwargs)
def __init__(self, defaults={}, request={}):
Extension.__init__(self, defaults, request)
if self.max_window_bits is None:
self.max_window_bits = self.default_max_window_bits
elif not isinstance(self.max_window_bits, int):
raise HandshakeError('"max_window_bits" must be an integer')
elif self.max_window_bits > 32768:
raise HandshakeError('"max_window_bits" may not be larger than '
'32768')
mwb = self.defaults['max_window_bits']
cto = self.defaults['no_context_takeover']
if self.no_context_takeover is None:
self.no_context_takeover = False
elif self.no_context_takeover is not True:
raise HandshakeError('"no_context_takeover" must have no value')
if not isinstance(mwb, int):
raise ValueError('"max_window_bits" must be an integer')
elif mwb > 32768:
raise ValueError('"max_window_bits" may not be larger than 32768')
def hook_send(self, frame):
if not frame.rsv1:
frame.rsv1 = True
frame.payload = self.deflate(frame.payload)
if cto is not False and cto is not True:
raise ValueError('"no_context_takeover" must have no value')
return frame
class Hook:
def send(self, frame):
if not frame.rsv1:
frame.rsv1 = True
frame.payload = self.deflate(frame.payload)
def hook_recv(self, frame):
if frame.rsv1:
frame.rsv1 = False
frame.payload = self.inflate(frame.payload)
return frame
return frame
def recv(self, frame):
if frame.rsv1:
frame.rsv1 = False
frame.payload = self.inflate(frame.payload)
def header_params(self):
raise NotImplementedError # TODO
return frame
def deflate(self, data):
raise NotImplementedError # TODO
def deflate(self, data):
raise NotImplementedError # TODO
def inflate(self, data):
raise NotImplementedError # TODO
def inflate(self, data):
raise NotImplementedError # TODO
class Multiplex(Extension):
......@@ -115,21 +103,19 @@ class Multiplex(Extension):
rsv1 = True # FIXME
rsv2 = True # FIXME
rsv3 = True # FIXME
parameters = ['quota']
defaults = {'quota': None}
def __init__(self, **kwargs):
super(Multiplex, self).__init__(**kwargs)
def __init__(self, defaults={}, request={}):
Extension.__init__(self, defaults, request)
# TODO: check "quota" value
def hook_send(self, frame):
raise NotImplementedError # TODO
def hook_recv(self, frame):
raise NotImplementedError # TODO
class Hook:
def send(self, frame):
raise NotImplementedError # TODO
def header_params(self):
raise NotImplementedError # TODO
def recv(self, frame):
raise NotImplementedError # TODO
def filter_extensions(extensions):
......
......@@ -142,14 +142,20 @@ class ServerHandshake(Handshake):
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions)
extensions = []
all_params = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext)
if name in supported_ext:
extensions.append(supported_ext[name](**params))
extensions.append(supported_ext[name])
all_params.append(params)
self.wsock.extensions = filter_extensions(extensions)
for ext, params in zip(self.wsock.extensions, all_params):
hook = ext.Hook(**params)
self.wsock.add_hook(send=hook.send, recv=hook.recv)
else:
self.wsock.extensions = []
......@@ -183,10 +189,11 @@ class ServerHandshake(Handshake):
yield 'Sec-WebSocket-Protocol', self.wsock.protocol
if self.wsock.extensions:
values = [format_param_hdr(e.name, e.header_params())
values = [format_param_hdr(e.name, e.request)
for e in self.wsock.extensions]
yield 'Sec-WebSocket-Extensions', ', '.join(values)
class ClientHandshake(Handshake):
"""
Executes a handshake as the client end point of the socket. May raise a
......@@ -230,7 +237,7 @@ class ClientHandshake(Handshake):
if accept != required_accept:
self.fail('invalid websocket accept header "%s"' % accept)
# Compare extensions
# Compare extensions, add hooks only for those returned by server
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions)
self.wsock.extensions = []
......@@ -242,7 +249,9 @@ class ClientHandshake(Handshake):
raise HandshakeError('server handshake contains '
'unsupported extension "%s"' % name)
self.wsock.extensions.append(supported_ext[name](**params))
hook = supported_ext[name].Hook(**params)
self.wsock.extensions.append(supported_ext[name])
self.wsock.add_hook(send=hook.send, recv=hook.recv)
# Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers:
......@@ -325,7 +334,7 @@ class ClientHandshake(Handshake):
yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
if self.wsock.extensions:
values = [format_param_hdr(e.name, e.header_params())
values = [format_param_hdr(e.name, e.request)
for e in self.wsock.extensions]
yield 'Sec-WebSocket-Extensions', ', '.join(values)
......
......@@ -41,7 +41,7 @@ class websocket(object):
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extension classes.
`extensions` is a list of supported extensions (`Extension` instances).
`origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake .
......@@ -68,6 +68,8 @@ class websocket(object):
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
self.secure = False
self.handshake_sent = False
self.hooks_send = []
self.hooks_recv = []
def bind(self, address):
self.sock.bind(address)
......@@ -104,8 +106,8 @@ class websocket(object):
Send a number of frames.
"""
for frame in args:
for ext in self.extensions:
frame = ext.hook_send(frame)
for hook in self.hooks_send:
frame = hook(frame)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self.sock.sendall(frame.pack())
......@@ -117,8 +119,8 @@ class websocket(object):
"""
frame = receive_frame(self.sock)
for ext in reversed(self.extensions):
frame = ext.hook_recv(frame)
for hook in self.hooks_recv:
frame = hook(frame)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return frame
......@@ -156,3 +158,10 @@ class websocket(object):
self.secure = True
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
def add_hook(self, send=None, recv=None):
if send:
self.hooks_send.append(send)
if recv:
self.hooks_recv.prepend(recv)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment