Commit 5ed32daf authored by Taddeüs Kroes's avatar Taddeüs Kroes

Added permessage-deflate + lots of general debug and cleanup

parent b2e1f4b3
...@@ -20,8 +20,10 @@ Her is a quick overview of the features in this library: ...@@ -20,8 +20,10 @@ Her is a quick overview of the features in this library:
- HTTP authentication during handshake. - HTTP authentication during handshake.
- An extendible server implementation. - An extendible server implementation.
- Secure sockets using SSL certificates (for 'wss://...' URLs). - Secure sockets using SSL certificates (for 'wss://...' URLs).
- The possibility to add extensions to the web socket protocol. An included - An API for implementing WebSocket extensions. Included implementations are
implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06). [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06)
and
[permessage-deflate](http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-17).
- Asynchronous sockets with an EPOLL-based server. - Asynchronous sockets with an EPOLL-based server.
...@@ -112,16 +114,19 @@ Basic usage ...@@ -112,16 +114,19 @@ Basic usage
conn.send(msg(foo='Hello, World!')) conn.send(msg(foo='Hello, World!'))
Built-in Server Built-in servers
=============== ================
The built-in `Server` implementation is very basic. It starts a new thread with Threaded
a `Connection.receive_forever()` loop for each client that connects. It also --------
The `Server` class is very basic. It starts a new thread with a
`Connection.receive_forever()` loop for each client that connects. It also
handles client crashes properly. By default, a `Server` instance only logs handles client crashes properly. By default, a `Server` instance only logs
every event using Python's `logging` module. To create a custom server, The every event using Python's `logging` module. To create a custom server, The
`Server` class should be extended and its event handlers overwritten. The event `Server` class should be extended and its event handlers overwritten. The event
handlers are named identically to the `Connection` event handlers, but they handlers are named identically to the `Connection` event handlers, but they
also receive an additional `client` argument. This argument is a modified also receive an additional `client` argument. The client argumetn is a modified
`Connection` instance, so you can invoke `send()` and `recv()`. `Connection` instance, so you can invoke `send()` and `recv()`.
For example, the `EchoConnection` example above can be rewritten to: For example, the `EchoConnection` example above can be rewritten to:
...@@ -144,6 +149,16 @@ For example, the `EchoConnection` example above can be rewritten to: ...@@ -144,6 +149,16 @@ For example, the `EchoConnection` example above can be rewritten to:
The server can be stopped by typing CTRL-C in the command line. The The server can be stopped by typing CTRL-C in the command line. The
`KeyboardInterrupt` raised when this happens is caught by the server. `KeyboardInterrupt` raised when this happens is caught by the server.
Asynchronous
------------
The `AsyncServer` class has the same API as `Server`, but uses
[EPOLL](https://docs.python.org/2/library/select.html#epoll-objects) instead of
threads. This means that when you send a message, it is put into a queue to be
sent later when the socket is ready. The client argument is againa modified
`Connection` instance, with a non-blocking `send()` method (`recv` is still
blocking, use the server's `onmessage` callback instead).
Extensions Extensions
========== ==========
......
...@@ -11,4 +11,5 @@ from message import Message, TextMessage, BinaryMessage ...@@ -11,4 +11,5 @@ from message import Message, TextMessage, BinaryMessage
from errors import SocketClosed, HandshakeError, PingError, SSLError from errors import SocketClosed, HandshakeError, PingError, SSLError
from extension import Extension from extension import Extension
from deflate_frame import DeflateFrame from deflate_frame import DeflateFrame
from deflate_message import DeflateMessage
from async import AsyncConnection, AsyncServer from async import AsyncConnection, AsyncServer
...@@ -54,14 +54,13 @@ class Connection(object): ...@@ -54,14 +54,13 @@ class Connection(object):
self.onopen() self.onopen()
def message_to_frames(self, message, fragment_size=None, mask=False): def message_to_frames(self, message, fragment_size=None, mask=False):
for hook in self.hooks_send: frame = self.sock.apply_send_hooks(message.frame(mask=mask), True)
message = hook(message)
if fragment_size is None: if fragment_size is None:
yield message.frame(mask=mask)
else:
for frame in message.fragment(fragment_size, mask=mask):
yield frame yield frame
else:
for fragment in frame.fragment(fragment_size):
yield fragment
def send(self, message, fragment_size=None, mask=False): def send(self, message, fragment_size=None, mask=False):
""" """
...@@ -101,17 +100,14 @@ class Connection(object): ...@@ -101,17 +100,14 @@ class Connection(object):
return self.concat_fragments(fragments) return self.concat_fragments(fragments)
def concat_fragments(self, fragments): def concat_fragments(self, fragments):
payload = bytearray() frame = fragments[0]
for f in fragments:
payload += f.payload
message = create_message(fragments[0].opcode, payload)
for hook in self.hooks_recv: for f in fragments[1:]:
message = hook(message) frame.payload += f.payload
return message frame.final = True
frame = self.sock.apply_recv_hooks(frame, True)
return create_message(frame.opcode, frame.payload)
def handle_control_frame(self, frame): def handle_control_frame(self, frame):
""" """
...@@ -207,36 +203,6 @@ class Connection(object): ...@@ -207,36 +203,6 @@ class Connection(object):
self.handle_control_frame(frame) self.handle_control_frame(frame)
def add_hook(self, send=None, recv=None, prepend=False):
"""
Add a pair of send and receive hooks that are called for each frame
that is sent or received. A hook is a function that receives a single
argument - a Message instance - and returns a `Message` instance as
well.
`prepend` is a flag indicating whether the send hook is prepended to
the other send hooks.
For example, to add an automatic JSON conversion to messages and
eliminate the need to contruct TextMessage instances to all messages:
>>> import wspy, json
>>> conn = Connection(...)
>>> conn.add_hook(lambda data: tswpy.TextMessage(json.dumps(data)),
>>> lambda message: json.loads(message.payload))
>>> conn.send({'foo': 'bar'}) # Sends text message {"foo":"bar"}
>>> conn.recv() # May be dict(foo='bar')
Note that here `prepend=True`, so that data passed to `send()` is first
encoded and then packed into a frame. Of course, one could also decide
to add the base64 hook first, or to return a new `Frame` instance with
base64-encoded data.
"""
if send:
self.hooks_send.insert(0 if prepend else -1, send)
if recv:
self.hooks_recv.insert(-1 if prepend else 0, recv)
def onopen(self): def onopen(self):
""" """
Called after the connection is initialized. Called after the connection is initialized.
......
...@@ -24,7 +24,7 @@ class DeflateFrame(Extension): ...@@ -24,7 +24,7 @@ class DeflateFrame(Extension):
'no_context_takeover': False 'no_context_takeover': False
} }
compression_threshold = 64 # minimal payload size for compression compression_threshold = 20 # minimal payload size for compression
def negotiate(self, name, params): def negotiate(self, name, params):
if 'max_window_bits' in params: if 'max_window_bits' in params:
...@@ -43,13 +43,16 @@ class DeflateFrame(Extension): ...@@ -43,13 +43,16 @@ class DeflateFrame(Extension):
zlib.DEFLATED, -self.max_window_bits) zlib.DEFLATED, -self.max_window_bits)
self.dec = zlib.decompressobj(-self.max_window_bits) self.dec = zlib.decompressobj(-self.max_window_bits)
def onsend_frame(self, frame): def onsend(self, frame):
if not frame.rsv1 and not isinstance(frame, ControlFrame) and \ if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
len(frame.payload) > self.extension.compression_threshold: len(frame.payload) > self.extension.compression_threshold:
deflated = self.deflate(frame.payload)
if len(deflated) < len(frame.payload):
frame.rsv1 = True frame.rsv1 = True
frame.payload = self.deflate(frame) frame.payload = deflated
def onrecv_frame(self, frame): def onrecv(self, frame):
if frame.rsv1: if frame.rsv1:
if isinstance(frame, ControlFrame): if isinstance(frame, ControlFrame):
raise ValueError('received compressed control frame') raise ValueError('received compressed control frame')
...@@ -57,13 +60,13 @@ class DeflateFrame(Extension): ...@@ -57,13 +60,13 @@ class DeflateFrame(Extension):
frame.rsv1 = False frame.rsv1 = False
frame.payload = self.inflate(frame.payload) frame.payload = self.inflate(frame.payload)
def deflate(self, frame): def deflate(self, data):
if self.no_context_takeover: if self.no_context_takeover:
compressed = zlib.compress(frame.payload) self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
else: zlib.DEFLATED, -self.max_window_bits)
compressed = self.defl.compress(frame.payload)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
compressed = self.defl.compress(data)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff' assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4] return compressed[:-4]
...@@ -71,7 +74,6 @@ class DeflateFrame(Extension): ...@@ -71,7 +74,6 @@ class DeflateFrame(Extension):
data = str(data + '\x00\x00\xff\xff') data = str(data + '\x00\x00\xff\xff')
if self.no_context_takeover: if self.no_context_takeover:
dec = zlib.decompressobj(-self.max_window_bits) self.dec = zlib.decompressobj(-self.max_window_bits)
return dec.decompress(data) + dec.flush()
return self.dec.decompress(data) return self.dec.decompress(data)
import zlib
from extension import Extension
from deflate_frame import DeflateFrame
class DeflateMessage(Extension):
"""
Implementation of the "permessage-deflate" extension, as defined by
http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-17.
Note: this implementetion is only eligible for server sockets, client
sockets must NOT use it.
"""
name = 'permessage-deflate'
rsv1 = True
defaults = {
'client_max_window_bits': zlib.MAX_WBITS,
'client_no_context_takeover': False,
'server_max_window_bits': zlib.MAX_WBITS,
'server_no_context_takeover': False
}
before_fragmentation = True
compression_threshold = 20 # minimal message payload size for compression
def negotiate(self, name, params):
default = self.defaults['client_max_window_bits']
if 'client_max_window_bits' in params:
mwb = params['client_max_window_bits']
if mwb is True:
if default != zlib.MAX_WBITS:
yield 'client_max_window_bits', default
else:
mwb = int(mwb)
assert 8 <= mwb <= zlib.MAX_WBITS
yield 'client_max_window_bits', min(mwb, default)
elif default != zlib.MAX_WBITS:
yield 'client_max_window_bits', default
if 'client_no_context_takeover' in params:
assert params['client_no_context_takeover'] is True
yield 'client_no_context_takeover', True
elif self.defaults['client_no_context_takeover']:
yield 'client_no_context_takeover', True
default = self.defaults['server_max_window_bits']
if 'server_max_window_bits' in params:
mwb = int(params['server_max_window_bits'])
assert 8 <= mwb <= zlib.MAX_WBITS
yield 'server_max_window_bits', min(mwb, default)
elif default != zlib.MAX_WBITS:
yield 'server_max_window_bits', default
if 'server_no_context_takeover' in params:
assert params['server_no_context_takeover'] is True
yield 'server_no_context_takeover', True
elif self.defaults['server_no_context_takeover']:
yield 'server_no_context_takeover', True
class Instance(DeflateFrame.Instance):
def init(self):
if not self.server_no_context_takeover:
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.server_max_window_bits)
if not self.client_no_context_takeover:
self.dec = zlib.decompressobj(-self.client_max_window_bits)
def deflate(self, data):
if self.server_no_context_takeover:
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.server_max_window_bits)
compressed = self.defl.compress(data)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
def inflate(self, data):
data = str(data + '\x00\x00\xff\xff')
if self.client_no_context_takeover:
self.dec = zlib.decompressobj(-self.client_max_window_bits)
return self.dec.decompress(data)
...@@ -4,6 +4,7 @@ class Extension(object): ...@@ -4,6 +4,7 @@ class Extension(object):
rsv2 = False rsv2 = False
rsv3 = False rsv3 = False
opcodes = () opcodes = ()
before_fragmentation = False
defaults = {} defaults = {}
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -23,6 +24,10 @@ class Extension(object): ...@@ -23,6 +24,10 @@ class Extension(object):
def names(self): def names(self):
return (self.name,) if self.name else () return (self.name,) if self.name else ()
def is_supported(self, name, other_instances):
return name in self.names and not any(self.conflicts(other.extension)
for other in other_instances)
def conflicts(self, ext): def conflicts(self, ext):
""" """
Check if the extension conflicts with an already accepted extension. Check if the extension conflicts with an already accepted extension.
...@@ -76,14 +81,22 @@ class Extension(object): ...@@ -76,14 +81,22 @@ class Extension(object):
def init(self): def init(self):
return NotImplemented return NotImplemented
def onsend_frame(self, frame): def handle_send(self, frame):
pass if self.extension.before_fragmentation:
assert not frame.is_fragmented()
def onrecv_frame(self, frame): replacement = self.onsend(frame)
pass return frame if replacement is None else replacement
def onsend_message(self, message): def handle_recv(self, frame):
pass if self.extension.before_fragmentation:
assert not frame.is_fragmented()
def onrecv_message(self, message): replacement = self.onrecv(frame)
pass return frame if replacement is None else replacement
def onsend(self, frame):
raise NotImplementedError
def onrecv(self, frame):
raise NotImplementedError
...@@ -143,6 +143,9 @@ class Frame(object): ...@@ -143,6 +143,9 @@ class Frame(object):
return frames return frames
def is_fragmented(self):
return not self.final or self.opcode == OPCODE_CONTINUATION
def __str__(self): def __str__(self):
s = '<%s opcode=0x%X len=%d' \ s = '<%s opcode=0x%X len=%d' \
% (self.__class__.__name__, self.opcode, len(self.payload)) % (self.__class__.__name__, self.opcode, len(self.payload))
......
...@@ -177,8 +177,7 @@ class ServerHandshake(Handshake): ...@@ -177,8 +177,7 @@ class ServerHandshake(Handshake):
name, params = parse_param_hdr(hdr) name, params = parse_param_hdr(hdr)
for ext in ssock.extensions: for ext in ssock.extensions:
if not any(ext.conflicts(other.extension) if ext.is_supported(name, self.wsock.extension_instances):
for other in self.wsock.extension_instances):
accept_params = ext.negotiate_safe(name, params) accept_params = ext.negotiate_safe(name, params)
if accept_params is not None: if accept_params is not None:
...@@ -432,4 +431,4 @@ def format_param_hdr(value, params): ...@@ -432,4 +431,4 @@ def format_param_hdr(value, params):
return k + '=' + str(v) return k + '=' + str(v)
strparams = filter(None, map(fmt_param, params.items())) strparams = filter(None, map(fmt_param, params.items()))
return '%s; %s' % (value, ', '.join(strparams)) return '; '.join([value] + strparams)
...@@ -131,21 +131,17 @@ class websocket(object): ...@@ -131,21 +131,17 @@ class websocket(object):
ClientHandshake(self).perform() ClientHandshake(self).perform()
self.handshake_sent = True self.handshake_sent = True
def apply_send_hooks(self, frame): def apply_send_hooks(self, frame, before_fragmentation):
for inst in self.extension_instances: for inst in self.extension_instances:
replacement = inst.onsend_frame(frame) if inst.extension.before_fragmentation == before_fragmentation:
frame = inst.handle_send(frame)
if replacement is not None:
frame = replacement
return frame return frame
def apply_recv_hooks(self, frame): def apply_recv_hooks(self, frame, before_fragmentation):
for inst in reversed(self.extension_instances): for inst in reversed(self.extension_instances):
replacement = inst.onrecv_frame(frame) if inst.extension.before_fragmentation == before_fragmentation:
frame = inst.handle_recv(frame)
if replacement is not None:
frame = replacement
return frame return frame
...@@ -154,14 +150,14 @@ class websocket(object): ...@@ -154,14 +150,14 @@ class websocket(object):
Send a number of frames. Send a number of frames.
""" """
for frame in args: for frame in args:
self.sock.sendall(self.apply_send_hooks(frame).pack()) self.sock.sendall(self.apply_send_hooks(frame, False).pack())
def recv(self): def recv(self):
""" """
Receive a single frames. This can be either a data frame or a control Receive a single frames. This can be either a data frame or a control
frame. frame.
""" """
return self.apply_recv_hooks(receive_frame(self.sock)) return self.apply_recv_hooks(receive_frame(self.sock), False)
def recvn(self, n): def recvn(self, n):
""" """
...@@ -177,7 +173,7 @@ class websocket(object): ...@@ -177,7 +173,7 @@ class websocket(object):
frame has been fully written. `recv_callback` is an optional callable frame has been fully written. `recv_callback` is an optional callable
to quickly set the `recv_callback` attribute to. to quickly set the `recv_callback` attribute to.
""" """
frame = self.apply_send_hooks(frame) frame = self.apply_send_hooks(frame, False)
self.sendbuf += frame.pack() self.sendbuf += frame.pack()
self.sendbuf_frames.append([frame, len(self.sendbuf), callback]) self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
...@@ -222,7 +218,7 @@ class websocket(object): ...@@ -222,7 +218,7 @@ class websocket(object):
while contains_frame(self.recvbuf): while contains_frame(self.recvbuf):
frame, self.recvbuf = pop_frame(self.recvbuf) frame, self.recvbuf = pop_frame(self.recvbuf)
frame = self.apply_recv_hooks(frame) frame = self.apply_recv_hooks(frame, False)
if not self.recv_callback: if not self.recv_callback:
raise ValueError('no callback installed for %s' % frame) raise ValueError('no callback installed for %s' % frame)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment