Commit 9232e5d4 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Rewrote extensions API + reimplemented deflate-frame

parent 6c79550e
......@@ -10,5 +10,5 @@ from connection import Connection
from message import Message, TextMessage, BinaryMessage
from errors import SocketClosed, HandshakeError, PingError, SSLError
from extension import Extension
from deflate_frame import DeflateFrame, WebkitDeflateFrame
from deflate_frame import DeflateFrame
from async import AsyncConnection, AsyncServer
......@@ -17,40 +17,39 @@ class DeflateFrame(Extension):
Note that the deflate and inflate hooks modify the RSV1 bit and payload of
existing `Frame` objects.
"""
name = 'deflate-frame'
names = ('deflate-frame', 'x-webkit-deflate-frame')
rsv1 = True
defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
COMPRESSION_THRESHOLD = 64 # minimal payload size for compression
def init(self):
mwb = self.defaults['max_window_bits']
cto = self.defaults['no_context_takeover']
if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
raise ValueError('"max_window_bits" must be in range 1-15')
if cto is not False and cto is not True:
raise ValueError('"no_context_takeover" must have no value')
class Hook(Extension.Hook):
def init(self, extension):
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
self.dec = zlib.decompressobj(-other_wbits)
defaults = {
'max_window_bits': zlib.MAX_WBITS,
'no_context_takeover': False
}
compression_threshold = 64 # minimal payload size for compression
def negotiate(self, name, params):
if 'max_window_bits' in params:
mwb = int(params['max_window_bits'])
assert 8 <= mwb <= zlib.MAX_WBITS
yield 'max_window_bits', mwb
if 'no_context_takeover' in params:
assert params['no_context_takeover'] is True
yield 'no_context_takeover', True
class Instance(Extension.Instance):
def init(self):
if not self.no_context_takeover:
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
self.dec = zlib.decompressobj(-self.max_window_bits)
def send(self, frame):
# FIXME: this does not seem to work properly on Android
def onsend_frame(self, frame):
if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
len(frame.payload) > self.extension.compression_threshold:
frame.rsv1 = True
frame.payload = self.deflate(frame)
return frame
def recv(self, frame):
def onrecv_frame(self, frame):
if frame.rsv1:
if isinstance(frame, ControlFrame):
raise ValueError('received compressed control frame')
......@@ -58,26 +57,22 @@ class DeflateFrame(Extension):
frame.rsv1 = False
frame.payload = self.inflate(frame.payload)
return frame
def deflate(self, frame):
compressed = self.defl.compress(frame.payload)
if frame.final or self.no_context_takeover:
compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
if self.no_context_takeover:
print 'no_context_takeover'
compressed = zlib.compress(frame.payload)
else:
compressed = self.defl.compress(frame.payload)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
compressed = compressed[:-4]
return compressed
assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
def inflate(self, data):
return self.dec.decompress(data + '\x00\x00\xff\xff') + \
self.dec.flush(zlib.Z_SYNC_FLUSH)
data = str(data + '\x00\x00\xff\xff')
if self.no_context_takeover:
dec = zlib.decompressobj(-self.max_window_bits)
return dec.decompress(data) + dec.flush()
class WebkitDeflateFrame(DeflateFrame):
name = 'x-webkit-deflate-frame'
return self.dec.decompress(data)
......@@ -3,67 +3,87 @@ class Extension(object):
rsv1 = False
rsv2 = False
rsv3 = False
opcodes = []
opcodes = ()
defaults = {}
request = {}
def __init__(self, defaults={}, request={}):
for param in defaults.keys() + request.keys():
def __init__(self, **kwargs):
for param in kwargs.iterkeys():
if param not in self.defaults:
raise KeyError('unrecognized parameter "%s"' % param)
# Copy dict first to avoid duplicate references to the same object
self.defaults = dict(self.__class__.defaults)
self.defaults.update(defaults)
self.request = dict(self.__class__.request)
self.request.update(request)
self.init()
self.defaults.update(kwargs)
def __str__(self):
return '<Extension "%s" defaults=%s request=%s>' \
% (self.name, self.defaults, self.request)
def init(self):
return NotImplemented
@property
def names(self):
return (self.name,) if self.name else ()
def conflicts(self, ext):
"""
Check if the extension conflicts with an already accepted extension.
This may be the case when the two extensions use the same reserved
bits, or have the same name (when the same extension is negotiated
multiple times with different parameters).
"""
return ext.rsv1 and self.rsv1 \
or ext.rsv2 and self.rsv2 \
or ext.rsv3 and self.rsv3 \
or set(ext.names) & set(self.names) \
or set(ext.opcodes) & set(self.opcodes)
def negotiate(self, name, params):
"""
Same as `negotiate_safe`, but instead returns an iterator of (param,
value) tuples and raises an exception on error.
"""
raise NotImplementedError
def negotiate_safe(self, name, params):
"""
`name` and `params` are sent in the HTTP request by the client. Check
if the extension name is supported by this extension, and validate the
parameters. Returns a dict with accepted parameters, or None if not
accepted.
"""
for param in params.iterkeys():
if param not in self.defaults:
return
try:
return dict(self.negotiate(name, params))
except (KeyError, ValueError, AssertionError):
pass
def create_hook(self, **kwargs):
params = {}
params.update(self.defaults)
params.update(kwargs)
hook = self.Hook(**params)
hook.init(self)
return hook
class Instance:
def __init__(self, extension, name, params):
self.extension = extension
self.name = name
self.params = params
class Hook:
def __init__(self, **kwargs):
for param, value in kwargs.iteritems():
for param, value in extension.defaults.iteritems():
setattr(self, param, value)
def init(self, extension):
return NotImplemented
for param, value in params.iteritems():
setattr(self, param, value)
def send(self, frame):
return frame
self.init()
def recv(self, frame):
return frame
def init(self):
return NotImplemented
def onsend_frame(self, frame):
pass
def extension_conflicts(ext, existing):
rsv1_reserved = False
rsv2_reserved = False
rsv3_reserved = False
reserved_opcodes = []
def onrecv_frame(self, frame):
pass
for e in existing:
rsv1_reserved |= e.rsv1
rsv2_reserved |= e.rsv2
rsv3_reserved |= e.rsv3
reserved_opcodes.extend(e.opcodes)
def onsend_message(self, message):
pass
return ext.rsv1 and rsv1_reserved \
or ext.rsv2 and rsv2_reserved \
or ext.rsv3 and rsv3_reserved \
or len(set(ext.opcodes) & set(reserved_opcodes))
def onrecv_message(self, message):
pass
......@@ -7,7 +7,6 @@ from hashlib import sha1
from urlparse import urlparse
from errors import HandshakeError
from extension import extension_conflicts
from python_digest import build_authorization_request
......@@ -172,20 +171,19 @@ class ServerHandshake(Handshake):
# Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in ssock.extensions)
self.wsock.extension_hooks = []
extensions = []
self.wsock.extension_instances = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext)
for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(hdr)
if name in supported_ext:
ext = supported_ext[name]
for ext in ssock.extensions:
if not any(ext.conflicts(other.extension)
for other in self.wsock.extension_instances):
accept_params = ext.negotiate_safe(name, params)
if not extension_conflicts(ext, extensions):
extensions.append(ext)
hook = ext.create_hook(**params)
self.wsock.extension_hooks.append(hook)
if accept_params is not None:
instance = ext.Instance(ext, name, accept_params)
self.wsock.extension_instances.append(instance)
# Check if requested resource location is served by this server
if ssock.locations:
......@@ -222,9 +220,9 @@ class ServerHandshake(Handshake):
if self.wsock.protocol:
yield 'Sec-WebSocket-Protocol', self.wsock.protocol
if self.wsock.extensions:
values = [format_param_hdr(e.name, e.request)
for e in self.wsock.extensions]
if self.wsock.extension_instances:
values = [format_param_hdr(i.name, i.params)
for i in self.wsock.extension_instances]
yield 'Sec-WebSocket-Extensions', ', '.join(values)
......@@ -273,19 +271,23 @@ class ClientHandshake(Handshake):
# Compare extensions, add hooks only for those returned by server
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions)
self.wsock.extension_hooks = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext)
if name not in supported_ext:
# FIXME: there is no distinction between server/client extension
# instances, while the extension instance may assume it belongs to
# a server, leading to undefined behavior
self.wsock.extension_instances = []
for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
name, accept_params = parse_param_hdr(hdr)
for ext in self.wsock.extensions:
if name in ext.names:
instance = ext.Instance(ext, name, accept_params)
self.wsock.extension_instances.append(instance)
break
else:
raise HandshakeError('server handshake contains '
'unsupported extension "%s"' % name)
hook = supported_ext[name].create_hook(**params)
self.wsock.extension_hooks.append(hook)
# Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers:
protocol = headers['Sec-WebSocket-Protocol']
......
......@@ -9,7 +9,6 @@ sys.path.insert(0, basepath)
from websocket import websocket
from connection import Connection
from message import TextMessage
from errors import SocketClosed
ADDR = ('localhost', 8000)
......@@ -30,9 +29,8 @@ class EchoClient(Connection):
print 'Connection closed'
secure = True
if __name__ == '__main__':
secure = '-s' in sys.argv[1:]
scheme = 'wss' if secure else 'ws'
print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
sock = websocket()
......
......@@ -7,7 +7,7 @@ basepath = abspath(dirname(abspath(__file__)) + '/..')
sys.path.insert(0, basepath)
from server import Server
from deflate_frame import WebkitDeflateFrame
from deflate_frame import DeflateFrame
class EchoServer(Server):
......@@ -17,8 +17,8 @@ class EchoServer(Server):
if __name__ == '__main__':
deflate = WebkitDeflateFrame()
#deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True})
EchoServer(('localhost', 8000), extensions=[deflate],
EchoServer(('localhost', 8000),
#extensions=[DeflateFrame(no_context_takeover=True)],
extensions=[DeflateFrame()],
#ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
loglevel=logging.DEBUG).run()
......@@ -81,7 +81,7 @@ class websocket(object):
"""
self.protocols = protocols
self.extensions = extensions
self.extension_hooks = []
self.extension_instances = []
self.origin = origin
self.location = location
self.trusted_origins = trusted_origins
......@@ -99,9 +99,6 @@ class websocket(object):
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
def set_extensions(self, extensions):
self.extensions = [ext.Hook() for ext in extensions]
def __getattr__(self, name):
if name in INHERITED_ATTRS:
return getattr(self.sock, name)
......@@ -135,14 +132,20 @@ class websocket(object):
self.handshake_sent = True
def apply_send_hooks(self, frame):
for hook in self.extension_hooks:
frame = hook.send(frame)
for inst in self.extension_instances:
replacement = inst.onsend_frame(frame)
if replacement is not None:
frame = replacement
return frame
def apply_recv_hooks(self, frame):
for hook in reversed(self.extension_hooks):
frame = hook.recv(frame)
for inst in reversed(self.extension_instances):
replacement = inst.onrecv_frame(frame)
if replacement is not None:
frame = replacement
return frame
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment