Commit 91ec1795 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Fixed deflate-frame implementation

parent d17b2ebb
import zlib import zlib
from frame import ControlFrame
from errors import SocketClosed
class Extension(object): class Extension(object):
name = '' name = ''
...@@ -22,7 +25,7 @@ class Extension(object): ...@@ -22,7 +25,7 @@ class Extension(object):
self.request = dict(self.__class__.request) self.request = dict(self.__class__.request)
self.request.update(request) self.request.update(request)
def __str__(self, frame): def __str__(self):
return '<Extension "%s" defaults=%s request=%s>' \ return '<Extension "%s" defaults=%s request=%s>' \
% (self.name, self.defaults, self.request) % (self.name, self.defaults, self.request)
...@@ -30,10 +33,12 @@ class Extension(object): ...@@ -30,10 +33,12 @@ class Extension(object):
params = {} params = {}
params.update(self.defaults) params.update(self.defaults)
params.update(kwargs) params.update(kwargs)
return self.Hook(**params) return self.Hook(self, **params)
class Hook: class Hook:
def __init__(self, **kwargs): def __init__(self, extension, **kwargs):
self.extension = extension
for param, value in kwargs.iteritems(): for param, value in kwargs.iteritems():
setattr(self, param, value) setattr(self, param, value)
...@@ -61,7 +66,7 @@ class DeflateFrame(Extension): ...@@ -61,7 +66,7 @@ class DeflateFrame(Extension):
name = 'deflate-frame' name = 'deflate-frame'
rsv1 = True rsv1 = True
# FIXME: is 32768 (below) correct? # FIXME: is 32768 (below) correct?
defaults = {'max_window_bits': 15, 'no_context_takeover': True} defaults = {'max_window_bits': 15, 'no_context_takeover': False}
def __init__(self, defaults={}, request={}): def __init__(self, defaults={}, request={}):
Extension.__init__(self, defaults, request) Extension.__init__(self, defaults, request)
...@@ -78,24 +83,19 @@ class DeflateFrame(Extension): ...@@ -78,24 +83,19 @@ class DeflateFrame(Extension):
raise ValueError('"no_context_takeover" must have no value') raise ValueError('"no_context_takeover" must have no value')
class Hook(Extension.Hook): class Hook(Extension.Hook):
def __init__(self, **kwargs): def __init__(self, extension, **kwargs):
Extension.Hook.__init__(**kwargs) Extension.Hook.__init__(self, extension, **kwargs)
other_wbits = self.request.get('max_window_bits', 15)
# Don't request default value of max_window_bits
if 'max_window_bits' in self.request and other_wbits == 15:
del self.request['max_window_bits']
if not self.no_context_takeover: if not self.no_context_takeover:
self.com = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, zlib.DEFLATED,
self.max_window_bits) -self.max_window_bits)
self.dec = zlib.decompressobj(other_wbits) other_wbits = self.extension.request.get('max_window_bits', 15)
self.dec = zlib.decompressobj(-other_wbits)
def send(self, frame): def send(self, frame):
if not frame.rsv1: if not frame.rsv1 and not isinstance(frame, ControlFrame):
frame.rsv1 = True frame.rsv1 = True
frame.payload = self.deflate(frame.payload) frame.payload = self.deflate(frame.payload)
...@@ -103,6 +103,9 @@ class DeflateFrame(Extension): ...@@ -103,6 +103,9 @@ class DeflateFrame(Extension):
def recv(self, frame): def recv(self, frame):
if frame.rsv1: if frame.rsv1:
if isinstance(frame, ControlFrame):
raise SocketClosed('received compressed control frame')
frame.rsv1 = False frame.rsv1 = False
frame.payload = self.inflate(frame.payload) frame.payload = self.inflate(frame.payload)
...@@ -110,17 +113,21 @@ class DeflateFrame(Extension): ...@@ -110,17 +113,21 @@ class DeflateFrame(Extension):
def deflate(self, data): def deflate(self, data):
if self.no_context_takeover: if self.no_context_takeover:
if self.max_window_bits == 15: defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
return zlib.compress(data) zlib.DEFLATED, -self.max_window_bits)
# FIXME: why the '\x00' below? This was borrowed from
self.com = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, # https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
zlib.DEFLATED, return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
self.max_window_bits)
return self.com.compress(data) compressed = self.defl.compress(data)
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
def inflate(self, data): def inflate(self, data):
return self.dec.decompress(data) data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
assert not self.dec.unused_data
return data
class Multiplex(Extension): class Multiplex(Extension):
......
...@@ -21,6 +21,8 @@ class WebkitDeflateFrame(DeflateFrame): ...@@ -21,6 +21,8 @@ class WebkitDeflateFrame(DeflateFrame):
if __name__ == '__main__': if __name__ == '__main__':
EchoServer(('localhost', 8000), extensions=[WebkitDeflateFrame()], deflate = WebkitDeflateFrame()
#deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True})
EchoServer(('localhost', 8000), extensions=[deflate],
#ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'), #ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
loglevel=logging.DEBUG).run() loglevel=logging.DEBUG).run()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment