浏览代码

Rewrote extensions API + reimplemented deflate-frame

Taddeus Kroes 11 年之前
父节点
当前提交
9232e5d4b4
共有 7 个文件被更改,包括 144 次插入126 次删除
  1. 1 1
      __init__.py
  2. 38 43
      deflate_frame.py
  3. 62 42
      extension.py
  4. 27 25
      handshake.py
  5. 1 3
      test/client.py
  6. 4 4
      test/server.py
  7. 11 8
      websocket.py

+ 1 - 1
__init__.py

@@ -10,5 +10,5 @@ from connection import Connection
 from message import Message, TextMessage, BinaryMessage
 from message import Message, TextMessage, BinaryMessage
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from extension import Extension
 from extension import Extension
-from deflate_frame import DeflateFrame, WebkitDeflateFrame
+from deflate_frame import DeflateFrame
 from async import AsyncConnection, AsyncServer
 from async import AsyncConnection, AsyncServer

+ 38 - 43
deflate_frame.py

@@ -17,40 +17,39 @@ class DeflateFrame(Extension):
     Note that the deflate and inflate hooks modify the RSV1 bit and payload of
     Note that the deflate and inflate hooks modify the RSV1 bit and payload of
     existing `Frame` objects.
     existing `Frame` objects.
     """
     """
-
-    name = 'deflate-frame'
+    names = ('deflate-frame', 'x-webkit-deflate-frame')
     rsv1 = True
     rsv1 = True
-    defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
-
-    COMPRESSION_THRESHOLD = 64  # minimal payload size for compression
-
-    def init(self):
-        mwb = self.defaults['max_window_bits']
-        cto = self.defaults['no_context_takeover']
-
-        if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
-            raise ValueError('"max_window_bits" must be in range 1-15')
-
-        if cto is not False and cto is not True:
-            raise ValueError('"no_context_takeover" must have no value')
-
-    class Hook(Extension.Hook):
-        def init(self, extension):
-            self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                         zlib.DEFLATED, -self.max_window_bits)
-            other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
-            self.dec = zlib.decompressobj(-other_wbits)
+    defaults = {
+        'max_window_bits': zlib.MAX_WBITS,
+        'no_context_takeover': False
+    }
+
+    compression_threshold = 64  # minimal payload size for compression
+
+    def negotiate(self, name, params):
+        if 'max_window_bits' in params:
+            mwb = int(params['max_window_bits'])
+            assert 8 <= mwb <= zlib.MAX_WBITS
+            yield 'max_window_bits', mwb
+
+        if 'no_context_takeover' in params:
+            assert params['no_context_takeover'] is True
+            yield 'no_context_takeover', True
+
+    class Instance(Extension.Instance):
+        def init(self):
+            if not self.no_context_takeover:
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                        zlib.DEFLATED, -self.max_window_bits)
+                self.dec = zlib.decompressobj(-self.max_window_bits)
 
 
-        def send(self, frame):
-            # FIXME: this does not seem to work properly on Android
+        def onsend_frame(self, frame):
             if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
             if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
-                   len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
+                   len(frame.payload) > self.extension.compression_threshold:
                 frame.rsv1 = True
                 frame.rsv1 = True
                 frame.payload = self.deflate(frame)
                 frame.payload = self.deflate(frame)
 
 
-            return frame
-
-        def recv(self, frame):
+        def onrecv_frame(self, frame):
             if frame.rsv1:
             if frame.rsv1:
                 if isinstance(frame, ControlFrame):
                 if isinstance(frame, ControlFrame):
                     raise ValueError('received compressed control frame')
                     raise ValueError('received compressed control frame')
@@ -58,26 +57,22 @@ class DeflateFrame(Extension):
                 frame.rsv1 = False
                 frame.rsv1 = False
                 frame.payload = self.inflate(frame.payload)
                 frame.payload = self.inflate(frame.payload)
 
 
-            return frame
-
         def deflate(self, frame):
         def deflate(self, frame):
-            compressed = self.defl.compress(frame.payload)
-
-            if frame.final or self.no_context_takeover:
-                compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
-                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                        zlib.DEFLATED, -self.max_window_bits)
+            if self.no_context_takeover:
+                print 'no_context_takeover'
+                compressed = zlib.compress(frame.payload)
             else:
             else:
+                compressed = self.defl.compress(frame.payload)
                 compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
                 compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
-                assert compressed[-4:] == '\x00\x00\xff\xff'
-                compressed = compressed[:-4]
 
 
-            return compressed
+            assert compressed[-4:] == '\x00\x00\xff\xff'
+            return compressed[:-4]
 
 
         def inflate(self, data):
         def inflate(self, data):
-            return self.dec.decompress(data + '\x00\x00\xff\xff') + \
-                   self.dec.flush(zlib.Z_SYNC_FLUSH)
+            data = str(data + '\x00\x00\xff\xff')
 
 
+            if self.no_context_takeover:
+                dec = zlib.decompressobj(-self.max_window_bits)
+                return dec.decompress(data) + dec.flush()
 
 
-class WebkitDeflateFrame(DeflateFrame):
-    name = 'x-webkit-deflate-frame'
+            return self.dec.decompress(data)

+ 62 - 42
extension.py

@@ -3,67 +3,87 @@ class Extension(object):
     rsv1 = False
     rsv1 = False
     rsv2 = False
     rsv2 = False
     rsv3 = False
     rsv3 = False
-    opcodes = []
+    opcodes = ()
     defaults = {}
     defaults = {}
-    request = {}
 
 
-    def __init__(self, defaults={}, request={}):
-        for param in defaults.keys() + request.keys():
+    def __init__(self, **kwargs):
+        for param in kwargs.iterkeys():
             if param not in self.defaults:
             if param not in self.defaults:
                 raise KeyError('unrecognized parameter "%s"' % param)
                 raise KeyError('unrecognized parameter "%s"' % param)
 
 
         # Copy dict first to avoid duplicate references to the same object
         # Copy dict first to avoid duplicate references to the same object
         self.defaults = dict(self.__class__.defaults)
         self.defaults = dict(self.__class__.defaults)
-        self.defaults.update(defaults)
-
-        self.request = dict(self.__class__.request)
-        self.request.update(request)
-
-        self.init()
+        self.defaults.update(kwargs)
 
 
     def __str__(self):
     def __str__(self):
         return '<Extension "%s" defaults=%s request=%s>' \
         return '<Extension "%s" defaults=%s request=%s>' \
                % (self.name, self.defaults, self.request)
                % (self.name, self.defaults, self.request)
 
 
-    def init(self):
-        return NotImplemented
+    @property
+    def names(self):
+        return (self.name,) if self.name else ()
+
+    def conflicts(self, ext):
+        """
+        Check if the extension conflicts with an already accepted extension.
+        This may be the case when the two extensions use the same reserved
+        bits, or have the same name (when the same extension is negotiated
+        multiple times with different parameters).
+        """
+        return ext.rsv1 and self.rsv1 \
+            or ext.rsv2 and self.rsv2 \
+            or ext.rsv3 and self.rsv3 \
+            or set(ext.names) & set(self.names) \
+            or set(ext.opcodes) & set(self.opcodes)
+
+    def negotiate(self, name, params):
+        """
+        Same as `negotiate_safe`, but instead returns an iterator of (param,
+        value) tuples and raises an exception on error.
+        """
+        raise NotImplementedError
+
+    def negotiate_safe(self, name, params):
+        """
+        `name` and `params` are sent in the HTTP request by the client. Check
+        if the extension name is supported by this extension, and validate the
+        parameters. Returns a dict with accepted parameters, or None if not
+        accepted.
+        """
+        for param in params.iterkeys():
+            if param not in self.defaults:
+                return
+
+        try:
+            return dict(self.negotiate(name, params))
+        except (KeyError, ValueError, AssertionError):
+            pass
 
 
-    def create_hook(self, **kwargs):
-        params = {}
-        params.update(self.defaults)
-        params.update(kwargs)
-        hook = self.Hook(**params)
-        hook.init(self)
-        return hook
+    class Instance:
+        def __init__(self, extension, name, params):
+            self.extension = extension
+            self.name = name
+            self.params = params
 
 
-    class Hook:
-        def __init__(self, **kwargs):
-            for param, value in kwargs.iteritems():
+            for param, value in extension.defaults.iteritems():
                 setattr(self, param, value)
                 setattr(self, param, value)
 
 
-        def init(self, extension):
-            return NotImplemented
+            for param, value in params.iteritems():
+                setattr(self, param, value)
 
 
-        def send(self, frame):
-            return frame
+            self.init()
 
 
-        def recv(self, frame):
-            return frame
+        def init(self):
+            return NotImplemented
 
 
+        def onsend_frame(self, frame):
+            pass
 
 
-def extension_conflicts(ext, existing):
-    rsv1_reserved = False
-    rsv2_reserved = False
-    rsv3_reserved = False
-    reserved_opcodes = []
+        def onrecv_frame(self, frame):
+            pass
 
 
-    for e in existing:
-        rsv1_reserved |= e.rsv1
-        rsv2_reserved |= e.rsv2
-        rsv3_reserved |= e.rsv3
-        reserved_opcodes.extend(e.opcodes)
+        def onsend_message(self, message):
+            pass
 
 
-    return ext.rsv1 and rsv1_reserved \
-            or ext.rsv2 and rsv2_reserved \
-            or ext.rsv3 and rsv3_reserved \
-            or len(set(ext.opcodes) & set(reserved_opcodes))
+        def onrecv_message(self, message):
+            pass

+ 27 - 25
handshake.py

@@ -7,7 +7,6 @@ from hashlib import sha1
 from urlparse import urlparse
 from urlparse import urlparse
 
 
 from errors import HandshakeError
 from errors import HandshakeError
-from extension import extension_conflicts
 from python_digest import build_authorization_request
 from python_digest import build_authorization_request
 
 
 
 
@@ -172,20 +171,19 @@ class ServerHandshake(Handshake):
 
 
         # Only supported extensions are returned
         # Only supported extensions are returned
         if 'Sec-WebSocket-Extensions' in headers:
         if 'Sec-WebSocket-Extensions' in headers:
-            supported_ext = dict((e.name, e) for e in ssock.extensions)
-            self.wsock.extension_hooks = []
-            extensions = []
+            self.wsock.extension_instances = []
 
 
-            for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
-                name, params = parse_param_hdr(ext)
+            for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
+                name, params = parse_param_hdr(hdr)
 
 
-                if name in supported_ext:
-                    ext = supported_ext[name]
+                for ext in ssock.extensions:
+                    if not any(ext.conflicts(other.extension)
+                               for other in self.wsock.extension_instances):
+                        accept_params = ext.negotiate_safe(name, params)
 
 
-                    if not extension_conflicts(ext, extensions):
-                        extensions.append(ext)
-                        hook = ext.create_hook(**params)
-                        self.wsock.extension_hooks.append(hook)
+                        if accept_params is not None:
+                            instance = ext.Instance(ext, name, accept_params)
+                            self.wsock.extension_instances.append(instance)
 
 
         # Check if requested resource location is served by this server
         # Check if requested resource location is served by this server
         if ssock.locations:
         if ssock.locations:
@@ -222,9 +220,9 @@ class ServerHandshake(Handshake):
         if self.wsock.protocol:
         if self.wsock.protocol:
             yield 'Sec-WebSocket-Protocol', self.wsock.protocol
             yield 'Sec-WebSocket-Protocol', self.wsock.protocol
 
 
-        if self.wsock.extensions:
-            values = [format_param_hdr(e.name, e.request)
-                      for e in self.wsock.extensions]
+        if self.wsock.extension_instances:
+            values = [format_param_hdr(i.name, i.params)
+                      for i in self.wsock.extension_instances]
             yield 'Sec-WebSocket-Extensions', ', '.join(values)
             yield 'Sec-WebSocket-Extensions', ', '.join(values)
 
 
 
 
@@ -273,19 +271,23 @@ class ClientHandshake(Handshake):
 
 
         # Compare extensions, add hooks only for those returned by server
         # Compare extensions, add hooks only for those returned by server
         if 'Sec-WebSocket-Extensions' in headers:
         if 'Sec-WebSocket-Extensions' in headers:
-            supported_ext = dict((e.name, e) for e in self.wsock.extensions)
-            self.wsock.extension_hooks = []
-
-            for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
-                name, params = parse_param_hdr(ext)
-
-                if name not in supported_ext:
+            # FIXME: there is no distinction between server/client extension
+            # instances, while the extension instance may assume it belongs to
+            # a server, leading to undefined behavior
+            self.wsock.extension_instances = []
+
+            for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
+                name, accept_params = parse_param_hdr(hdr)
+
+                for ext in self.wsock.extensions:
+                    if name in ext.names:
+                        instance = ext.Instance(ext, name, accept_params)
+                        self.wsock.extension_instances.append(instance)
+                        break
+                else:
                     raise HandshakeError('server handshake contains '
                     raise HandshakeError('server handshake contains '
                                          'unsupported extension "%s"' % name)
                                          'unsupported extension "%s"' % name)
 
 
-                hook = supported_ext[name].create_hook(**params)
-                self.wsock.extension_hooks.append(hook)
-
         # Assert that returned protocol (if any) is supported
         # Assert that returned protocol (if any) is supported
         if 'Sec-WebSocket-Protocol' in headers:
         if 'Sec-WebSocket-Protocol' in headers:
             protocol = headers['Sec-WebSocket-Protocol']
             protocol = headers['Sec-WebSocket-Protocol']

+ 1 - 3
test/client.py

@@ -9,7 +9,6 @@ sys.path.insert(0, basepath)
 from websocket import websocket
 from websocket import websocket
 from connection import Connection
 from connection import Connection
 from message import TextMessage
 from message import TextMessage
-from errors import SocketClosed
 
 
 ADDR = ('localhost', 8000)
 ADDR = ('localhost', 8000)
 
 
@@ -30,9 +29,8 @@ class EchoClient(Connection):
         print 'Connection closed'
         print 'Connection closed'
 
 
 
 
-secure = True
-
 if __name__ == '__main__':
 if __name__ == '__main__':
+    secure = '-s' in sys.argv[1:]
     scheme = 'wss' if secure else 'ws'
     scheme = 'wss' if secure else 'ws'
     print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
     print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
     sock = websocket()
     sock = websocket()

+ 4 - 4
test/server.py

@@ -7,7 +7,7 @@ basepath = abspath(dirname(abspath(__file__)) + '/..')
 sys.path.insert(0, basepath)
 sys.path.insert(0, basepath)
 
 
 from server import Server
 from server import Server
-from deflate_frame import WebkitDeflateFrame
+from deflate_frame import DeflateFrame
 
 
 
 
 class EchoServer(Server):
 class EchoServer(Server):
@@ -17,8 +17,8 @@ class EchoServer(Server):
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    deflate = WebkitDeflateFrame()
-    #deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True})
-    EchoServer(('localhost', 8000), extensions=[deflate],
+    EchoServer(('localhost', 8000),
+               #extensions=[DeflateFrame(no_context_takeover=True)],
+               extensions=[DeflateFrame()],
                #ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
                #ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
                loglevel=logging.DEBUG).run()
                loglevel=logging.DEBUG).run()

+ 11 - 8
websocket.py

@@ -81,7 +81,7 @@ class websocket(object):
         """
         """
         self.protocols = protocols
         self.protocols = protocols
         self.extensions = extensions
         self.extensions = extensions
-        self.extension_hooks = []
+        self.extension_instances = []
         self.origin = origin
         self.origin = origin
         self.location = location
         self.location = location
         self.trusted_origins = trusted_origins
         self.trusted_origins = trusted_origins
@@ -99,9 +99,6 @@ class websocket(object):
 
 
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
 
-    def set_extensions(self, extensions):
-        self.extensions = [ext.Hook() for ext in extensions]
-
     def __getattr__(self, name):
     def __getattr__(self, name):
         if name in INHERITED_ATTRS:
         if name in INHERITED_ATTRS:
             return getattr(self.sock, name)
             return getattr(self.sock, name)
@@ -135,14 +132,20 @@ class websocket(object):
         self.handshake_sent = True
         self.handshake_sent = True
 
 
     def apply_send_hooks(self, frame):
     def apply_send_hooks(self, frame):
-        for hook in self.extension_hooks:
-            frame = hook.send(frame)
+        for inst in self.extension_instances:
+            replacement = inst.onsend_frame(frame)
+
+            if replacement is not None:
+                frame = replacement
 
 
         return frame
         return frame
 
 
     def apply_recv_hooks(self, frame):
     def apply_recv_hooks(self, frame):
-        for hook in reversed(self.extension_hooks):
-            frame = hook.recv(frame)
+        for inst in reversed(self.extension_instances):
+            replacement = inst.onrecv_frame(frame)
+
+            if replacement is not None:
+                frame = replacement
 
 
         return frame
         return frame