Эх сурвалжийг харах

Rewrote extensions API + reimplemented deflate-frame

Taddeus Kroes 11 жил өмнө
parent
commit
9232e5d4b4
7 өөрчлөгдсөн 144 нэмэгдсэн , 126 устгасан
  1. 1 1
      __init__.py
  2. 38 43
      deflate_frame.py
  3. 62 42
      extension.py
  4. 27 25
      handshake.py
  5. 1 3
      test/client.py
  6. 4 4
      test/server.py
  7. 11 8
      websocket.py

+ 1 - 1
__init__.py

@@ -10,5 +10,5 @@ from connection import Connection
 from message import Message, TextMessage, BinaryMessage
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from extension import Extension
-from deflate_frame import DeflateFrame, WebkitDeflateFrame
+from deflate_frame import DeflateFrame
 from async import AsyncConnection, AsyncServer

+ 38 - 43
deflate_frame.py

@@ -17,40 +17,39 @@ class DeflateFrame(Extension):
     Note that the deflate and inflate hooks modify the RSV1 bit and payload of
     existing `Frame` objects.
     """
-
-    name = 'deflate-frame'
+    names = ('deflate-frame', 'x-webkit-deflate-frame')
     rsv1 = True
-    defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
-
-    COMPRESSION_THRESHOLD = 64  # minimal payload size for compression
-
-    def init(self):
-        mwb = self.defaults['max_window_bits']
-        cto = self.defaults['no_context_takeover']
-
-        if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
-            raise ValueError('"max_window_bits" must be in range 1-15')
-
-        if cto is not False and cto is not True:
-            raise ValueError('"no_context_takeover" must have no value')
-
-    class Hook(Extension.Hook):
-        def init(self, extension):
-            self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                         zlib.DEFLATED, -self.max_window_bits)
-            other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
-            self.dec = zlib.decompressobj(-other_wbits)
+    defaults = {
+        'max_window_bits': zlib.MAX_WBITS,
+        'no_context_takeover': False
+    }
+
+    compression_threshold = 64  # minimal payload size for compression
+
+    def negotiate(self, name, params):
+        if 'max_window_bits' in params:
+            mwb = int(params['max_window_bits'])
+            assert 8 <= mwb <= zlib.MAX_WBITS
+            yield 'max_window_bits', mwb
+
+        if 'no_context_takeover' in params:
+            assert params['no_context_takeover'] is True
+            yield 'no_context_takeover', True
+
+    class Instance(Extension.Instance):
+        def init(self):
+            if not self.no_context_takeover:
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                        zlib.DEFLATED, -self.max_window_bits)
+                self.dec = zlib.decompressobj(-self.max_window_bits)
 
-        def send(self, frame):
-            # FIXME: this does not seem to work properly on Android
+        def onsend_frame(self, frame):
             if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
-                   len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
+                   len(frame.payload) > self.extension.compression_threshold:
                 frame.rsv1 = True
                 frame.payload = self.deflate(frame)
 
-            return frame
-
-        def recv(self, frame):
+        def onrecv_frame(self, frame):
             if frame.rsv1:
                 if isinstance(frame, ControlFrame):
                     raise ValueError('received compressed control frame')
@@ -58,26 +57,22 @@ class DeflateFrame(Extension):
                 frame.rsv1 = False
                 frame.payload = self.inflate(frame.payload)
 
-            return frame
-
         def deflate(self, frame):
-            compressed = self.defl.compress(frame.payload)
-
-            if frame.final or self.no_context_takeover:
-                compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
-                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                        zlib.DEFLATED, -self.max_window_bits)
+            if self.no_context_takeover:
+                print 'no_context_takeover'
+                compressed = zlib.compress(frame.payload)
             else:
+                compressed = self.defl.compress(frame.payload)
                 compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
-                assert compressed[-4:] == '\x00\x00\xff\xff'
-                compressed = compressed[:-4]
 
-            return compressed
+            assert compressed[-4:] == '\x00\x00\xff\xff'
+            return compressed[:-4]
 
         def inflate(self, data):
-            return self.dec.decompress(data + '\x00\x00\xff\xff') + \
-                   self.dec.flush(zlib.Z_SYNC_FLUSH)
+            data = str(data + '\x00\x00\xff\xff')
 
+            if self.no_context_takeover:
+                dec = zlib.decompressobj(-self.max_window_bits)
+                return dec.decompress(data) + dec.flush()
 
-class WebkitDeflateFrame(DeflateFrame):
-    name = 'x-webkit-deflate-frame'
+            return self.dec.decompress(data)

+ 62 - 42
extension.py

@@ -3,67 +3,87 @@ class Extension(object):
     rsv1 = False
     rsv2 = False
     rsv3 = False
-    opcodes = []
+    opcodes = ()
     defaults = {}
-    request = {}
 
-    def __init__(self, defaults={}, request={}):
-        for param in defaults.keys() + request.keys():
+    def __init__(self, **kwargs):
+        for param in kwargs.iterkeys():
             if param not in self.defaults:
                 raise KeyError('unrecognized parameter "%s"' % param)
 
         # Copy dict first to avoid duplicate references to the same object
         self.defaults = dict(self.__class__.defaults)
-        self.defaults.update(defaults)
-
-        self.request = dict(self.__class__.request)
-        self.request.update(request)
-
-        self.init()
+        self.defaults.update(kwargs)
 
     def __str__(self):
         return '<Extension "%s" defaults=%s request=%s>' \
                % (self.name, self.defaults, self.request)
 
-    def init(self):
-        return NotImplemented
+    @property
+    def names(self):
+        return (self.name,) if self.name else ()
+
+    def conflicts(self, ext):
+        """
+        Check if the extension conflicts with an already accepted extension.
+        This may be the case when the two extensions use the same reserved
+        bits, or have the same name (when the same extension is negotiated
+        multiple times with different parameters).
+        """
+        return ext.rsv1 and self.rsv1 \
+            or ext.rsv2 and self.rsv2 \
+            or ext.rsv3 and self.rsv3 \
+            or set(ext.names) & set(self.names) \
+            or set(ext.opcodes) & set(self.opcodes)
+
+    def negotiate(self, name, params):
+        """
+        Same as `negotiate_safe`, but instead returns an iterator of (param,
+        value) tuples and raises an exception on error.
+        """
+        raise NotImplementedError
+
+    def negotiate_safe(self, name, params):
+        """
+        `name` and `params` are sent in the HTTP request by the client. Check
+        if the extension name is supported by this extension, and validate the
+        parameters. Returns a dict with accepted parameters, or None if not
+        accepted.
+        """
+        for param in params.iterkeys():
+            if param not in self.defaults:
+                return
+
+        try:
+            return dict(self.negotiate(name, params))
+        except (KeyError, ValueError, AssertionError):
+            pass
 
-    def create_hook(self, **kwargs):
-        params = {}
-        params.update(self.defaults)
-        params.update(kwargs)
-        hook = self.Hook(**params)
-        hook.init(self)
-        return hook
+    class Instance:
+        def __init__(self, extension, name, params):
+            self.extension = extension
+            self.name = name
+            self.params = params
 
-    class Hook:
-        def __init__(self, **kwargs):
-            for param, value in kwargs.iteritems():
+            for param, value in extension.defaults.iteritems():
                 setattr(self, param, value)
 
-        def init(self, extension):
-            return NotImplemented
+            for param, value in params.iteritems():
+                setattr(self, param, value)
 
-        def send(self, frame):
-            return frame
+            self.init()
 
-        def recv(self, frame):
-            return frame
+        def init(self):
+            return NotImplemented
 
+        def onsend_frame(self, frame):
+            pass
 
-def extension_conflicts(ext, existing):
-    rsv1_reserved = False
-    rsv2_reserved = False
-    rsv3_reserved = False
-    reserved_opcodes = []
+        def onrecv_frame(self, frame):
+            pass
 
-    for e in existing:
-        rsv1_reserved |= e.rsv1
-        rsv2_reserved |= e.rsv2
-        rsv3_reserved |= e.rsv3
-        reserved_opcodes.extend(e.opcodes)
+        def onsend_message(self, message):
+            pass
 
-    return ext.rsv1 and rsv1_reserved \
-            or ext.rsv2 and rsv2_reserved \
-            or ext.rsv3 and rsv3_reserved \
-            or len(set(ext.opcodes) & set(reserved_opcodes))
+        def onrecv_message(self, message):
+            pass

+ 27 - 25
handshake.py

@@ -7,7 +7,6 @@ from hashlib import sha1
 from urlparse import urlparse
 
 from errors import HandshakeError
-from extension import extension_conflicts
 from python_digest import build_authorization_request
 
 
@@ -172,20 +171,19 @@ class ServerHandshake(Handshake):
 
         # Only supported extensions are returned
         if 'Sec-WebSocket-Extensions' in headers:
-            supported_ext = dict((e.name, e) for e in ssock.extensions)
-            self.wsock.extension_hooks = []
-            extensions = []
+            self.wsock.extension_instances = []
 
-            for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
-                name, params = parse_param_hdr(ext)
+            for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
+                name, params = parse_param_hdr(hdr)
 
-                if name in supported_ext:
-                    ext = supported_ext[name]
+                for ext in ssock.extensions:
+                    if not any(ext.conflicts(other.extension)
+                               for other in self.wsock.extension_instances):
+                        accept_params = ext.negotiate_safe(name, params)
 
-                    if not extension_conflicts(ext, extensions):
-                        extensions.append(ext)
-                        hook = ext.create_hook(**params)
-                        self.wsock.extension_hooks.append(hook)
+                        if accept_params is not None:
+                            instance = ext.Instance(ext, name, accept_params)
+                            self.wsock.extension_instances.append(instance)
 
         # Check if requested resource location is served by this server
         if ssock.locations:
@@ -222,9 +220,9 @@ class ServerHandshake(Handshake):
         if self.wsock.protocol:
             yield 'Sec-WebSocket-Protocol', self.wsock.protocol
 
-        if self.wsock.extensions:
-            values = [format_param_hdr(e.name, e.request)
-                      for e in self.wsock.extensions]
+        if self.wsock.extension_instances:
+            values = [format_param_hdr(i.name, i.params)
+                      for i in self.wsock.extension_instances]
             yield 'Sec-WebSocket-Extensions', ', '.join(values)
 
 
@@ -273,19 +271,23 @@ class ClientHandshake(Handshake):
 
         # Compare extensions, add hooks only for those returned by server
         if 'Sec-WebSocket-Extensions' in headers:
-            supported_ext = dict((e.name, e) for e in self.wsock.extensions)
-            self.wsock.extension_hooks = []
-
-            for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
-                name, params = parse_param_hdr(ext)
-
-                if name not in supported_ext:
+            # FIXME: there is no distinction between server/client extension
+            # instances, while the extension instance may assume it belongs to
+            # a server, leading to undefined behavior
+            self.wsock.extension_instances = []
+
+            for hdr in split_stripped(headers['Sec-WebSocket-Extensions']):
+                name, accept_params = parse_param_hdr(hdr)
+
+                for ext in self.wsock.extensions:
+                    if name in ext.names:
+                        instance = ext.Instance(ext, name, accept_params)
+                        self.wsock.extension_instances.append(instance)
+                        break
+                else:
                     raise HandshakeError('server handshake contains '
                                          'unsupported extension "%s"' % name)
 
-                hook = supported_ext[name].create_hook(**params)
-                self.wsock.extension_hooks.append(hook)
-
         # Assert that returned protocol (if any) is supported
         if 'Sec-WebSocket-Protocol' in headers:
             protocol = headers['Sec-WebSocket-Protocol']

+ 1 - 3
test/client.py

@@ -9,7 +9,6 @@ sys.path.insert(0, basepath)
 from websocket import websocket
 from connection import Connection
 from message import TextMessage
-from errors import SocketClosed
 
 ADDR = ('localhost', 8000)
 
@@ -30,9 +29,8 @@ class EchoClient(Connection):
         print 'Connection closed'
 
 
-secure = True
-
 if __name__ == '__main__':
+    secure = '-s' in sys.argv[1:]
     scheme = 'wss' if secure else 'ws'
     print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
     sock = websocket()

+ 4 - 4
test/server.py

@@ -7,7 +7,7 @@ basepath = abspath(dirname(abspath(__file__)) + '/..')
 sys.path.insert(0, basepath)
 
 from server import Server
-from deflate_frame import WebkitDeflateFrame
+from deflate_frame import DeflateFrame
 
 
 class EchoServer(Server):
@@ -17,8 +17,8 @@ class EchoServer(Server):
 
 
 if __name__ == '__main__':
-    deflate = WebkitDeflateFrame()
-    #deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True})
-    EchoServer(('localhost', 8000), extensions=[deflate],
+    EchoServer(('localhost', 8000),
+               #extensions=[DeflateFrame(no_context_takeover=True)],
+               extensions=[DeflateFrame()],
                #ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
                loglevel=logging.DEBUG).run()

+ 11 - 8
websocket.py

@@ -81,7 +81,7 @@ class websocket(object):
         """
         self.protocols = protocols
         self.extensions = extensions
-        self.extension_hooks = []
+        self.extension_instances = []
         self.origin = origin
         self.location = location
         self.trusted_origins = trusted_origins
@@ -99,9 +99,6 @@ class websocket(object):
 
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
-    def set_extensions(self, extensions):
-        self.extensions = [ext.Hook() for ext in extensions]
-
     def __getattr__(self, name):
         if name in INHERITED_ATTRS:
             return getattr(self.sock, name)
@@ -135,14 +132,20 @@ class websocket(object):
         self.handshake_sent = True
 
     def apply_send_hooks(self, frame):
-        for hook in self.extension_hooks:
-            frame = hook.send(frame)
+        for inst in self.extension_instances:
+            replacement = inst.onsend_frame(frame)
+
+            if replacement is not None:
+                frame = replacement
 
         return frame
 
     def apply_recv_hooks(self, frame):
-        for hook in reversed(self.extension_hooks):
-            frame = hook.recv(frame)
+        for inst in reversed(self.extension_instances):
+            replacement = inst.onrecv_frame(frame)
+
+            if replacement is not None:
+                frame = replacement
 
         return frame