Commit 7c8972d4 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Removed frame send/recv hooks, now working with a somewhat more robust extensions list

parent caa592ea
...@@ -20,38 +20,33 @@ class DeflateFrame(Extension): ...@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
name = 'deflate-frame' name = 'deflate-frame'
rsv1 = True rsv1 = True
defaults = {'max_window_bits': 15, 'no_context_takeover': False} defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
def __init__(self, defaults={}, request={}): COMPRESSION_THRESHOLD = 64 # minimal payload size for compression
Extension.__init__(self, defaults, request)
def init(self):
mwb = self.defaults['max_window_bits'] mwb = self.defaults['max_window_bits']
cto = self.defaults['no_context_takeover'] cto = self.defaults['no_context_takeover']
if not isinstance(mwb, int): if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
raise ValueError('"max_window_bits" must be an integer') raise ValueError('"max_window_bits" must be in range 1-15')
elif mwb > 15:
raise ValueError('"max_window_bits" may not be larger than 15')
if cto is not False and cto is not True: if cto is not False and cto is not True:
raise ValueError('"no_context_takeover" must have no value') raise ValueError('"no_context_takeover" must have no value')
class Hook(Extension.Hook): class Hook(Extension.Hook):
def __init__(self, extension, **kwargs): def init(self, extension):
Extension.Hook.__init__(self, extension, **kwargs) self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
if not self.no_context_takeover: other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED,
-self.max_window_bits)
other_wbits = self.extension.request.get('max_window_bits', 15)
self.dec = zlib.decompressobj(-other_wbits) self.dec = zlib.decompressobj(-other_wbits)
def send(self, frame): def send(self, frame):
if not frame.rsv1 and not isinstance(frame, ControlFrame): # FIXME: this does not seem to work properly on Android
if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
frame.rsv1 = True frame.rsv1 = True
frame.payload = self.deflate(frame.payload) frame.payload = self.deflate(frame)
return frame return frame
...@@ -65,23 +60,23 @@ class DeflateFrame(Extension): ...@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
return frame return frame
def deflate(self, data): def deflate(self, frame):
if self.no_context_takeover: compressed = self.defl.compress(frame.payload)
defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits) if frame.final or self.no_context_takeover:
# FIXME: why the '\x00' below? This was borrowed from compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
# https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91 self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00' zlib.DEFLATED, -self.max_window_bits)
else:
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
compressed = compressed[:-4]
compressed = self.defl.compress(data) return compressed
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
def inflate(self, data): def inflate(self, data):
data = self.dec.decompress(str(data + '\x00\x00\xff\xff')) return self.dec.decompress(data + '\x00\x00\xff\xff') + \
assert not self.dec.unused_data self.dec.flush(zlib.Z_SYNC_FLUSH)
return data
class WebkitDeflateFrame(DeflateFrame): class WebkitDeflateFrame(DeflateFrame):
......
...@@ -19,23 +19,31 @@ class Extension(object): ...@@ -19,23 +19,31 @@ class Extension(object):
self.request = dict(self.__class__.request) self.request = dict(self.__class__.request)
self.request.update(request) self.request.update(request)
self.init()
def __str__(self): def __str__(self):
return '<Extension "%s" defaults=%s request=%s>' \ return '<Extension "%s" defaults=%s request=%s>' \
% (self.name, self.defaults, self.request) % (self.name, self.defaults, self.request)
def init(self):
return NotImplemented
def create_hook(self, **kwargs): def create_hook(self, **kwargs):
params = {} params = {}
params.update(self.defaults) params.update(self.defaults)
params.update(kwargs) params.update(kwargs)
return self.Hook(self, **params) hook = self.Hook(**params)
hook.init(self)
return hook
class Hook: class Hook:
def __init__(self, extension, **kwargs): def __init__(self, **kwargs):
self.extension = extension
for param, value in kwargs.iteritems(): for param, value in kwargs.iteritems():
setattr(self, param, value) setattr(self, param, value)
def init(self, extension):
return NotImplemented
def send(self, frame): def send(self, frame):
return frame return frame
...@@ -43,28 +51,19 @@ class Extension(object): ...@@ -43,28 +51,19 @@ class Extension(object):
return frame return frame
def filter_extensions(extensions): def extension_conflicts(ext, existing):
"""
Remove extensions that use conflicting rsv bits and/or opcodes, with the
first options being the most preferable.
"""
rsv1_reserved = False rsv1_reserved = False
rsv2_reserved = False rsv2_reserved = False
rsv3_reserved = False rsv3_reserved = False
opcodes_reserved = [] reserved_opcodes = []
compat = []
for ext in extensions:
if ext.rsv1 and rsv1_reserved \
or ext.rsv2 and rsv2_reserved \
or ext.rsv3 and rsv3_reserved \
or len(set(ext.opcodes) & set(opcodes_reserved)):
continue
rsv1_reserved |= ext.rsv1 for e in existing:
rsv2_reserved |= ext.rsv2 rsv1_reserved |= e.rsv1
rsv3_reserved |= ext.rsv3 rsv2_reserved |= e.rsv2
opcodes_reserved.extend(ext.opcodes) rsv3_reserved |= e.rsv3
compat.append(ext) reserved_opcodes.extend(e.opcodes)
return compat return ext.rsv1 and rsv1_reserved \
or ext.rsv2 and rsv2_reserved \
or ext.rsv3 and rsv3_reserved \
or len(set(ext.opcodes) & set(reserved_opcodes))
...@@ -7,7 +7,7 @@ from hashlib import sha1 ...@@ -7,7 +7,7 @@ from hashlib import sha1
from urlparse import urlparse from urlparse import urlparse
from errors import HandshakeError from errors import HandshakeError
from extension import filter_extensions from extension import extension_conflicts
from python_digest import build_authorization_request from python_digest import build_authorization_request
...@@ -173,23 +173,19 @@ class ServerHandshake(Handshake): ...@@ -173,23 +173,19 @@ class ServerHandshake(Handshake):
# Only supported extensions are returned # Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in ssock.extensions) supported_ext = dict((e.name, e) for e in ssock.extensions)
self.wsock.extension_hooks = []
extensions = [] extensions = []
all_params = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']): for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext) name, params = parse_param_hdr(ext)
if name in supported_ext: if name in supported_ext:
extensions.append(supported_ext[name]) ext = supported_ext[name]
all_params.append(params)
self.wsock.extensions = filter_extensions(extensions) if not extension_conflicts(ext, extensions):
extensions.append(ext)
for ext, params in zip(self.wsock.extensions, all_params): hook = ext.create_hook(**params)
hook = ext.create_hook(**params) self.wsock.extension_hooks.append(hook)
self.wsock.add_hook(send=hook.send, recv=hook.recv)
else:
self.wsock.extensions = []
# Check if requested resource location is served by this server # Check if requested resource location is served by this server
if ssock.locations: if ssock.locations:
...@@ -278,7 +274,7 @@ class ClientHandshake(Handshake): ...@@ -278,7 +274,7 @@ class ClientHandshake(Handshake):
# Compare extensions, add hooks only for those returned by server # Compare extensions, add hooks only for those returned by server
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions) supported_ext = dict((e.name, e) for e in self.wsock.extensions)
self.wsock.extensions = [] self.wsock.extension_hooks = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']): for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext) name, params = parse_param_hdr(ext)
...@@ -288,8 +284,7 @@ class ClientHandshake(Handshake): ...@@ -288,8 +284,7 @@ class ClientHandshake(Handshake):
'unsupported extension "%s"' % name) 'unsupported extension "%s"' % name)
hook = supported_ext[name].create_hook(**params) hook = supported_ext[name].create_hook(**params)
self.wsock.extensions.append(supported_ext[name]) self.wsock.extension_hooks.append(hook)
self.wsock.add_hook(send=hook.send, recv=hook.recv)
# Assert that returned protocol (if any) is supported # Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers: if 'Sec-WebSocket-Protocol' in headers:
......
...@@ -35,7 +35,7 @@ class websocket(object): ...@@ -35,7 +35,7 @@ class websocket(object):
>>> sock.connect(('', 8000)) >>> sock.connect(('', 8000))
>>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!')) >>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
""" """
def __init__(self, sock=None, protocols=[], extensions=[], origin=None, def __init__(self, sock=None, origin=None, protocols=[], extensions=[],
location='/', trusted_origins=[], locations=[], auth=None, location='/', trusted_origins=[], locations=[], auth=None,
recv_callback=None, sfamily=socket.AF_INET, sproto=0): recv_callback=None, sfamily=socket.AF_INET, sproto=0):
""" """
...@@ -44,13 +44,14 @@ class websocket(object): ...@@ -44,13 +44,14 @@ class websocket(object):
`sock` is an optional regular TCP socket to be used for sending binary `sock` is an optional regular TCP socket to be used for sending binary
data. If not specified, a new socket is created. data. If not specified, a new socket is created.
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extensions (`Extension` instances).
`origin` (for client sockets) is the value for the "Origin" header sent `origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake . in a client handshake .
`protocols` is a list of supported protocol names.
`extensions` (for server sockets) is a list of supported extensions
(`Extension` instances).
`location` (for client sockets) is optional, used to request a `location` (for client sockets) is optional, used to request a
particular resource in the HTTP handshake. In a URL, this would show as particular resource in the HTTP handshake. In a URL, this would show as
ws://host[:port]/<location>. Use this when the server serves multiple ws://host[:port]/<location>. Use this when the server serves multiple
...@@ -80,6 +81,7 @@ class websocket(object): ...@@ -80,6 +81,7 @@ class websocket(object):
""" """
self.protocols = protocols self.protocols = protocols
self.extensions = extensions self.extensions = extensions
self.extension_hooks = []
self.origin = origin self.origin = origin
self.location = location self.location = location
self.trusted_origins = trusted_origins self.trusted_origins = trusted_origins
...@@ -90,9 +92,6 @@ class websocket(object): ...@@ -90,9 +92,6 @@ class websocket(object):
self.handshake_sent = False self.handshake_sent = False
self.hooks_send = []
self.hooks_recv = []
self.sendbuf_frames = [] self.sendbuf_frames = []
self.sendbuf = '' self.sendbuf = ''
self.recvbuf = '' self.recvbuf = ''
...@@ -100,6 +99,9 @@ class websocket(object): ...@@ -100,6 +99,9 @@ class websocket(object):
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto) self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
def set_extensions(self, extensions):
self.extensions = [ext.Hook() for ext in extensions]
def __getattr__(self, name): def __getattr__(self, name):
if name in INHERITED_ATTRS: if name in INHERITED_ATTRS:
return getattr(self.sock, name) return getattr(self.sock, name)
...@@ -132,29 +134,31 @@ class websocket(object): ...@@ -132,29 +134,31 @@ class websocket(object):
ClientHandshake(self).perform() ClientHandshake(self).perform()
self.handshake_sent = True self.handshake_sent = True
def apply_send_hooks(self, frame):
for hook in self.extension_hooks:
frame = hook.send(frame)
return frame
def apply_recv_hooks(self, frame):
for hook in reversed(self.extension_hooks):
frame = hook.recv(frame)
return frame
def send(self, *args): def send(self, *args):
""" """
Send a number of frames. Send a number of frames.
""" """
for frame in args: for frame in args:
for hook in self.hooks_send: self.sock.sendall(self.apply_send_hooks(frame).pack())
frame = hook(frame)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self.sock.sendall(frame.pack())
def recv(self): def recv(self):
""" """
Receive a single frames. This can be either a data frame or a control Receive a single frames. This can be either a data frame or a control
frame. frame.
""" """
frame = receive_frame(self.sock) return self.apply_recv_hooks(receive_frame(self.sock))
for hook in self.hooks_recv:
frame = hook(frame)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return frame
def recvn(self, n): def recvn(self, n):
""" """
...@@ -170,9 +174,7 @@ class websocket(object): ...@@ -170,9 +174,7 @@ class websocket(object):
frame has been fully written. `recv_callback` is an optional callable frame has been fully written. `recv_callback` is an optional callable
to quickly set the `recv_callback` attribute to. to quickly set the `recv_callback` attribute to.
""" """
for hook in self.hooks_send: frame = self.apply_send_hooks(frame)
frame = hook(frame)
self.sendbuf += frame.pack() self.sendbuf += frame.pack()
self.sendbuf_frames.append([frame, len(self.sendbuf), callback]) self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
...@@ -181,7 +183,8 @@ class websocket(object): ...@@ -181,7 +183,8 @@ class websocket(object):
def do_async_send(self): def do_async_send(self):
""" """
Send any queued data. Send any queued data. This function should only be called after a write
event on a file descriptor.
""" """
assert len(self.sendbuf) assert len(self.sendbuf)
...@@ -204,6 +207,8 @@ class websocket(object): ...@@ -204,6 +207,8 @@ class websocket(object):
def do_async_recv(self, bufsize): def do_async_recv(self, bufsize):
""" """
Receive any completed frames from the socket. This function should only
be called after a read event on a file descriptor.
""" """
data = self.sock.recv(bufsize) data = self.sock.recv(bufsize)
...@@ -214,6 +219,7 @@ class websocket(object): ...@@ -214,6 +219,7 @@ class websocket(object):
while contains_frame(self.recvbuf): while contains_frame(self.recvbuf):
frame, self.recvbuf = pop_frame(self.recvbuf) frame, self.recvbuf = pop_frame(self.recvbuf)
frame = self.apply_recv_hooks(frame)
if not self.recv_callback: if not self.recv_callback:
raise ValueError('no callback installed for %s' % frame) raise ValueError('no callback installed for %s' % frame)
...@@ -237,37 +243,3 @@ class websocket(object): ...@@ -237,37 +243,3 @@ class websocket(object):
self.secure = True self.secure = True
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs) self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
def add_hook(self, send=None, recv=None, prepend=False):
"""
Add a pair of send and receive hooks that are called for each frame
that is sent or received. A hook is a function that receives a single
argument - a Frame instance - and returns a `Frame` instance as well.
`prepend` is a flag indicating whether the send hook is prepended to
the other send hooks. This is expecially useful when a program uses
extensions such as the built-in `DeflateFrame` extension. These
extensions are installed using these hooks as well.
For example, the following code creates a `Frame` instance for data
being sent and removes the instance for received data. This way, data
can be sent and received as if on a regular socket.
>>> import wspy
>>> sock = wspy.websocket()
>>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
>>> lambda frame: frame.payload)
To add base64 encoding to the example above:
>>> import base64
>>> sock.add_hook(base64.encodestring, base64.decodestring, True)
Note that here `prepend=True`, so that data passed to `send()` is first
encoded and then packed into a frame. Of course, one could also decide
to add the base64 hook first, or to return a new `Frame` instance with
base64-encoded data.
"""
if send:
self.hooks_send.insert(0 if prepend else -1, send)
if recv:
self.hooks_recv.insert(-1 if prepend else 0, recv)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment