Commit 5ed32daf authored by Taddeüs Kroes's avatar Taddeüs Kroes

Added permessage-deflate + lots of general debug and cleanup

parent b2e1f4b3
......@@ -20,8 +20,10 @@ Her is a quick overview of the features in this library:
- HTTP authentication during handshake.
- An extendible server implementation.
- Secure sockets using SSL certificates (for 'wss://...' URLs).
- The possibility to add extensions to the web socket protocol. An included
implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
- An API for implementing WebSocket extensions. Included implementations are
[deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06)
and
[permessage-deflate](http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-17).
- Asynchronous sockets with an EPOLL-based server.
......@@ -112,16 +114,19 @@ Basic usage
conn.send(msg(foo='Hello, World!'))
Built-in Server
===============
Built-in servers
================
The built-in `Server` implementation is very basic. It starts a new thread with
a `Connection.receive_forever()` loop for each client that connects. It also
Threaded
--------
The `Server` class is very basic. It starts a new thread with a
`Connection.receive_forever()` loop for each client that connects. It also
handles client crashes properly. By default, a `Server` instance only logs
every event using Python's `logging` module. To create a custom server, The
`Server` class should be extended and its event handlers overwritten. The event
handlers are named identically to the `Connection` event handlers, but they
also receive an additional `client` argument. This argument is a modified
also receive an additional `client` argument. The client argumetn is a modified
`Connection` instance, so you can invoke `send()` and `recv()`.
For example, the `EchoConnection` example above can be rewritten to:
......@@ -144,6 +149,16 @@ For example, the `EchoConnection` example above can be rewritten to:
The server can be stopped by typing CTRL-C in the command line. The
`KeyboardInterrupt` raised when this happens is caught by the server.
Asynchronous
------------
The `AsyncServer` class has the same API as `Server`, but uses
[EPOLL](https://docs.python.org/2/library/select.html#epoll-objects) instead of
threads. This means that when you send a message, it is put into a queue to be
sent later when the socket is ready. The client argument is againa modified
`Connection` instance, with a non-blocking `send()` method (`recv` is still
blocking, use the server's `onmessage` callback instead).
Extensions
==========
......
......@@ -11,4 +11,5 @@ from message import Message, TextMessage, BinaryMessage
from errors import SocketClosed, HandshakeError, PingError, SSLError
from extension import Extension
from deflate_frame import DeflateFrame
from deflate_message import DeflateMessage
from async import AsyncConnection, AsyncServer
......@@ -54,14 +54,13 @@ class Connection(object):
self.onopen()
def message_to_frames(self, message, fragment_size=None, mask=False):
for hook in self.hooks_send:
message = hook(message)
frame = self.sock.apply_send_hooks(message.frame(mask=mask), True)
if fragment_size is None:
yield message.frame(mask=mask)
yield frame
else:
for frame in message.fragment(fragment_size, mask=mask):
yield frame
for fragment in frame.fragment(fragment_size):
yield fragment
def send(self, message, fragment_size=None, mask=False):
"""
......@@ -101,17 +100,14 @@ class Connection(object):
return self.concat_fragments(fragments)
def concat_fragments(self, fragments):
payload = bytearray()
frame = fragments[0]
for f in fragments:
payload += f.payload
for f in fragments[1:]:
frame.payload += f.payload
message = create_message(fragments[0].opcode, payload)
for hook in self.hooks_recv:
message = hook(message)
return message
frame.final = True
frame = self.sock.apply_recv_hooks(frame, True)
return create_message(frame.opcode, frame.payload)
def handle_control_frame(self, frame):
"""
......@@ -207,36 +203,6 @@ class Connection(object):
self.handle_control_frame(frame)
def add_hook(self, send=None, recv=None, prepend=False):
"""
Add a pair of send and receive hooks that are called for each frame
that is sent or received. A hook is a function that receives a single
argument - a Message instance - and returns a `Message` instance as
well.
`prepend` is a flag indicating whether the send hook is prepended to
the other send hooks.
For example, to add an automatic JSON conversion to messages and
eliminate the need to contruct TextMessage instances to all messages:
>>> import wspy, json
>>> conn = Connection(...)
>>> conn.add_hook(lambda data: tswpy.TextMessage(json.dumps(data)),
>>> lambda message: json.loads(message.payload))
>>> conn.send({'foo': 'bar'}) # Sends text message {"foo":"bar"}
>>> conn.recv() # May be dict(foo='bar')
Note that here `prepend=True`, so that data passed to `send()` is first
encoded and then packed into a frame. Of course, one could also decide
to add the base64 hook first, or to return a new `Frame` instance with
base64-encoded data.
"""
if send:
self.hooks_send.insert(0 if prepend else -1, send)
if recv:
self.hooks_recv.insert(-1 if prepend else 0, recv)
def onopen(self):
"""
Called after the connection is initialized.
......
......@@ -24,7 +24,7 @@ class DeflateFrame(Extension):
'no_context_takeover': False
}
compression_threshold = 64 # minimal payload size for compression
compression_threshold = 20 # minimal payload size for compression
def negotiate(self, name, params):
if 'max_window_bits' in params:
......@@ -43,13 +43,16 @@ class DeflateFrame(Extension):
zlib.DEFLATED, -self.max_window_bits)
self.dec = zlib.decompressobj(-self.max_window_bits)
def onsend_frame(self, frame):
def onsend(self, frame):
if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
len(frame.payload) > self.extension.compression_threshold:
frame.rsv1 = True
frame.payload = self.deflate(frame)
deflated = self.deflate(frame.payload)
def onrecv_frame(self, frame):
if len(deflated) < len(frame.payload):
frame.rsv1 = True
frame.payload = deflated
def onrecv(self, frame):
if frame.rsv1:
if isinstance(frame, ControlFrame):
raise ValueError('received compressed control frame')
......@@ -57,13 +60,13 @@ class DeflateFrame(Extension):
frame.rsv1 = False
frame.payload = self.inflate(frame.payload)
def deflate(self, frame):
def deflate(self, data):
if self.no_context_takeover:
compressed = zlib.compress(frame.payload)
else:
compressed = self.defl.compress(frame.payload)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
compressed = self.defl.compress(data)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
......@@ -71,7 +74,6 @@ class DeflateFrame(Extension):
data = str(data + '\x00\x00\xff\xff')
if self.no_context_takeover:
dec = zlib.decompressobj(-self.max_window_bits)
return dec.decompress(data) + dec.flush()
self.dec = zlib.decompressobj(-self.max_window_bits)
return self.dec.decompress(data)
import zlib
from extension import Extension
from deflate_frame import DeflateFrame
class DeflateMessage(Extension):
"""
Implementation of the "permessage-deflate" extension, as defined by
http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-17.
Note: this implementetion is only eligible for server sockets, client
sockets must NOT use it.
"""
name = 'permessage-deflate'
rsv1 = True
defaults = {
'client_max_window_bits': zlib.MAX_WBITS,
'client_no_context_takeover': False,
'server_max_window_bits': zlib.MAX_WBITS,
'server_no_context_takeover': False
}
before_fragmentation = True
compression_threshold = 20 # minimal message payload size for compression
def negotiate(self, name, params):
default = self.defaults['client_max_window_bits']
if 'client_max_window_bits' in params:
mwb = params['client_max_window_bits']
if mwb is True:
if default != zlib.MAX_WBITS:
yield 'client_max_window_bits', default
else:
mwb = int(mwb)
assert 8 <= mwb <= zlib.MAX_WBITS
yield 'client_max_window_bits', min(mwb, default)
elif default != zlib.MAX_WBITS:
yield 'client_max_window_bits', default
if 'client_no_context_takeover' in params:
assert params['client_no_context_takeover'] is True
yield 'client_no_context_takeover', True
elif self.defaults['client_no_context_takeover']:
yield 'client_no_context_takeover', True
default = self.defaults['server_max_window_bits']
if 'server_max_window_bits' in params:
mwb = int(params['server_max_window_bits'])
assert 8 <= mwb <= zlib.MAX_WBITS
yield 'server_max_window_bits', min(mwb, default)
elif default != zlib.MAX_WBITS:
yield 'server_max_window_bits', default
if 'server_no_context_takeover' in params:
assert params['server_no_context_takeover'] is True
yield 'server_no_context_takeover', True
elif self.defaults['server_no_context_takeover']:
yield 'server_no_context_takeover', True
class Instance(DeflateFrame.Instance):
def init(self):
if not self.server_no_context_takeover:
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.server_max_window_bits)
if not self.client_no_context_takeover:
self.dec = zlib.decompressobj(-self.client_max_window_bits)
def deflate(self, data):
if self.server_no_context_takeover:
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.server_max_window_bits)
compressed = self.defl.compress(data)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
def inflate(self, data):
data = str(data + '\x00\x00\xff\xff')
if self.client_no_context_takeover:
self.dec = zlib.decompressobj(-self.client_max_window_bits)
return self.dec.decompress(data)
......@@ -4,6 +4,7 @@ class Extension(object):
rsv2 = False
rsv3 = False
opcodes = ()
before_fragmentation = False
defaults = {}
def __init__(self, **kwargs):
......@@ -23,6 +24,10 @@ class Extension(object):
def names(self):
return (self.name,) if self.name else ()
def is_supported(self, name, other_instances):
return name in self.names and not any(self.conflicts(other.extension)
for other in other_instances)
def conflicts(self, ext):
"""
Check if the extension conflicts with an already accepted extension.
......@@ -76,14 +81,22 @@ class Extension(object):
def init(self):
return NotImplemented
def onsend_frame(self, frame):
pass
def handle_send(self, frame):
if self.extension.before_fragmentation:
assert not frame.is_fragmented()
def onrecv_frame(self, frame):
pass
replacement = self.onsend(frame)
return frame if replacement is None else replacement
def onsend_message(self, message):
pass
def handle_recv(self, frame):
if self.extension.before_fragmentation:
assert not frame.is_fragmented()
def onrecv_message(self, message):
pass
replacement = self.onrecv(frame)
return frame if replacement is None else replacement
def onsend(self, frame):
raise NotImplementedError
def onrecv(self, frame):
raise NotImplementedError
......@@ -143,6 +143,9 @@ class Frame(object):
return frames
def is_fragmented(self):
return not self.final or self.opcode == OPCODE_CONTINUATION
def __str__(self):
s = '<%s opcode=0x%X len=%d' \
% (self.__class__.__name__, self.opcode, len(self.payload))
......
......@@ -177,8 +177,7 @@ class ServerHandshake(Handshake):
name, params = parse_param_hdr(hdr)
for ext in ssock.extensions:
if not any(ext.conflicts(other.extension)
for other in self.wsock.extension_instances):
if ext.is_supported(name, self.wsock.extension_instances):
accept_params = ext.negotiate_safe(name, params)
if accept_params is not None:
......@@ -432,4 +431,4 @@ def format_param_hdr(value, params):
return k + '=' + str(v)
strparams = filter(None, map(fmt_param, params.items()))
return '%s; %s' % (value, ', '.join(strparams))
return '; '.join([value] + strparams)
......@@ -131,21 +131,17 @@ class websocket(object):
ClientHandshake(self).perform()
self.handshake_sent = True
def apply_send_hooks(self, frame):
def apply_send_hooks(self, frame, before_fragmentation):
for inst in self.extension_instances:
replacement = inst.onsend_frame(frame)
if replacement is not None:
frame = replacement
if inst.extension.before_fragmentation == before_fragmentation:
frame = inst.handle_send(frame)
return frame
def apply_recv_hooks(self, frame):
def apply_recv_hooks(self, frame, before_fragmentation):
for inst in reversed(self.extension_instances):
replacement = inst.onrecv_frame(frame)
if replacement is not None:
frame = replacement
if inst.extension.before_fragmentation == before_fragmentation:
frame = inst.handle_recv(frame)
return frame
......@@ -154,14 +150,14 @@ class websocket(object):
Send a number of frames.
"""
for frame in args:
self.sock.sendall(self.apply_send_hooks(frame).pack())
self.sock.sendall(self.apply_send_hooks(frame, False).pack())
def recv(self):
"""
Receive a single frames. This can be either a data frame or a control
frame.
"""
return self.apply_recv_hooks(receive_frame(self.sock))
return self.apply_recv_hooks(receive_frame(self.sock), False)
def recvn(self, n):
"""
......@@ -177,7 +173,7 @@ class websocket(object):
frame has been fully written. `recv_callback` is an optional callable
to quickly set the `recv_callback` attribute to.
"""
frame = self.apply_send_hooks(frame)
frame = self.apply_send_hooks(frame, False)
self.sendbuf += frame.pack()
self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
......@@ -222,7 +218,7 @@ class websocket(object):
while contains_frame(self.recvbuf):
frame, self.recvbuf = pop_frame(self.recvbuf)
frame = self.apply_recv_hooks(frame)
frame = self.apply_recv_hooks(frame, False)
if not self.recv_callback:
raise ValueError('no callback installed for %s' % frame)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment