Przeglądaj źródła

Removed frame send/recv hooks, now working with a somewhat more robust extensions list

Taddeus Kroes 11 lat temu
rodzic
commit
7c8972d4f2
4 zmienionych plików z 89 dodań i 128 usunięć
  1. 27 32
      deflate_frame.py
  2. 23 24
      extension.py
  3. 9 14
      handshake.py
  4. 30 58
      websocket.py

+ 27 - 32
deflate_frame.py

@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
 
     name = 'deflate-frame'
     rsv1 = True
-    defaults = {'max_window_bits': 15, 'no_context_takeover': False}
+    defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
 
-    def __init__(self, defaults={}, request={}):
-        Extension.__init__(self, defaults, request)
+    COMPRESSION_THRESHOLD = 64  # minimal payload size for compression
 
+    def init(self):
         mwb = self.defaults['max_window_bits']
         cto = self.defaults['no_context_takeover']
 
-        if not isinstance(mwb, int):
-            raise ValueError('"max_window_bits" must be an integer')
-        elif mwb > 15:
-            raise ValueError('"max_window_bits" may not be larger than 15')
+        if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
+            raise ValueError('"max_window_bits" must be in range 1-15')
 
         if cto is not False and cto is not True:
             raise ValueError('"no_context_takeover" must have no value')
 
     class Hook(Extension.Hook):
-        def __init__(self, extension, **kwargs):
-            Extension.Hook.__init__(self, extension, **kwargs)
-
-            if not self.no_context_takeover:
-                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                            zlib.DEFLATED,
-                                            -self.max_window_bits)
-
-            other_wbits = self.extension.request.get('max_window_bits', 15)
+        def init(self, extension):
+            self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                                         zlib.DEFLATED, -self.max_window_bits)
+            other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
             self.dec = zlib.decompressobj(-other_wbits)
 
         def send(self, frame):
-            if not frame.rsv1 and not isinstance(frame, ControlFrame):
+            # FIXME: this does not seem to work properly on Android
+            if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
+                   len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
                 frame.rsv1 = True
-                frame.payload = self.deflate(frame.payload)
+                frame.payload = self.deflate(frame)
 
             return frame
 
@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
 
             return frame
 
-        def deflate(self, data):
-            if self.no_context_takeover:
-                defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                        zlib.DEFLATED, -self.max_window_bits)
-                # FIXME: why the '\x00' below? This was borrowed from
-                # https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
-                return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
+        def deflate(self, frame):
+            compressed = self.defl.compress(frame.payload)
+
+            if frame.final or self.no_context_takeover:
+                compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                        zlib.DEFLATED, -self.max_window_bits)
+            else:
+                compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
+                assert compressed[-4:] == '\x00\x00\xff\xff'
+                compressed = compressed[:-4]
 
-            compressed = self.defl.compress(data)
-            compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
-            assert compressed[-4:] == '\x00\x00\xff\xff'
-            return compressed[:-4]
+            return compressed
 
         def inflate(self, data):
-            data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
-            assert not self.dec.unused_data
-            return data
+            return self.dec.decompress(data + '\x00\x00\xff\xff') + \
+                   self.dec.flush(zlib.Z_SYNC_FLUSH)
 
 
 class WebkitDeflateFrame(DeflateFrame):

+ 23 - 24
extension.py

@@ -19,23 +19,31 @@ class Extension(object):
         self.request = dict(self.__class__.request)
         self.request.update(request)
 
+        self.init()
+
     def __str__(self):
         return '<Extension "%s" defaults=%s request=%s>' \
                % (self.name, self.defaults, self.request)
 
+    def init(self):
+        return NotImplemented
+
     def create_hook(self, **kwargs):
         params = {}
         params.update(self.defaults)
         params.update(kwargs)
-        return self.Hook(self, **params)
+        hook = self.Hook(**params)
+        hook.init(self)
+        return hook
 
     class Hook:
-        def __init__(self, extension, **kwargs):
-            self.extension = extension
-
+        def __init__(self, **kwargs):
             for param, value in kwargs.iteritems():
                 setattr(self, param, value)
 
+        def init(self, extension):
+            return NotImplemented
+
         def send(self, frame):
             return frame
 
@@ -43,28 +51,19 @@ class Extension(object):
             return frame
 
 
-def filter_extensions(extensions):
-    """
-    Remove extensions that use conflicting rsv bits and/or opcodes, with the
-    first options being the most preferable.
-    """
+def extension_conflicts(ext, existing):
     rsv1_reserved = False
     rsv2_reserved = False
     rsv3_reserved = False
-    opcodes_reserved = []
-    compat = []
-
-    for ext in extensions:
-        if ext.rsv1 and rsv1_reserved \
-                or ext.rsv2 and rsv2_reserved \
-                or ext.rsv3 and rsv3_reserved \
-                or len(set(ext.opcodes) & set(opcodes_reserved)):
-            continue
+    reserved_opcodes = []
 
-        rsv1_reserved |= ext.rsv1
-        rsv2_reserved |= ext.rsv2
-        rsv3_reserved |= ext.rsv3
-        opcodes_reserved.extend(ext.opcodes)
-        compat.append(ext)
+    for e in existing:
+        rsv1_reserved |= e.rsv1
+        rsv2_reserved |= e.rsv2
+        rsv3_reserved |= e.rsv3
+        reserved_opcodes.extend(e.opcodes)
 
-    return compat
+    return ext.rsv1 and rsv1_reserved \
+            or ext.rsv2 and rsv2_reserved \
+            or ext.rsv3 and rsv3_reserved \
+            or len(set(ext.opcodes) & set(reserved_opcodes))

+ 9 - 14
handshake.py

@@ -7,7 +7,7 @@ from hashlib import sha1
 from urlparse import urlparse
 
 from errors import HandshakeError
-from extension import filter_extensions
+from extension import extension_conflicts
 from python_digest import build_authorization_request
 
 
@@ -173,23 +173,19 @@ class ServerHandshake(Handshake):
         # Only supported extensions are returned
         if 'Sec-WebSocket-Extensions' in headers:
             supported_ext = dict((e.name, e) for e in ssock.extensions)
+            self.wsock.extension_hooks = []
             extensions = []
-            all_params = []
 
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
                 name, params = parse_param_hdr(ext)
 
                 if name in supported_ext:
-                    extensions.append(supported_ext[name])
-                    all_params.append(params)
+                    ext = supported_ext[name]
 
-            self.wsock.extensions = filter_extensions(extensions)
-
-            for ext, params in zip(self.wsock.extensions, all_params):
-                hook = ext.create_hook(**params)
-                self.wsock.add_hook(send=hook.send, recv=hook.recv)
-        else:
-            self.wsock.extensions = []
+                    if not extension_conflicts(ext, extensions):
+                        extensions.append(ext)
+                        hook = ext.create_hook(**params)
+                        self.wsock.extension_hooks.append(hook)
 
         # Check if requested resource location is served by this server
         if ssock.locations:
@@ -278,7 +274,7 @@ class ClientHandshake(Handshake):
         # Compare extensions, add hooks only for those returned by server
         if 'Sec-WebSocket-Extensions' in headers:
             supported_ext = dict((e.name, e) for e in self.wsock.extensions)
-            self.wsock.extensions = []
+            self.wsock.extension_hooks = []
 
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
                 name, params = parse_param_hdr(ext)
@@ -288,8 +284,7 @@ class ClientHandshake(Handshake):
                                          'unsupported extension "%s"' % name)
 
                 hook = supported_ext[name].create_hook(**params)
-                self.wsock.extensions.append(supported_ext[name])
-                self.wsock.add_hook(send=hook.send, recv=hook.recv)
+                self.wsock.extension_hooks.append(hook)
 
         # Assert that returned protocol (if any) is supported
         if 'Sec-WebSocket-Protocol' in headers:

+ 30 - 58
websocket.py

@@ -35,7 +35,7 @@ class websocket(object):
     >>> sock.connect(('', 8000))
     >>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
     """
-    def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
+    def __init__(self, sock=None, origin=None, protocols=[], extensions=[],
                  location='/', trusted_origins=[], locations=[], auth=None,
                  recv_callback=None, sfamily=socket.AF_INET, sproto=0):
         """
@@ -44,13 +44,14 @@ class websocket(object):
         `sock` is an optional regular TCP socket to be used for sending binary
         data. If not specified, a new socket is created.
 
-        `protocols` is a list of supported protocol names.
-
-        `extensions` is a list of supported extensions (`Extension` instances).
-
         `origin` (for client sockets) is the value for the "Origin" header sent
         in a client handshake .
 
+        `protocols` is a list of supported protocol names.
+
+        `extensions` (for server sockets) is a list of supported extensions
+        (`Extension` instances).
+
         `location` (for client sockets) is optional, used to request a
         particular resource in the HTTP handshake. In a URL, this would show as
         ws://host[:port]/<location>. Use this when the server serves multiple
@@ -80,6 +81,7 @@ class websocket(object):
         """
         self.protocols = protocols
         self.extensions = extensions
+        self.extension_hooks = []
         self.origin = origin
         self.location = location
         self.trusted_origins = trusted_origins
@@ -90,9 +92,6 @@ class websocket(object):
 
         self.handshake_sent = False
 
-        self.hooks_send = []
-        self.hooks_recv = []
-
         self.sendbuf_frames = []
         self.sendbuf = ''
         self.recvbuf = ''
@@ -100,6 +99,9 @@ class websocket(object):
 
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
+    def set_extensions(self, extensions):
+        self.extensions = [ext.Hook() for ext in extensions]
+
     def __getattr__(self, name):
         if name in INHERITED_ATTRS:
             return getattr(self.sock, name)
@@ -132,29 +134,31 @@ class websocket(object):
         ClientHandshake(self).perform()
         self.handshake_sent = True
 
+    def apply_send_hooks(self, frame):
+        for hook in self.extension_hooks:
+            frame = hook.send(frame)
+
+        return frame
+
+    def apply_recv_hooks(self, frame):
+        for hook in reversed(self.extension_hooks):
+            frame = hook.recv(frame)
+
+        return frame
+
     def send(self, *args):
         """
         Send a number of frames.
         """
         for frame in args:
-            for hook in self.hooks_send:
-                frame = hook(frame)
-
-            #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
-            self.sock.sendall(frame.pack())
+            self.sock.sendall(self.apply_send_hooks(frame).pack())
 
     def recv(self):
         """
         Receive a single frames. This can be either a data frame or a control
         frame.
         """
-        frame = receive_frame(self.sock)
-
-        for hook in self.hooks_recv:
-            frame = hook(frame)
-
-        #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
-        return frame
+        return self.apply_recv_hooks(receive_frame(self.sock))
 
     def recvn(self, n):
         """
@@ -170,9 +174,7 @@ class websocket(object):
         frame has been fully written. `recv_callback` is an optional callable
         to quickly set the `recv_callback` attribute to.
         """
-        for hook in self.hooks_send:
-            frame = hook(frame)
-
+        frame = self.apply_send_hooks(frame)
         self.sendbuf += frame.pack()
         self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
 
@@ -181,7 +183,8 @@ class websocket(object):
 
     def do_async_send(self):
         """
-        Send any queued data.
+        Send any queued data. This function should only be called after a write
+        event on a file descriptor.
         """
         assert len(self.sendbuf)
 
@@ -204,6 +207,8 @@ class websocket(object):
 
     def do_async_recv(self, bufsize):
         """
+        Receive any completed frames from the socket. This function should only
+        be called after a read event on a file descriptor.
         """
         data = self.sock.recv(bufsize)
 
@@ -214,6 +219,7 @@ class websocket(object):
 
         while contains_frame(self.recvbuf):
             frame, self.recvbuf = pop_frame(self.recvbuf)
+            frame = self.apply_recv_hooks(frame)
 
             if not self.recv_callback:
                 raise ValueError('no callback installed for %s' % frame)
@@ -237,37 +243,3 @@ class websocket(object):
 
         self.secure = True
         self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
-
-    def add_hook(self, send=None, recv=None, prepend=False):
-        """
-        Add a pair of send and receive hooks that are called for each frame
-        that is sent or received. A hook is a function that receives a single
-        argument - a Frame instance - and returns a `Frame` instance as well.
-
-        `prepend` is a flag indicating whether the send hook is prepended to
-        the other send hooks. This is expecially useful when a program uses
-        extensions such as the built-in `DeflateFrame` extension. These
-        extensions are installed using these hooks as well.
-
-        For example, the following code creates a `Frame` instance for data
-        being sent and removes the instance for received data. This way, data
-        can be sent and received as if on a regular socket.
-        >>> import wspy
-        >>> sock = wspy.websocket()
-        >>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
-        >>>               lambda frame: frame.payload)
-
-        To add base64 encoding to the example above:
-        >>> import base64
-        >>> sock.add_hook(base64.encodestring, base64.decodestring, True)
-
-        Note that here `prepend=True`, so that data passed to `send()` is first
-        encoded and then packed into a frame. Of course, one could also decide
-        to add the base64 hook first, or to return a new `Frame` instance with
-        base64-encoded data.
-        """
-        if send:
-            self.hooks_send.insert(0 if prepend else -1, send)
-
-        if recv:
-            self.hooks_recv.insert(-1 if prepend else 0, recv)