|
|
@@ -1,5 +1,8 @@
|
|
|
import zlib
|
|
|
|
|
|
+from frame import ControlFrame
|
|
|
+from errors import SocketClosed
|
|
|
+
|
|
|
|
|
|
class Extension(object):
|
|
|
name = ''
|
|
|
@@ -22,7 +25,7 @@ class Extension(object):
|
|
|
self.request = dict(self.__class__.request)
|
|
|
self.request.update(request)
|
|
|
|
|
|
- def __str__(self, frame):
|
|
|
+ def __str__(self):
|
|
|
return '<Extension "%s" defaults=%s request=%s>' \
|
|
|
% (self.name, self.defaults, self.request)
|
|
|
|
|
|
@@ -30,10 +33,12 @@ class Extension(object):
|
|
|
params = {}
|
|
|
params.update(self.defaults)
|
|
|
params.update(kwargs)
|
|
|
- return self.Hook(**params)
|
|
|
+ return self.Hook(self, **params)
|
|
|
|
|
|
class Hook:
|
|
|
- def __init__(self, **kwargs):
|
|
|
+ def __init__(self, extension, **kwargs):
|
|
|
+ self.extension = extension
|
|
|
+
|
|
|
for param, value in kwargs.iteritems():
|
|
|
setattr(self, param, value)
|
|
|
|
|
|
@@ -61,7 +66,7 @@ class DeflateFrame(Extension):
|
|
|
name = 'deflate-frame'
|
|
|
rsv1 = True
|
|
|
# FIXME: is 32768 (below) correct?
|
|
|
- defaults = {'max_window_bits': 15, 'no_context_takeover': True}
|
|
|
+ defaults = {'max_window_bits': 15, 'no_context_takeover': False}
|
|
|
|
|
|
def __init__(self, defaults={}, request={}):
|
|
|
Extension.__init__(self, defaults, request)
|
|
|
@@ -78,24 +83,19 @@ class DeflateFrame(Extension):
|
|
|
raise ValueError('"no_context_takeover" must have no value')
|
|
|
|
|
|
class Hook(Extension.Hook):
|
|
|
- def __init__(self, **kwargs):
|
|
|
- Extension.Hook.__init__(**kwargs)
|
|
|
-
|
|
|
- other_wbits = self.request.get('max_window_bits', 15)
|
|
|
-
|
|
|
- # Don't request default value of max_window_bits
|
|
|
- if 'max_window_bits' in self.request and other_wbits == 15:
|
|
|
- del self.request['max_window_bits']
|
|
|
+ def __init__(self, extension, **kwargs):
|
|
|
+ Extension.Hook.__init__(self, extension, **kwargs)
|
|
|
|
|
|
if not self.no_context_takeover:
|
|
|
- self.com = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
|
|
|
+ self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
|
|
|
zlib.DEFLATED,
|
|
|
- self.max_window_bits)
|
|
|
+ -self.max_window_bits)
|
|
|
|
|
|
- self.dec = zlib.decompressobj(other_wbits)
|
|
|
+ other_wbits = self.extension.request.get('max_window_bits', 15)
|
|
|
+ self.dec = zlib.decompressobj(-other_wbits)
|
|
|
|
|
|
def send(self, frame):
|
|
|
- if not frame.rsv1:
|
|
|
+ if not frame.rsv1 and not isinstance(frame, ControlFrame):
|
|
|
frame.rsv1 = True
|
|
|
frame.payload = self.deflate(frame.payload)
|
|
|
|
|
|
@@ -103,6 +103,9 @@ class DeflateFrame(Extension):
|
|
|
|
|
|
def recv(self, frame):
|
|
|
if frame.rsv1:
|
|
|
+ if isinstance(frame, ControlFrame):
|
|
|
+ raise SocketClosed('received compressed control frame')
|
|
|
+
|
|
|
frame.rsv1 = False
|
|
|
frame.payload = self.inflate(frame.payload)
|
|
|
|
|
|
@@ -110,17 +113,21 @@ class DeflateFrame(Extension):
|
|
|
|
|
|
def deflate(self, data):
|
|
|
if self.no_context_takeover:
|
|
|
- if self.max_window_bits == 15:
|
|
|
- return zlib.compress(data)
|
|
|
-
|
|
|
- self.com = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
|
|
|
- zlib.DEFLATED,
|
|
|
- self.max_window_bits)
|
|
|
+ defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
|
|
|
+ zlib.DEFLATED, -self.max_window_bits)
|
|
|
+ # FIXME: why the '\x00' below? This was borrowed from
|
|
|
+ # https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
|
|
|
+ return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
|
|
|
|
|
|
- return self.com.compress(data)
|
|
|
+ compressed = self.defl.compress(data)
|
|
|
+ compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
|
|
|
+ assert compressed[-4:] == '\x00\x00\xff\xff'
|
|
|
+ return compressed[:-4]
|
|
|
|
|
|
def inflate(self, data):
|
|
|
- return self.dec.decompress(data)
|
|
|
+ data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
|
|
|
+ assert not self.dec.unused_data
|
|
|
+ return data
|
|
|
|
|
|
|
|
|
class Multiplex(Extension):
|