Commit 9232e5d4 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Rewrote extensions API + reimplemented deflate-frame

parent 6c79550e
...@@ -10,5 +10,5 @@ from connection import Connection ...@@ -10,5 +10,5 @@ from connection import Connection
from message import Message, TextMessage, BinaryMessage from message import Message, TextMessage, BinaryMessage
from errors import SocketClosed, HandshakeError, PingError, SSLError from errors import SocketClosed, HandshakeError, PingError, SSLError
from extension import Extension from extension import Extension
from deflate_frame import DeflateFrame, WebkitDeflateFrame from deflate_frame import DeflateFrame
from async import AsyncConnection, AsyncServer from async import AsyncConnection, AsyncServer
...@@ -17,40 +17,39 @@ class DeflateFrame(Extension): ...@@ -17,40 +17,39 @@ class DeflateFrame(Extension):
Note that the deflate and inflate hooks modify the RSV1 bit and payload of Note that the deflate and inflate hooks modify the RSV1 bit and payload of
existing `Frame` objects. existing `Frame` objects.
""" """
names = ('deflate-frame', 'x-webkit-deflate-frame')
name = 'deflate-frame'
rsv1 = True rsv1 = True
defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False} defaults = {
'max_window_bits': zlib.MAX_WBITS,
COMPRESSION_THRESHOLD = 64 # minimal payload size for compression 'no_context_takeover': False
}
def init(self):
mwb = self.defaults['max_window_bits'] compression_threshold = 64 # minimal payload size for compression
cto = self.defaults['no_context_takeover']
def negotiate(self, name, params):
if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS: if 'max_window_bits' in params:
raise ValueError('"max_window_bits" must be in range 1-15') mwb = int(params['max_window_bits'])
assert 8 <= mwb <= zlib.MAX_WBITS
if cto is not False and cto is not True: yield 'max_window_bits', mwb
raise ValueError('"no_context_takeover" must have no value')
if 'no_context_takeover' in params:
class Hook(Extension.Hook): assert params['no_context_takeover'] is True
def init(self, extension): yield 'no_context_takeover', True
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits) class Instance(Extension.Instance):
other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS) def init(self):
self.dec = zlib.decompressobj(-other_wbits) if not self.no_context_takeover:
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
self.dec = zlib.decompressobj(-self.max_window_bits)
def send(self, frame): def onsend_frame(self, frame):
# FIXME: this does not seem to work properly on Android
if not frame.rsv1 and not isinstance(frame, ControlFrame) and \ if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD: len(frame.payload) > self.extension.compression_threshold:
frame.rsv1 = True frame.rsv1 = True
frame.payload = self.deflate(frame) frame.payload = self.deflate(frame)
return frame def onrecv_frame(self, frame):
def recv(self, frame):
if frame.rsv1: if frame.rsv1:
if isinstance(frame, ControlFrame): if isinstance(frame, ControlFrame):
raise ValueError('received compressed control frame') raise ValueError('received compressed control frame')
...@@ -58,26 +57,22 @@ class DeflateFrame(Extension): ...@@ -58,26 +57,22 @@ class DeflateFrame(Extension):
frame.rsv1 = False frame.rsv1 = False
frame.payload = self.inflate(frame.payload) frame.payload = self.inflate(frame.payload)
return frame
def deflate(self, frame): def deflate(self, frame):
compressed = self.defl.compress(frame.payload) if self.no_context_takeover:
print 'no_context_takeover'
if frame.final or self.no_context_takeover: compressed = zlib.compress(frame.payload)
compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
else: else:
compressed = self.defl.compress(frame.payload)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH) compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
compressed = compressed[:-4]
return compressed assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
def inflate(self, data): def inflate(self, data):
return self.dec.decompress(data + '\x00\x00\xff\xff') + \ data = str(data + '\x00\x00\xff\xff')
self.dec.flush(zlib.Z_SYNC_FLUSH)
if self.no_context_takeover:
dec = zlib.decompressobj(-self.max_window_bits)
return dec.decompress(data) + dec.flush()
class WebkitDeflateFrame(DeflateFrame): return self.dec.decompress(data)
name = 'x-webkit-deflate-frame'
...@@ -3,67 +3,87 @@ class Extension(object): ...@@ -3,67 +3,87 @@ class Extension(object):
rsv1 = False rsv1 = False
rsv2 = False rsv2 = False
rsv3 = False rsv3 = False
opcodes = [] opcodes = ()
defaults = {} defaults = {}
request = {}
def __init__(self, defaults={}, request={}): def __init__(self, **kwargs):
for param in defaults.keys() + request.keys(): for param in kwargs.iterkeys():
if param not in self.defaults: if param not in self.defaults:
raise KeyError('unrecognized parameter "%s"' % param) raise KeyError('unrecognized parameter "%s"' % param)
# Copy dict first to avoid duplicate references to the same object # Copy dict first to avoid duplicate references to the same object
self.defaults = dict(self.__class__.defaults) self.defaults = dict(self.__class__.defaults)
self.defaults.update(defaults) self.defaults.update(kwargs)
self.request = dict(self.__class__.request)
self.request.update(request)
self.init()
def __str__(self): def __str__(self):
return '<Extension "%s" defaults=%s request=%s>' \ return '<Extension "%s" defaults=%s request=%s>' \
% (self.name, self.defaults, self.request) % (self.name, self.defaults, self.request)
def init(self): @property
return NotImplemented def names(self):
return (self.name,) if self.name else ()
def conflicts(self, ext):
"""
Check if the extension conflicts with an already accepted extension.
This may be the case when the two extensions use the same reserved
bits, or have the same name (when the same extension is negotiated
multiple times with different parameters).
"""
return ext.rsv1 and self.rsv1 \
or ext.rsv2 and self.rsv2 \
or ext.rsv3 and self.rsv3 \
or set(ext.names) & set(self.names) \
or set(ext.opcodes) & set(self.opcodes)
def negotiate(self, name, params):
"""
Same as `negotiate_safe`, but instead returns an iterator of (param,
value) tuples and raises an exception on error.
"""
raise NotImplementedError
def negotiate_safe(self, name, params):
"""
`name` and `params` are sent in the HTTP request by the client. Check
if the extension name is supported by this extension, and validate the
parameters. Returns a dict with accepted parameters, or None if not
accepted.
"""
for param in params.iterkeys():
if param not in self.defaults:
return
try:
return dict(self.negotiate(name, params))
except (KeyError, ValueError, AssertionError):
pass
def create_hook(self, **kwargs): class Instance:
params = {} def __init__(self, extension, name, params):
params.update(self.defaults) self.extension = extension
params.update(kwargs) self.name = name
hook = self.Hook(**params) self.params = params
hook.init(self)
return hook
class Hook: for param, value in extension.defaults.iteritems():
def __init__(self, **kwargs):
for param, value in kwargs.iteritems():
setattr(self, param, value) setattr(self, param, value)
def init(self, extension): for param, value in params.iteritems():
return NotImplemented setattr(self, param, value)
def send(self, frame): self.init()
return frame
def recv(self, frame): def init(self):
return frame return NotImplemented
def onsend_frame(self, frame):
pass
def extension_conflicts(ext, existing): def onrecv_frame(self, frame):
rsv1_reserved = False pass
rsv2_reserved = False
rsv3_reserved = False
reserved_opcodes = []
for e in existing: def onsend_message(self, message):
rsv1_reserved |= e.rsv1 pass
rsv2_reserved |= e.rsv2
rsv3_reserved |= e.rsv3
reserved_opcodes.extend(e.opcodes)
return ext.rsv1 and rsv1_reserved \ def onrecv_message(self, message):
or ext.rsv2 and rsv2_reserved \ pass
or ext.rsv3 and rsv3_reserved \
or len(set(ext.opcodes) & set(reserved_opcodes))
...@@ -7,7 +7,6 @@ from hashlib import sha1 ...@@ -7,7 +7,6 @@ from hashlib import sha1
from urlparse import urlparse from urlparse import urlparse
from errors import HandshakeError from errors import HandshakeError
from extension import extension_conflicts
from python_digest import build_authorization_request from python_digest import build_authorization_request
...@@ -172,20 +171,19 @@ class ServerHandshake(Handshake): ...@@ -172,20 +171,19 @@ class ServerHandshake(Handshake):
# Only supported extensions are returned # Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in ssock.extensions) self.wsock.extension_instances = []
self.wsock.extension_hooks = []
extensions = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']): for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext) name, params = parse_param_hdr(hdr)
if name in supported_ext: for ext in ssock.extensions:
ext = supported_ext[name] if not any(ext.conflicts(other.extension)
for other in self.wsock.extension_instances):
accept_params = ext.negotiate_safe(name, params)
if not extension_conflicts(ext, extensions): if accept_params is not None:
extensions.append(ext) instance = ext.Instance(ext, name, accept_params)
hook = ext.create_hook(**params) self.wsock.extension_instances.append(instance)
self.wsock.extension_hooks.append(hook)
# Check if requested resource location is served by this server # Check if requested resource location is served by this server
if ssock.locations: if ssock.locations:
...@@ -222,9 +220,9 @@ class ServerHandshake(Handshake): ...@@ -222,9 +220,9 @@ class ServerHandshake(Handshake):
if self.wsock.protocol: if self.wsock.protocol:
yield 'Sec-WebSocket-Protocol', self.wsock.protocol yield 'Sec-WebSocket-Protocol', self.wsock.protocol
if self.wsock.extensions: if self.wsock.extension_instances:
values = [format_param_hdr(e.name, e.request) values = [format_param_hdr(i.name, i.params)
for e in self.wsock.extensions] for i in self.wsock.extension_instances]
yield 'Sec-WebSocket-Extensions', ', '.join(values) yield 'Sec-WebSocket-Extensions', ', '.join(values)
...@@ -273,19 +271,23 @@ class ClientHandshake(Handshake): ...@@ -273,19 +271,23 @@ class ClientHandshake(Handshake):
# Compare extensions, add hooks only for those returned by server # Compare extensions, add hooks only for those returned by server
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions) # FIXME: there is no distinction between server/client extension
self.wsock.extension_hooks = [] # instances, while the extension instance may assume it belongs to
# a server, leading to undefined behavior
for ext in split_stripped(headers['Sec-WebSocket-Extensions']): self.wsock.extension_instances = []
name, params = parse_param_hdr(ext)
for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
if name not in supported_ext: name, accept_params = parse_param_hdr(hdr)
for ext in self.wsock.extensions:
if name in ext.names:
instance = ext.Instance(ext, name, accept_params)
self.wsock.extension_instances.append(instance)
break
else:
raise HandshakeError('server handshake contains ' raise HandshakeError('server handshake contains '
'unsupported extension "%s"' % name) 'unsupported extension "%s"' % name)
hook = supported_ext[name].create_hook(**params)
self.wsock.extension_hooks.append(hook)
# Assert that returned protocol (if any) is supported # Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers: if 'Sec-WebSocket-Protocol' in headers:
protocol = headers['Sec-WebSocket-Protocol'] protocol = headers['Sec-WebSocket-Protocol']
......
...@@ -9,7 +9,6 @@ sys.path.insert(0, basepath) ...@@ -9,7 +9,6 @@ sys.path.insert(0, basepath)
from websocket import websocket from websocket import websocket
from connection import Connection from connection import Connection
from message import TextMessage from message import TextMessage
from errors import SocketClosed
ADDR = ('localhost', 8000) ADDR = ('localhost', 8000)
...@@ -30,9 +29,8 @@ class EchoClient(Connection): ...@@ -30,9 +29,8 @@ class EchoClient(Connection):
print 'Connection closed' print 'Connection closed'
secure = True
if __name__ == '__main__': if __name__ == '__main__':
secure = '-s' in sys.argv[1:]
scheme = 'wss' if secure else 'ws' scheme = 'wss' if secure else 'ws'
print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR) print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
sock = websocket() sock = websocket()
......
...@@ -7,7 +7,7 @@ basepath = abspath(dirname(abspath(__file__)) + '/..') ...@@ -7,7 +7,7 @@ basepath = abspath(dirname(abspath(__file__)) + '/..')
sys.path.insert(0, basepath) sys.path.insert(0, basepath)
from server import Server from server import Server
from deflate_frame import WebkitDeflateFrame from deflate_frame import DeflateFrame
class EchoServer(Server): class EchoServer(Server):
...@@ -17,8 +17,8 @@ class EchoServer(Server): ...@@ -17,8 +17,8 @@ class EchoServer(Server):
if __name__ == '__main__': if __name__ == '__main__':
deflate = WebkitDeflateFrame() EchoServer(('localhost', 8000),
#deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True}) #extensions=[DeflateFrame(no_context_takeover=True)],
EchoServer(('localhost', 8000), extensions=[deflate], extensions=[DeflateFrame()],
#ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'), #ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
loglevel=logging.DEBUG).run() loglevel=logging.DEBUG).run()
...@@ -81,7 +81,7 @@ class websocket(object): ...@@ -81,7 +81,7 @@ class websocket(object):
""" """
self.protocols = protocols self.protocols = protocols
self.extensions = extensions self.extensions = extensions
self.extension_hooks = [] self.extension_instances = []
self.origin = origin self.origin = origin
self.location = location self.location = location
self.trusted_origins = trusted_origins self.trusted_origins = trusted_origins
...@@ -99,9 +99,6 @@ class websocket(object): ...@@ -99,9 +99,6 @@ class websocket(object):
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto) self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
def set_extensions(self, extensions):
self.extensions = [ext.Hook() for ext in extensions]
def __getattr__(self, name): def __getattr__(self, name):
if name in INHERITED_ATTRS: if name in INHERITED_ATTRS:
return getattr(self.sock, name) return getattr(self.sock, name)
...@@ -135,14 +132,20 @@ class websocket(object): ...@@ -135,14 +132,20 @@ class websocket(object):
self.handshake_sent = True self.handshake_sent = True
def apply_send_hooks(self, frame): def apply_send_hooks(self, frame):
for hook in self.extension_hooks: for inst in self.extension_instances:
frame = hook.send(frame) replacement = inst.onsend_frame(frame)
if replacement is not None:
frame = replacement
return frame return frame
def apply_recv_hooks(self, frame): def apply_recv_hooks(self, frame):
for hook in reversed(self.extension_hooks): for inst in reversed(self.extension_instances):
frame = hook.recv(frame) replacement = inst.onrecv_frame(frame)
if replacement is not None:
frame = replacement
return frame return frame
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment