Bläddra i källkod

Fixed deflate-frame implementation

Taddeus Kroes 12 år sedan
förälder
incheckning
91ec179511
2 ändrade filer med 34 tillägg och 25 borttagningar
  1. 31 24
      extension.py
  2. 3 1
      test/server.py

+ 31 - 24
extension.py

@@ -1,5 +1,8 @@
 import zlib
 
+from frame import ControlFrame
+from errors import SocketClosed
+
 
 class Extension(object):
     name = ''
@@ -22,7 +25,7 @@ class Extension(object):
         self.request = dict(self.__class__.request)
         self.request.update(request)
 
-    def __str__(self, frame):
+    def __str__(self):
         return '<Extension "%s" defaults=%s request=%s>' \
                % (self.name, self.defaults, self.request)
 
@@ -30,10 +33,12 @@ class Extension(object):
         params = {}
         params.update(self.defaults)
         params.update(kwargs)
-        return self.Hook(**params)
+        return self.Hook(self, **params)
 
     class Hook:
-        def __init__(self, **kwargs):
+        def __init__(self, extension, **kwargs):
+            self.extension = extension
+
             for param, value in kwargs.iteritems():
                 setattr(self, param, value)
 
@@ -61,7 +66,7 @@ class DeflateFrame(Extension):
     name = 'deflate-frame'
     rsv1 = True
     # FIXME: is 32768 (below) correct?
-    defaults = {'max_window_bits': 15, 'no_context_takeover': True}
+    defaults = {'max_window_bits': 15, 'no_context_takeover': False}
 
     def __init__(self, defaults={}, request={}):
         Extension.__init__(self, defaults, request)
@@ -78,24 +83,19 @@ class DeflateFrame(Extension):
             raise ValueError('"no_context_takeover" must have no value')
 
     class Hook(Extension.Hook):
-        def __init__(self, **kwargs):
-            Extension.Hook.__init__(**kwargs)
-
-            other_wbits = self.request.get('max_window_bits', 15)
-
-            # Don't request default value of max_window_bits
-            if 'max_window_bits' in self.request and other_wbits == 15:
-                del self.request['max_window_bits']
+        def __init__(self, extension, **kwargs):
+            Extension.Hook.__init__(self, extension, **kwargs)
 
             if not self.no_context_takeover:
-                self.com = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
                                             zlib.DEFLATED,
-                                            self.max_window_bits)
+                                            -self.max_window_bits)
 
-            self.dec = zlib.decompressobj(other_wbits)
+            other_wbits = self.extension.request.get('max_window_bits', 15)
+            self.dec = zlib.decompressobj(-other_wbits)
 
         def send(self, frame):
-            if not frame.rsv1:
+            if not frame.rsv1 and not isinstance(frame, ControlFrame):
                 frame.rsv1 = True
                 frame.payload = self.deflate(frame.payload)
 
@@ -103,6 +103,9 @@ class DeflateFrame(Extension):
 
         def recv(self, frame):
             if frame.rsv1:
+                if isinstance(frame, ControlFrame):
+                    raise SocketClosed('received compressed control frame')
+
                 frame.rsv1 = False
                 frame.payload = self.inflate(frame.payload)
 
@@ -110,17 +113,21 @@ class DeflateFrame(Extension):
 
         def deflate(self, data):
             if self.no_context_takeover:
-                if self.max_window_bits == 15:
-                    return zlib.compress(data)
-
-                self.com = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                            zlib.DEFLATED,
-                                            self.max_window_bits)
+                defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                                        zlib.DEFLATED, -self.max_window_bits)
+                # FIXME: why the '\x00' below? This was borrowed from
+                # https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
+                return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
 
-            return self.com.compress(data)
+            compressed = self.defl.compress(data)
+            compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
+            assert compressed[-4:] == '\x00\x00\xff\xff'
+            return compressed[:-4]
 
         def inflate(self, data):
-            return self.dec.decompress(data)
+            data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
+            assert not self.dec.unused_data
+            return data
 
 
 class Multiplex(Extension):

+ 3 - 1
test/server.py

@@ -21,6 +21,8 @@ class WebkitDeflateFrame(DeflateFrame):
 
 
 if __name__ == '__main__':
-    EchoServer(('localhost', 8000), extensions=[WebkitDeflateFrame()],
+    deflate = WebkitDeflateFrame()
+    #deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True})
+    EchoServer(('localhost', 8000), extensions=[deflate],
                #ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
                loglevel=logging.DEBUG).run()