Ver Fonte

Added permessage-deflate + lots of general debug and cleanup

Taddeus Kroes há 11 anos atrás
pai
commit
5ed32daf31
9 ficheiros alterados com 172 adições e 88 exclusões
  1. 22 7
      README.md
  2. 1 0
      __init__.py
  3. 10 44
      connection.py
  4. 14 12
      deflate_frame.py
  5. 89 0
      deflate_message.py
  6. 21 8
      extension.py
  7. 3 0
      frame.py
  8. 2 3
      handshake.py
  9. 10 14
      websocket.py

+ 22 - 7
README.md

@@ -20,8 +20,10 @@ Her is a quick overview of the features in this library:
 - HTTP authentication during handshake.
 - An extendible server implementation.
 - Secure sockets using SSL certificates (for 'wss://...' URLs).
-- The possibility to add extensions to the web socket protocol. An included
-  implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
+- An API for implementing WebSocket extensions. Included implementations are
+  [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06)
+  and
+  [permessage-deflate](http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-17).
 - Asynchronous sockets with an EPOLL-based server.
 
 
@@ -112,16 +114,19 @@ Basic usage
         conn.send(msg(foo='Hello, World!'))
 
 
-Built-in Server
-===============
+Built-in servers
+================
 
-The built-in `Server` implementation is very basic. It starts a new thread with
-a `Connection.receive_forever()` loop for each client that connects. It also
+Threaded
+--------
+
+The `Server` class is very basic. It starts a new thread with a
+`Connection.receive_forever()` loop for each client that connects. It also
 handles client crashes properly. By default, a `Server` instance only logs
 every event using Python's `logging` module. To create a custom server, The
 `Server` class should be extended and its event handlers overwritten. The event
 handlers are named identically to the `Connection` event handlers, but they
-also receive an additional `client` argument. This argument is a modified
+also receive an additional `client` argument. The client argumetn is a modified
 `Connection` instance, so you can invoke `send()` and `recv()`.
 
 For example, the `EchoConnection` example above can be rewritten to:
@@ -144,6 +149,16 @@ For example, the `EchoConnection` example above can be rewritten to:
 The server can be stopped by typing CTRL-C in the command line. The
 `KeyboardInterrupt` raised when this happens is caught by the server.
 
+Asynchronous
+------------
+
+The `AsyncServer` class has the same API as `Server`, but uses
+[EPOLL](https://docs.python.org/2/library/select.html#epoll-objects) instead of
+threads. This means that when you send a message, it is put into a queue to be
+sent later when the socket is ready. The client argument is againa modified
+`Connection` instance, with a non-blocking `send()` method (`recv` is still
+blocking, use the server's `onmessage` callback instead).
+
 
 Extensions
 ==========

+ 1 - 0
__init__.py

@@ -11,4 +11,5 @@ from message import Message, TextMessage, BinaryMessage
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from extension import Extension
 from deflate_frame import DeflateFrame
+from deflate_message import DeflateMessage
 from async import AsyncConnection, AsyncServer

+ 10 - 44
connection.py

@@ -54,14 +54,13 @@ class Connection(object):
         self.onopen()
 
     def message_to_frames(self, message, fragment_size=None, mask=False):
-        for hook in self.hooks_send:
-            message = hook(message)
+        frame = self.sock.apply_send_hooks(message.frame(mask=mask), True)
 
         if fragment_size is None:
-            yield message.frame(mask=mask)
+            yield frame
         else:
-            for frame in message.fragment(fragment_size, mask=mask):
-                yield frame
+            for fragment in frame.fragment(fragment_size):
+                yield fragment
 
     def send(self, message, fragment_size=None, mask=False):
         """
@@ -101,17 +100,14 @@ class Connection(object):
         return self.concat_fragments(fragments)
 
     def concat_fragments(self, fragments):
-        payload = bytearray()
+        frame = fragments[0]
 
-        for f in fragments:
-            payload += f.payload
+        for f in fragments[1:]:
+            frame.payload += f.payload
 
-        message = create_message(fragments[0].opcode, payload)
-
-        for hook in self.hooks_recv:
-            message = hook(message)
-
-        return message
+        frame.final = True
+        frame = self.sock.apply_recv_hooks(frame, True)
+        return create_message(frame.opcode, frame.payload)
 
     def handle_control_frame(self, frame):
         """
@@ -207,36 +203,6 @@ class Connection(object):
 
         self.handle_control_frame(frame)
 
-    def add_hook(self, send=None, recv=None, prepend=False):
-        """
-        Add a pair of send and receive hooks that are called for each frame
-        that is sent or received. A hook is a function that receives a single
-        argument - a Message instance - and returns a `Message` instance as
-        well.
-
-        `prepend` is a flag indicating whether the send hook is prepended to
-        the other send hooks.
-
-        For example, to add an automatic JSON conversion to messages and
-        eliminate the need to contruct TextMessage instances to all messages:
-        >>> import wspy, json
-        >>> conn = Connection(...)
-        >>> conn.add_hook(lambda data: tswpy.TextMessage(json.dumps(data)),
-        >>>               lambda message: json.loads(message.payload))
-        >>> conn.send({'foo': 'bar'})  # Sends text message {"foo":"bar"}
-        >>> conn.recv()                # May be dict(foo='bar')
-
-        Note that here `prepend=True`, so that data passed to `send()` is first
-        encoded and then packed into a frame. Of course, one could also decide
-        to add the base64 hook first, or to return a new `Frame` instance with
-        base64-encoded data.
-        """
-        if send:
-            self.hooks_send.insert(0 if prepend else -1, send)
-
-        if recv:
-            self.hooks_recv.insert(-1 if prepend else 0, recv)
-
     def onopen(self):
         """
         Called after the connection is initialized.

+ 14 - 12
deflate_frame.py

@@ -24,7 +24,7 @@ class DeflateFrame(Extension):
         'no_context_takeover': False
     }
 
-    compression_threshold = 64  # minimal payload size for compression
+    compression_threshold = 20  # minimal payload size for compression
 
     def negotiate(self, name, params):
         if 'max_window_bits' in params:
@@ -43,13 +43,16 @@ class DeflateFrame(Extension):
                         zlib.DEFLATED, -self.max_window_bits)
                 self.dec = zlib.decompressobj(-self.max_window_bits)
 
-        def onsend_frame(self, frame):
+        def onsend(self, frame):
             if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
                    len(frame.payload) > self.extension.compression_threshold:
-                frame.rsv1 = True
-                frame.payload = self.deflate(frame)
+                deflated = self.deflate(frame.payload)
 
-        def onrecv_frame(self, frame):
+                if len(deflated) < len(frame.payload):
+                    frame.rsv1 = True
+                    frame.payload = deflated
+
+        def onrecv(self, frame):
             if frame.rsv1:
                 if isinstance(frame, ControlFrame):
                     raise ValueError('received compressed control frame')
@@ -57,13 +60,13 @@ class DeflateFrame(Extension):
                 frame.rsv1 = False
                 frame.payload = self.inflate(frame.payload)
 
-        def deflate(self, frame):
+        def deflate(self, data):
             if self.no_context_takeover:
-                compressed = zlib.compress(frame.payload)
-            else:
-                compressed = self.defl.compress(frame.payload)
-                compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                        zlib.DEFLATED, -self.max_window_bits)
 
+            compressed = self.defl.compress(data)
+            compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
             assert compressed[-4:] == '\x00\x00\xff\xff'
             return compressed[:-4]
 
@@ -71,7 +74,6 @@ class DeflateFrame(Extension):
             data = str(data + '\x00\x00\xff\xff')
 
             if self.no_context_takeover:
-                dec = zlib.decompressobj(-self.max_window_bits)
-                return dec.decompress(data) + dec.flush()
+                self.dec = zlib.decompressobj(-self.max_window_bits)
 
             return self.dec.decompress(data)

+ 89 - 0
deflate_message.py

@@ -0,0 +1,89 @@
+import zlib
+
+from extension import Extension
+from deflate_frame import DeflateFrame
+
+
+class DeflateMessage(Extension):
+    """
+    Implementation of the "permessage-deflate" extension, as defined by
+    http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-17.
+
+    Note: this implementetion is only eligible for server sockets, client
+    sockets must NOT use it.
+    """
+    name = 'permessage-deflate'
+    rsv1 = True
+    defaults = {
+        'client_max_window_bits': zlib.MAX_WBITS,
+        'client_no_context_takeover': False,
+        'server_max_window_bits': zlib.MAX_WBITS,
+        'server_no_context_takeover': False
+    }
+    before_fragmentation = True
+
+    compression_threshold = 20  # minimal message payload size for compression
+
+    def negotiate(self, name, params):
+        default = self.defaults['client_max_window_bits']
+
+        if 'client_max_window_bits' in params:
+            mwb = params['client_max_window_bits']
+
+            if mwb is True:
+                if default != zlib.MAX_WBITS:
+                    yield 'client_max_window_bits', default
+            else:
+                mwb = int(mwb)
+                assert 8 <= mwb <= zlib.MAX_WBITS
+                yield 'client_max_window_bits', min(mwb, default)
+        elif default != zlib.MAX_WBITS:
+            yield 'client_max_window_bits', default
+
+        if 'client_no_context_takeover' in params:
+            assert params['client_no_context_takeover'] is True
+            yield 'client_no_context_takeover', True
+        elif self.defaults['client_no_context_takeover']:
+            yield 'client_no_context_takeover', True
+
+        default = self.defaults['server_max_window_bits']
+
+        if 'server_max_window_bits' in params:
+            mwb = int(params['server_max_window_bits'])
+            assert 8 <= mwb <= zlib.MAX_WBITS
+            yield 'server_max_window_bits', min(mwb, default)
+        elif default != zlib.MAX_WBITS:
+            yield 'server_max_window_bits', default
+
+        if 'server_no_context_takeover' in params:
+            assert params['server_no_context_takeover'] is True
+            yield 'server_no_context_takeover', True
+        elif self.defaults['server_no_context_takeover']:
+            yield 'server_no_context_takeover', True
+
+    class Instance(DeflateFrame.Instance):
+        def init(self):
+            if not self.server_no_context_takeover:
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                        zlib.DEFLATED, -self.server_max_window_bits)
+
+            if not self.client_no_context_takeover:
+                self.dec = zlib.decompressobj(-self.client_max_window_bits)
+
+        def deflate(self, data):
+            if self.server_no_context_takeover:
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                        zlib.DEFLATED, -self.server_max_window_bits)
+
+            compressed = self.defl.compress(data)
+            compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
+            assert compressed[-4:] == '\x00\x00\xff\xff'
+            return compressed[:-4]
+
+        def inflate(self, data):
+            data = str(data + '\x00\x00\xff\xff')
+
+            if self.client_no_context_takeover:
+                self.dec = zlib.decompressobj(-self.client_max_window_bits)
+
+            return self.dec.decompress(data)

+ 21 - 8
extension.py

@@ -4,6 +4,7 @@ class Extension(object):
     rsv2 = False
     rsv3 = False
     opcodes = ()
+    before_fragmentation = False
     defaults = {}
 
     def __init__(self, **kwargs):
@@ -23,6 +24,10 @@ class Extension(object):
     def names(self):
         return (self.name,) if self.name else ()
 
+    def is_supported(self, name, other_instances):
+        return name in self.names and not any(self.conflicts(other.extension)
+                                              for other in other_instances)
+
     def conflicts(self, ext):
         """
         Check if the extension conflicts with an already accepted extension.
@@ -76,14 +81,22 @@ class Extension(object):
         def init(self):
             return NotImplemented
 
-        def onsend_frame(self, frame):
-            pass
+        def handle_send(self, frame):
+            if self.extension.before_fragmentation:
+                assert not frame.is_fragmented()
 
-        def onrecv_frame(self, frame):
-            pass
+            replacement = self.onsend(frame)
+            return frame if replacement is None else replacement
 
-        def onsend_message(self, message):
-            pass
+        def handle_recv(self, frame):
+            if self.extension.before_fragmentation:
+                assert not frame.is_fragmented()
 
-        def onrecv_message(self, message):
-            pass
+            replacement = self.onrecv(frame)
+            return frame if replacement is None else replacement
+
+        def onsend(self, frame):
+            raise NotImplementedError
+
+        def onrecv(self, frame):
+            raise NotImplementedError

+ 3 - 0
frame.py

@@ -143,6 +143,9 @@ class Frame(object):
 
         return frames
 
+    def is_fragmented(self):
+        return not self.final or self.opcode == OPCODE_CONTINUATION
+
     def __str__(self):
         s = '<%s opcode=0x%X len=%d' \
             % (self.__class__.__name__, self.opcode, len(self.payload))

+ 2 - 3
handshake.py

@@ -177,8 +177,7 @@ class ServerHandshake(Handshake):
                 name, params = parse_param_hdr(hdr)
 
                 for ext in ssock.extensions:
-                    if not any(ext.conflicts(other.extension)
-                               for other in self.wsock.extension_instances):
+                    if ext.is_supported(name, self.wsock.extension_instances):
                         accept_params = ext.negotiate_safe(name, params)
 
                         if accept_params is not None:
@@ -432,4 +431,4 @@ def format_param_hdr(value, params):
             return k + '=' + str(v)
 
     strparams = filter(None, map(fmt_param, params.items()))
-    return '%s; %s' % (value, ', '.join(strparams))
+    return '; '.join([value] + strparams)

+ 10 - 14
websocket.py

@@ -131,21 +131,17 @@ class websocket(object):
         ClientHandshake(self).perform()
         self.handshake_sent = True
 
-    def apply_send_hooks(self, frame):
+    def apply_send_hooks(self, frame, before_fragmentation):
         for inst in self.extension_instances:
-            replacement = inst.onsend_frame(frame)
-
-            if replacement is not None:
-                frame = replacement
+            if inst.extension.before_fragmentation == before_fragmentation:
+                frame = inst.handle_send(frame)
 
         return frame
 
-    def apply_recv_hooks(self, frame):
+    def apply_recv_hooks(self, frame, before_fragmentation):
         for inst in reversed(self.extension_instances):
-            replacement = inst.onrecv_frame(frame)
-
-            if replacement is not None:
-                frame = replacement
+            if inst.extension.before_fragmentation == before_fragmentation:
+                frame = inst.handle_recv(frame)
 
         return frame
 
@@ -154,14 +150,14 @@ class websocket(object):
         Send a number of frames.
         """
         for frame in args:
-            self.sock.sendall(self.apply_send_hooks(frame).pack())
+            self.sock.sendall(self.apply_send_hooks(frame, False).pack())
 
     def recv(self):
         """
         Receive a single frames. This can be either a data frame or a control
         frame.
         """
-        return self.apply_recv_hooks(receive_frame(self.sock))
+        return self.apply_recv_hooks(receive_frame(self.sock), False)
 
     def recvn(self, n):
         """
@@ -177,7 +173,7 @@ class websocket(object):
         frame has been fully written. `recv_callback` is an optional callable
         to quickly set the `recv_callback` attribute to.
         """
-        frame = self.apply_send_hooks(frame)
+        frame = self.apply_send_hooks(frame, False)
         self.sendbuf += frame.pack()
         self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
 
@@ -222,7 +218,7 @@ class websocket(object):
 
         while contains_frame(self.recvbuf):
             frame, self.recvbuf = pop_frame(self.recvbuf)
-            frame = self.apply_recv_hooks(frame)
+            frame = self.apply_recv_hooks(frame, False)
 
             if not self.recv_callback:
                 raise ValueError('no callback installed for %s' % frame)