Commit 7c8972d4 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Removed frame send/recv hooks, now working with a somewhat more robust extensions list

parent caa592ea
......@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
name = 'deflate-frame'
rsv1 = True
defaults = {'max_window_bits': 15, 'no_context_takeover': False}
defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
def __init__(self, defaults={}, request={}):
Extension.__init__(self, defaults, request)
COMPRESSION_THRESHOLD = 64 # minimal payload size for compression
def init(self):
mwb = self.defaults['max_window_bits']
cto = self.defaults['no_context_takeover']
if not isinstance(mwb, int):
raise ValueError('"max_window_bits" must be an integer')
elif mwb > 15:
raise ValueError('"max_window_bits" may not be larger than 15')
if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
raise ValueError('"max_window_bits" must be in range 1-15')
if cto is not False and cto is not True:
raise ValueError('"no_context_takeover" must have no value')
class Hook(Extension.Hook):
def __init__(self, extension, **kwargs):
Extension.Hook.__init__(self, extension, **kwargs)
if not self.no_context_takeover:
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED,
-self.max_window_bits)
other_wbits = self.extension.request.get('max_window_bits', 15)
def init(self, extension):
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
self.dec = zlib.decompressobj(-other_wbits)
def send(self, frame):
if not frame.rsv1 and not isinstance(frame, ControlFrame):
# FIXME: this does not seem to work properly on Android
if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
frame.rsv1 = True
frame.payload = self.deflate(frame.payload)
frame.payload = self.deflate(frame)
return frame
......@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
return frame
def deflate(self, data):
if self.no_context_takeover:
defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
# FIXME: why the '\x00' below? This was borrowed from
# https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
def deflate(self, frame):
compressed = self.defl.compress(frame.payload)
if frame.final or self.no_context_takeover:
compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
else:
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
compressed = compressed[:-4]
compressed = self.defl.compress(data)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
return compressed
def inflate(self, data):
data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
assert not self.dec.unused_data
return data
return self.dec.decompress(data + '\x00\x00\xff\xff') + \
self.dec.flush(zlib.Z_SYNC_FLUSH)
class WebkitDeflateFrame(DeflateFrame):
......
......@@ -19,23 +19,31 @@ class Extension(object):
self.request = dict(self.__class__.request)
self.request.update(request)
self.init()
def __str__(self):
return '<Extension "%s" defaults=%s request=%s>' \
% (self.name, self.defaults, self.request)
def init(self):
return NotImplemented
def create_hook(self, **kwargs):
params = {}
params.update(self.defaults)
params.update(kwargs)
return self.Hook(self, **params)
hook = self.Hook(**params)
hook.init(self)
return hook
class Hook:
def __init__(self, extension, **kwargs):
self.extension = extension
def __init__(self, **kwargs):
for param, value in kwargs.iteritems():
setattr(self, param, value)
def init(self, extension):
return NotImplemented
def send(self, frame):
return frame
......@@ -43,28 +51,19 @@ class Extension(object):
return frame
def filter_extensions(extensions):
"""
Remove extensions that use conflicting rsv bits and/or opcodes, with the
first options being the most preferable.
"""
def extension_conflicts(ext, existing):
rsv1_reserved = False
rsv2_reserved = False
rsv3_reserved = False
opcodes_reserved = []
compat = []
for ext in extensions:
if ext.rsv1 and rsv1_reserved \
or ext.rsv2 and rsv2_reserved \
or ext.rsv3 and rsv3_reserved \
or len(set(ext.opcodes) & set(opcodes_reserved)):
continue
reserved_opcodes = []
rsv1_reserved |= ext.rsv1
rsv2_reserved |= ext.rsv2
rsv3_reserved |= ext.rsv3
opcodes_reserved.extend(ext.opcodes)
compat.append(ext)
for e in existing:
rsv1_reserved |= e.rsv1
rsv2_reserved |= e.rsv2
rsv3_reserved |= e.rsv3
reserved_opcodes.extend(e.opcodes)
return compat
return ext.rsv1 and rsv1_reserved \
or ext.rsv2 and rsv2_reserved \
or ext.rsv3 and rsv3_reserved \
or len(set(ext.opcodes) & set(reserved_opcodes))
......@@ -7,7 +7,7 @@ from hashlib import sha1
from urlparse import urlparse
from errors import HandshakeError
from extension import filter_extensions
from extension import extension_conflicts
from python_digest import build_authorization_request
......@@ -173,23 +173,19 @@ class ServerHandshake(Handshake):
# Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in ssock.extensions)
self.wsock.extension_hooks = []
extensions = []
all_params = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext)
if name in supported_ext:
extensions.append(supported_ext[name])
all_params.append(params)
ext = supported_ext[name]
self.wsock.extensions = filter_extensions(extensions)
for ext, params in zip(self.wsock.extensions, all_params):
hook = ext.create_hook(**params)
self.wsock.add_hook(send=hook.send, recv=hook.recv)
else:
self.wsock.extensions = []
if not extension_conflicts(ext, extensions):
extensions.append(ext)
hook = ext.create_hook(**params)
self.wsock.extension_hooks.append(hook)
# Check if requested resource location is served by this server
if ssock.locations:
......@@ -278,7 +274,7 @@ class ClientHandshake(Handshake):
# Compare extensions, add hooks only for those returned by server
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions)
self.wsock.extensions = []
self.wsock.extension_hooks = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext)
......@@ -288,8 +284,7 @@ class ClientHandshake(Handshake):
'unsupported extension "%s"' % name)
hook = supported_ext[name].create_hook(**params)
self.wsock.extensions.append(supported_ext[name])
self.wsock.add_hook(send=hook.send, recv=hook.recv)
self.wsock.extension_hooks.append(hook)
# Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers:
......
......@@ -35,7 +35,7 @@ class websocket(object):
>>> sock.connect(('', 8000))
>>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
"""
def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
def __init__(self, sock=None, origin=None, protocols=[], extensions=[],
location='/', trusted_origins=[], locations=[], auth=None,
recv_callback=None, sfamily=socket.AF_INET, sproto=0):
"""
......@@ -44,13 +44,14 @@ class websocket(object):
`sock` is an optional regular TCP socket to be used for sending binary
data. If not specified, a new socket is created.
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extensions (`Extension` instances).
`origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake .
`protocols` is a list of supported protocol names.
`extensions` (for server sockets) is a list of supported extensions
(`Extension` instances).
`location` (for client sockets) is optional, used to request a
particular resource in the HTTP handshake. In a URL, this would show as
ws://host[:port]/<location>. Use this when the server serves multiple
......@@ -80,6 +81,7 @@ class websocket(object):
"""
self.protocols = protocols
self.extensions = extensions
self.extension_hooks = []
self.origin = origin
self.location = location
self.trusted_origins = trusted_origins
......@@ -90,9 +92,6 @@ class websocket(object):
self.handshake_sent = False
self.hooks_send = []
self.hooks_recv = []
self.sendbuf_frames = []
self.sendbuf = ''
self.recvbuf = ''
......@@ -100,6 +99,9 @@ class websocket(object):
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
def set_extensions(self, extensions):
self.extensions = [ext.Hook() for ext in extensions]
def __getattr__(self, name):
if name in INHERITED_ATTRS:
return getattr(self.sock, name)
......@@ -132,29 +134,31 @@ class websocket(object):
ClientHandshake(self).perform()
self.handshake_sent = True
def apply_send_hooks(self, frame):
for hook in self.extension_hooks:
frame = hook.send(frame)
return frame
def apply_recv_hooks(self, frame):
for hook in reversed(self.extension_hooks):
frame = hook.recv(frame)
return frame
def send(self, *args):
"""
Send a number of frames.
"""
for frame in args:
for hook in self.hooks_send:
frame = hook(frame)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self.sock.sendall(frame.pack())
self.sock.sendall(self.apply_send_hooks(frame).pack())
def recv(self):
"""
Receive a single frames. This can be either a data frame or a control
frame.
"""
frame = receive_frame(self.sock)
for hook in self.hooks_recv:
frame = hook(frame)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return frame
return self.apply_recv_hooks(receive_frame(self.sock))
def recvn(self, n):
"""
......@@ -170,9 +174,7 @@ class websocket(object):
frame has been fully written. `recv_callback` is an optional callable
to quickly set the `recv_callback` attribute to.
"""
for hook in self.hooks_send:
frame = hook(frame)
frame = self.apply_send_hooks(frame)
self.sendbuf += frame.pack()
self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
......@@ -181,7 +183,8 @@ class websocket(object):
def do_async_send(self):
"""
Send any queued data.
Send any queued data. This function should only be called after a write
event on a file descriptor.
"""
assert len(self.sendbuf)
......@@ -204,6 +207,8 @@ class websocket(object):
def do_async_recv(self, bufsize):
"""
Receive any completed frames from the socket. This function should only
be called after a read event on a file descriptor.
"""
data = self.sock.recv(bufsize)
......@@ -214,6 +219,7 @@ class websocket(object):
while contains_frame(self.recvbuf):
frame, self.recvbuf = pop_frame(self.recvbuf)
frame = self.apply_recv_hooks(frame)
if not self.recv_callback:
raise ValueError('no callback installed for %s' % frame)
......@@ -237,37 +243,3 @@ class websocket(object):
self.secure = True
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
def add_hook(self, send=None, recv=None, prepend=False):
"""
Add a pair of send and receive hooks that are called for each frame
that is sent or received. A hook is a function that receives a single
argument - a Frame instance - and returns a `Frame` instance as well.
`prepend` is a flag indicating whether the send hook is prepended to
the other send hooks. This is expecially useful when a program uses
extensions such as the built-in `DeflateFrame` extension. These
extensions are installed using these hooks as well.
For example, the following code creates a `Frame` instance for data
being sent and removes the instance for received data. This way, data
can be sent and received as if on a regular socket.
>>> import wspy
>>> sock = wspy.websocket()
>>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
>>> lambda frame: frame.payload)
To add base64 encoding to the example above:
>>> import base64
>>> sock.add_hook(base64.encodestring, base64.decodestring, True)
Note that here `prepend=True`, so that data passed to `send()` is first
encoded and then packed into a frame. Of course, one could also decide
to add the base64 hook first, or to return a new `Frame` instance with
base64-encoded data.
"""
if send:
self.hooks_send.insert(0 if prepend else -1, send)
if recv:
self.hooks_recv.insert(-1 if prepend else 0, recv)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment