Commit e465862f authored by Taddeüs Kroes's avatar Taddeüs Kroes

Revised extension instantiation, now 'hooks' are installed which are cleaner and more flexible

parent 6efb8807
from errors import HandshakeError
class Extension(object): class Extension(object):
name = '' name = ''
rsv1 = False rsv1 = False
rsv2 = False rsv2 = False
rsv3 = False rsv3 = False
opcodes = [] opcodes = []
parameters = [] defaults = {}
request = {}
def __init__(self, **kwargs): def __init__(self, defaults={}, request={}):
for param in self.parameters: for param in defaults.keys() + request.keys():
setattr(self, param, None) if param not in self.defaults:
raise KeyError('unrecognized parameter "%s"' % param)
for param, value in kwargs.items(): # Copy dict first to avoid duplicate references to the same object
if param not in self.parameters: self.defaults = dict(self.__class__.defaults)
raise HandshakeError('unrecognized parameter "%s"' % param) self.defaults.update(defaults)
if value is None: self.request = dict(self.__class__.request)
value = True self.request.update(request)
setattr(self, param, value)
def __str__(self, frame): def __str__(self, frame):
if len(self.parameters): return '<Extension "%s" defaults=%s request=%s>' \
params = ' ' + ', '.join(p + '=' + str(getattr(self, p)) % (self.name, self.defaults, self.request)
for p in self.parameters)
else:
params = ''
return '<Extension "%s"%s>' % (self.name, params) class Hook:
def __init__(self, **kwargs):
def header_params(self, frame): for param, value in kwargs.iteritems():
return {} setattr(self, param, value)
def hook_send(self, frame): def send(self, frame):
return frame return frame
def hook_receive(self, frame): def recv(self, frame):
return frame return frame
...@@ -57,44 +51,38 @@ class DeflateFrame(Extension): ...@@ -57,44 +51,38 @@ class DeflateFrame(Extension):
name = 'deflate-frame' name = 'deflate-frame'
rsv1 = True rsv1 = True
parameters = ['max_window_bits', 'no_context_takeover'] # FIXME: is 32768 (below) correct?
defaults = {'max_window_bits': 32768, 'no_context_takeover': True}
# FIXME: is this correct? def __init__(self, defaults={}, request={}):
default_max_window_bits = 32768 Extension.__init__(self, defaults, request)
def __init__(self, **kwargs): mwb = self.defaults['max_window_bits']
super(DeflateFrame, self).__init__(**kwargs) cto = self.defaults['no_context_takeover']
if self.max_window_bits is None: if not isinstance(mwb, int):
self.max_window_bits = self.default_max_window_bits raise ValueError('"max_window_bits" must be an integer')
elif not isinstance(self.max_window_bits, int): elif mwb > 32768:
raise HandshakeError('"max_window_bits" must be an integer') raise ValueError('"max_window_bits" may not be larger than 32768')
elif self.max_window_bits > 32768:
raise HandshakeError('"max_window_bits" may not be larger than ' if cto is not False and cto is not True:
'32768') raise ValueError('"no_context_takeover" must have no value')
if self.no_context_takeover is None: class Hook:
self.no_context_takeover = False def send(self, frame):
elif self.no_context_takeover is not True:
raise HandshakeError('"no_context_takeover" must have no value')
def hook_send(self, frame):
if not frame.rsv1: if not frame.rsv1:
frame.rsv1 = True frame.rsv1 = True
frame.payload = self.deflate(frame.payload) frame.payload = self.deflate(frame.payload)
return frame return frame
def hook_recv(self, frame): def recv(self, frame):
if frame.rsv1: if frame.rsv1:
frame.rsv1 = False frame.rsv1 = False
frame.payload = self.inflate(frame.payload) frame.payload = self.inflate(frame.payload)
return frame return frame
def header_params(self):
raise NotImplementedError # TODO
def deflate(self, data): def deflate(self, data):
raise NotImplementedError # TODO raise NotImplementedError # TODO
...@@ -115,20 +103,18 @@ class Multiplex(Extension): ...@@ -115,20 +103,18 @@ class Multiplex(Extension):
rsv1 = True # FIXME rsv1 = True # FIXME
rsv2 = True # FIXME rsv2 = True # FIXME
rsv3 = True # FIXME rsv3 = True # FIXME
parameters = ['quota'] defaults = {'quota': None}
def __init__(self, **kwargs): def __init__(self, defaults={}, request={}):
super(Multiplex, self).__init__(**kwargs) Extension.__init__(self, defaults, request)
# TODO: check "quota" value # TODO: check "quota" value
def hook_send(self, frame): class Hook:
raise NotImplementedError # TODO def send(self, frame):
def hook_recv(self, frame):
raise NotImplementedError # TODO raise NotImplementedError # TODO
def header_params(self): def recv(self, frame):
raise NotImplementedError # TODO raise NotImplementedError # TODO
......
...@@ -142,14 +142,20 @@ class ServerHandshake(Handshake): ...@@ -142,14 +142,20 @@ class ServerHandshake(Handshake):
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions) supported_ext = dict((e.name, e) for e in self.wsock.extensions)
extensions = [] extensions = []
all_params = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']): for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext) name, params = parse_param_hdr(ext)
if name in supported_ext: if name in supported_ext:
extensions.append(supported_ext[name](**params)) extensions.append(supported_ext[name])
all_params.append(params)
self.wsock.extensions = filter_extensions(extensions) self.wsock.extensions = filter_extensions(extensions)
for ext, params in zip(self.wsock.extensions, all_params):
hook = ext.Hook(**params)
self.wsock.add_hook(send=hook.send, recv=hook.recv)
else: else:
self.wsock.extensions = [] self.wsock.extensions = []
...@@ -183,10 +189,11 @@ class ServerHandshake(Handshake): ...@@ -183,10 +189,11 @@ class ServerHandshake(Handshake):
yield 'Sec-WebSocket-Protocol', self.wsock.protocol yield 'Sec-WebSocket-Protocol', self.wsock.protocol
if self.wsock.extensions: if self.wsock.extensions:
values = [format_param_hdr(e.name, e.header_params()) values = [format_param_hdr(e.name, e.request)
for e in self.wsock.extensions] for e in self.wsock.extensions]
yield 'Sec-WebSocket-Extensions', ', '.join(values) yield 'Sec-WebSocket-Extensions', ', '.join(values)
class ClientHandshake(Handshake): class ClientHandshake(Handshake):
""" """
Executes a handshake as the client end point of the socket. May raise a Executes a handshake as the client end point of the socket. May raise a
...@@ -230,7 +237,7 @@ class ClientHandshake(Handshake): ...@@ -230,7 +237,7 @@ class ClientHandshake(Handshake):
if accept != required_accept: if accept != required_accept:
self.fail('invalid websocket accept header "%s"' % accept) self.fail('invalid websocket accept header "%s"' % accept)
# Compare extensions # Compare extensions, add hooks only for those returned by server
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions) supported_ext = dict((e.name, e) for e in self.wsock.extensions)
self.wsock.extensions = [] self.wsock.extensions = []
...@@ -242,7 +249,9 @@ class ClientHandshake(Handshake): ...@@ -242,7 +249,9 @@ class ClientHandshake(Handshake):
raise HandshakeError('server handshake contains ' raise HandshakeError('server handshake contains '
'unsupported extension "%s"' % name) 'unsupported extension "%s"' % name)
self.wsock.extensions.append(supported_ext[name](**params)) hook = supported_ext[name].Hook(**params)
self.wsock.extensions.append(supported_ext[name])
self.wsock.add_hook(send=hook.send, recv=hook.recv)
# Assert that returned protocol (if any) is supported # Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers: if 'Sec-WebSocket-Protocol' in headers:
...@@ -325,7 +334,7 @@ class ClientHandshake(Handshake): ...@@ -325,7 +334,7 @@ class ClientHandshake(Handshake):
yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols) yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
if self.wsock.extensions: if self.wsock.extensions:
values = [format_param_hdr(e.name, e.header_params()) values = [format_param_hdr(e.name, e.request)
for e in self.wsock.extensions] for e in self.wsock.extensions]
yield 'Sec-WebSocket-Extensions', ', '.join(values) yield 'Sec-WebSocket-Extensions', ', '.join(values)
......
...@@ -41,7 +41,7 @@ class websocket(object): ...@@ -41,7 +41,7 @@ class websocket(object):
`protocols` is a list of supported protocol names. `protocols` is a list of supported protocol names.
`extensions` is a list of supported extension classes. `extensions` is a list of supported extensions (`Extension` instances).
`origin` (for client sockets) is the value for the "Origin" header sent `origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake . in a client handshake .
...@@ -68,6 +68,8 @@ class websocket(object): ...@@ -68,6 +68,8 @@ class websocket(object):
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto) self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
self.secure = False self.secure = False
self.handshake_sent = False self.handshake_sent = False
self.hooks_send = []
self.hooks_recv = []
def bind(self, address): def bind(self, address):
self.sock.bind(address) self.sock.bind(address)
...@@ -104,8 +106,8 @@ class websocket(object): ...@@ -104,8 +106,8 @@ class websocket(object):
Send a number of frames. Send a number of frames.
""" """
for frame in args: for frame in args:
for ext in self.extensions: for hook in self.hooks_send:
frame = ext.hook_send(frame) frame = hook(frame)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername() #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self.sock.sendall(frame.pack()) self.sock.sendall(frame.pack())
...@@ -117,8 +119,8 @@ class websocket(object): ...@@ -117,8 +119,8 @@ class websocket(object):
""" """
frame = receive_frame(self.sock) frame = receive_frame(self.sock)
for ext in reversed(self.extensions): for hook in self.hooks_recv:
frame = ext.hook_recv(frame) frame = hook(frame)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername() #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return frame return frame
...@@ -156,3 +158,10 @@ class websocket(object): ...@@ -156,3 +158,10 @@ class websocket(object):
self.secure = True self.secure = True
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs) self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
def add_hook(self, send=None, recv=None):
if send:
self.hooks_send.append(send)
if recv:
self.hooks_recv.prepend(recv)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment