Jelajahi Sumber

Merge branch 'async'

Taddeus Kroes 11 tahun lalu
induk
melakukan
6c79550e13
13 mengubah file dengan 592 tambahan dan 224 penghapusan
  1. 1 0
      README.md
  2. 3 2
      __init__.py
  3. 190 0
      async.py
  4. 52 47
      connection.py
  5. 27 32
      deflate_frame.py
  6. 8 4
      errors.py
  7. 23 24
      extension.py
  8. 92 22
      frame.py
  9. 16 17
      handshake.py
  10. 24 19
      server.py
  11. 10 2
      test/client.py
  12. 45 0
      test/talk.py
  13. 101 55
      websocket.py

+ 1 - 0
README.md

@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library:
 - Secure sockets using SSL certificates (for 'wss://...' URLs).
 - Secure sockets using SSL certificates (for 'wss://...' URLs).
 - The possibility to add extensions to the web socket protocol. An included
 - The possibility to add extensions to the web socket protocol. An included
   implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
   implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
+- Asynchronous sockets with an EPOLL-based server.
 
 
 
 
 Installation
 Installation

+ 3 - 2
__init__.py

@@ -4,10 +4,11 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
         OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
         OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
         CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
         CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
         CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
         CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
-        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
+        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE, read_frame, pop_frame, \
+        contains_frame
 from connection import Connection
 from connection import Connection
 from message import Message, TextMessage, BinaryMessage
 from message import Message, TextMessage, BinaryMessage
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from extension import Extension
 from extension import Extension
 from deflate_frame import DeflateFrame, WebkitDeflateFrame
 from deflate_frame import DeflateFrame, WebkitDeflateFrame
-#from multiplex import Multiplex
+from async import AsyncConnection, AsyncServer

+ 190 - 0
async.py

@@ -0,0 +1,190 @@
+import socket
+from select import epoll, EPOLLIN, EPOLLOUT, EPOLLHUP
+from traceback import format_exc
+import logging
+
+from connection import Connection
+from frame import ControlFrame, OPCODE_PING, OPCODE_CONTINUATION, \
+                  create_close_frame
+from server import Server, Client
+from errors import HandshakeError, SocketClosed
+
+
+class AsyncConnection(Connection):
+    def __init__(self, sock):
+        sock.recv_callback = self.contruct_message
+        sock.recv_close_callback = self.onclose
+        self.recvbuf = []
+        Connection.__init__(self, sock)
+
+    def contruct_message(self, frame):
+        if isinstance(frame, ControlFrame):
+            self.handle_control_frame(frame)
+            return
+
+        self.recvbuf.append(frame)
+
+        if frame.final:
+            message = self.concat_fragments(self.recvbuf)
+            self.recvbuf = []
+            self.onmessage(message)
+        elif len(self.recvbuf) > 1 and frame.opcode != OPCODE_CONTINUATION:
+            raise ValueError('expected continuation/control frame, got %s '
+                             'instead' % frame)
+
+    def send(self, message, fragment_size=None, mask=False):
+        frames = list(self.message_to_frames(message, fragment_size, mask))
+
+        for frame in frames[:-1]:
+            self.sock.queue_send(frame)
+
+        self.sock.queue_send(frames[-1], lambda: self.onsend(message))
+
+    def send_frame(self, frame, callback):
+        self.sock.queue_send(frame, callback)
+
+    def do_async_send(self):
+        self.execute_controlled(self.sock.do_async_send)
+
+    def do_async_recv(self, bufsize):
+        self.execute_controlled(self.sock.do_async_recv, bufsize)
+
+    def execute_controlled(self, func, *args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except (KeyboardInterrupt, SystemExit, SocketClosed):
+            raise
+        except Exception as e:
+            self.onerror(e)
+            self.onclose(None, 'error: %s' % e)
+
+            try:
+                self.sock.close()
+            except socket.error:
+                pass
+
+            raise e
+
+    def send_close_frame(self, code, reason):
+        self.sock.queue_send(create_close_frame(code, reason),
+                             self.shutdown_write)
+        self.close_frame_sent = True
+
+    def close(self, code=None, reason=''):
+        self.send_close_frame(code, reason)
+
+    def send_ping(self, payload=''):
+        self.sock.queue_send(ControlFrame(OPCODE_PING, payload),
+                             lambda: self.onping(payload))
+        self.ping_payload = payload
+        self.ping_sent = True
+
+    def onsend(self, message):
+        """
+        Called after a message has been written.
+        """
+        return NotImplemented
+
+
+class AsyncServer(Server):
+    def __init__(self, *args, **kwargs):
+        Server.__init__(self, *args, **kwargs)
+
+        self.recvbuf_size = kwargs.get('recvbuf_size', 2048)
+
+        self.epoll = epoll()
+        self.epoll.register(self.sock.fileno(), EPOLLIN)
+        self.conns = {}
+
+    @property
+    def clients(self):
+        return self.conns.values()
+
+    def remove_client(self, client, code, reason):
+        self.epoll.unregister(client.fno)
+        del self.conns[client.fno]
+        self.onclose(client, code, reason)
+
+    def handle_events(self):
+        for fileno, event in self.epoll.poll(1):
+            if fileno == self.sock.fileno():
+                try:
+                    sock, addr = self.sock.accept()
+                except HandshakeError as e:
+                    logging.error('Invalid request: %s', e.message)
+                    continue
+
+                client = AsyncClient(self, sock)
+                client.fno = sock.fileno()
+                sock.setblocking(0)
+                self.epoll.register(client.fno, EPOLLIN)
+                self.conns[client.fno] = client
+                logging.debug('Registered client %s', client)
+
+            elif event & EPOLLHUP:
+                self.epoll.unregister(fileno)
+                del self.conns[fileno]
+
+            else:
+                conn = self.conns[fileno]
+
+                try:
+                    if event & EPOLLOUT:
+                        conn.do_async_send()
+                    elif event & EPOLLIN:
+                        conn.do_async_recv(self.recvbuf_size)
+                except (KeyboardInterrupt, SystemExit):
+                    raise
+                except SocketClosed:
+                    continue
+                except Exception as e:
+                    logging.error(format_exc(e).rstrip())
+                    continue
+
+                self.update_mask(conn)
+
+    def run(self):
+        try:
+            while True:
+                self.handle_events()
+        except (KeyboardInterrupt, SystemExit):
+            logging.info('Received interrupt, stopping server...')
+        finally:
+            self.epoll.unregister(self.sock.fileno())
+            self.epoll.close()
+            self.sock.close()
+
+    def update_mask(self, conn):
+        mask = 0
+
+        if conn.sock.can_send():
+            mask |= EPOLLOUT
+
+        if conn.sock.can_recv():
+            mask |= EPOLLIN
+
+        self.epoll.modify(conn.sock.fileno(), mask)
+
+    def onsend(self, client, message):
+        return NotImplemented
+
+
+class AsyncClient(Client, AsyncConnection):
+    def __init__(self, server, sock):
+        self.server = server
+        AsyncConnection.__init__(self, sock)
+
+    def send(self, message, fragment_size=None, mask=False):
+        logging.debug('Enqueueing %s to %s', message, self)
+        AsyncConnection.send(self, message, fragment_size, mask)
+        self.server.update_mask(self)
+
+    def onsend(self, message):
+        logging.debug('Finished sending %s to %s', message, self)
+        self.server.onsend(self, message)
+
+
+if __name__ == '__main__':
+    import sys
+    port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
+    AsyncServer(('', port), loglevel=logging.DEBUG).run()

+ 52 - 47
connection.py

@@ -1,8 +1,7 @@
-import struct
 import socket
 import socket
 
 
 from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
 from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
-                  OPCODE_CONTINUATION
+                  OPCODE_CONTINUATION, create_close_frame
 from message import create_message
 from message import create_message
 from errors import SocketClosed, PingError
 from errors import SocketClosed, PingError
 
 
@@ -54,19 +53,30 @@ class Connection(object):
 
 
         self.onopen()
         self.onopen()
 
 
+    def message_to_frames(self, message, fragment_size=None, mask=False):
+        for hook in self.hooks_send:
+            message = hook(message)
+
+        if fragment_size is None:
+            yield message.frame(mask=mask)
+        else:
+            for frame in message.fragment(fragment_size, mask=mask):
+                yield frame
+
     def send(self, message, fragment_size=None, mask=False):
     def send(self, message, fragment_size=None, mask=False):
         """
         """
         Send a message. If `fragment_size` is specified, the message is
         Send a message. If `fragment_size` is specified, the message is
         fragmented into multiple frames whose payload size does not extend
         fragmented into multiple frames whose payload size does not extend
         `fragment_size`.
         `fragment_size`.
         """
         """
-        for hook in self.hooks_send:
-            message = hook(message)
+        for frame in self.message_to_frames(message, fragment_size, mask):
+            self.send_frame(frame)
 
 
-        if fragment_size is None:
-            self.sock.send(message.frame(mask=mask))
-        else:
-            self.sock.send(*message.fragment(fragment_size, mask=mask))
+    def send_frame(self, frame, callback=None):
+        self.sock.send(frame)
+
+        if callback:
+            callback()
 
 
     def recv(self):
     def recv(self):
         """
         """
@@ -82,12 +92,15 @@ class Connection(object):
 
 
             if isinstance(frame, ControlFrame):
             if isinstance(frame, ControlFrame):
                 self.handle_control_frame(frame)
                 self.handle_control_frame(frame)
-            elif len(fragments) and frame.opcode != OPCODE_CONTINUATION:
+            elif len(fragments) > 0 and frame.opcode != OPCODE_CONTINUATION:
                 raise ValueError('expected continuation/control frame, got %s '
                 raise ValueError('expected continuation/control frame, got %s '
                                  'instead' % frame)
                                  'instead' % frame)
             else:
             else:
                 fragments.append(frame)
                 fragments.append(frame)
 
 
+        return self.concat_fragments(fragments)
+
+    def concat_fragments(self, fragments):
         payload = bytearray()
         payload = bytearray()
 
 
         for f in fragments:
         for f in fragments:
@@ -105,16 +118,20 @@ class Connection(object):
         Handle a control frame as defined by RFC 6455.
         Handle a control frame as defined by RFC 6455.
         """
         """
         if frame.opcode == OPCODE_CLOSE:
         if frame.opcode == OPCODE_CLOSE:
-            # Close the connection from this end as well
             self.close_frame_received = True
             self.close_frame_received = True
             code, reason = frame.unpack_close()
             code, reason = frame.unpack_close()
 
 
-            # No more receiving data after a close message
-            raise SocketClosed(code, reason)
+            if self.close_frame_sent:
+                self.onclose(code, reason)
+                self.sock.close()
+                raise SocketClosed(True)
+            else:
+                self.close_params = (code, reason)
+                self.send_close_frame(code, reason)
 
 
         elif frame.opcode == OPCODE_PING:
         elif frame.opcode == OPCODE_PING:
             # Respond with a pong message with identical payload
             # Respond with a pong message with identical payload
-            self.sock.send(ControlFrame(OPCODE_PONG, frame.payload))
+            self.send_frame(ControlFrame(OPCODE_PONG, frame.payload))
 
 
         elif frame.opcode == OPCODE_PONG:
         elif frame.opcode == OPCODE_PONG:
             # Assert that the PONG payload is identical to that of the PING
             # Assert that the PONG payload is identical to that of the PING
@@ -138,38 +155,40 @@ class Connection(object):
         while True:
         while True:
             try:
             try:
                 self.onmessage(self.recv())
                 self.onmessage(self.recv())
-            except SocketClosed as e:
-                self.close(e.code, e.reason)
+            except (KeyboardInterrupt, SystemExit, SocketClosed):
                 break
                 break
-            except socket.error as e:
+            except Exception as e:
                 self.onerror(e)
                 self.onerror(e)
+                self.onclose(None, 'error: %s' % e)
 
 
                 try:
                 try:
                     self.sock.close()
                     self.sock.close()
                 except socket.error:
                 except socket.error:
                     pass
                     pass
 
 
-                self.onclose(None, '')
-                break
-            except Exception as e:
-                self.onerror(e)
+                raise e
 
 
     def send_ping(self, payload=''):
     def send_ping(self, payload=''):
         """
         """
         Send a PING control frame with an optional payload.
         Send a PING control frame with an optional payload.
         """
         """
-        self.sock.send(ControlFrame(OPCODE_PING, payload))
+        self.send_frame(ControlFrame(OPCODE_PING, payload),
+                        lambda: self.onping(payload))
         self.ping_payload = payload
         self.ping_payload = payload
         self.ping_sent = True
         self.ping_sent = True
-        self.onping(payload)
 
 
-    def send_close_frame(self, code=None, reason=''):
-        """
-        Send a CLOSE control frame.
-        """
-        payload = '' if code is None else struct.pack('!H', code) + reason
-        self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
+    def send_close_frame(self, code, reason):
+        self.send_frame(create_close_frame(code, reason))
         self.close_frame_sent = True
         self.close_frame_sent = True
+        self.shutdown_write()
+
+    def shutdown_write(self):
+        if self.close_frame_received:
+            self.onclose(*self.close_params)
+            self.sock.close()
+            raise SocketClosed(False)
+        else:
+            self.sock.shutdown(socket.SHUT_WR)
 
 
     def close(self, code=None, reason=''):
     def close(self, code=None, reason=''):
         """
         """
@@ -179,28 +198,14 @@ class Connection(object):
         called after the response has been received, but before the socket is
         called after the response has been received, but before the socket is
         actually closed.
         actually closed.
         """
         """
-        # Send CLOSE frame
-        if not self.close_frame_sent:
-            self.send_close_frame(code, reason)
-
-        # Receive CLOSE frame
-        if not self.close_frame_received:
-            frame = self.sock.recv()
-
-            if frame.opcode != OPCODE_CLOSE:
-                raise ValueError('expected CLOSE frame, got %s' % frame)
-
-            self.close_frame_received = True
-            res_code, res_reason = frame.unpack_close()
+        self.send_close_frame(code, reason)
 
 
-            # FIXME: check if res_code == code and res_reason == reason?
+        frame = self.sock.recv()
 
 
-            # FIXME: alternatively, keep receiving frames in a loop until a
-            # CLOSE frame is received, so that a fragmented chain may arrive
-            # fully first
+        if frame.opcode != OPCODE_CLOSE:
+            raise ValueError('expected CLOSE frame, got %s' % frame)
 
 
-        self.onclose(code, reason)
-        self.sock.close()
+        self.handle_control_frame(frame)
 
 
     def add_hook(self, send=None, recv=None, prepend=False):
     def add_hook(self, send=None, recv=None, prepend=False):
         """
         """

+ 27 - 32
deflate_frame.py

@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
 
 
     name = 'deflate-frame'
     name = 'deflate-frame'
     rsv1 = True
     rsv1 = True
-    defaults = {'max_window_bits': 15, 'no_context_takeover': False}
+    defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
 
 
-    def __init__(self, defaults={}, request={}):
-        Extension.__init__(self, defaults, request)
+    COMPRESSION_THRESHOLD = 64  # minimal payload size for compression
 
 
+    def init(self):
         mwb = self.defaults['max_window_bits']
         mwb = self.defaults['max_window_bits']
         cto = self.defaults['no_context_takeover']
         cto = self.defaults['no_context_takeover']
 
 
-        if not isinstance(mwb, int):
-            raise ValueError('"max_window_bits" must be an integer')
-        elif mwb > 15:
-            raise ValueError('"max_window_bits" may not be larger than 15')
+        if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
+            raise ValueError('"max_window_bits" must be in range 1-15')
 
 
         if cto is not False and cto is not True:
         if cto is not False and cto is not True:
             raise ValueError('"no_context_takeover" must have no value')
             raise ValueError('"no_context_takeover" must have no value')
 
 
     class Hook(Extension.Hook):
     class Hook(Extension.Hook):
-        def __init__(self, extension, **kwargs):
-            Extension.Hook.__init__(self, extension, **kwargs)
-
-            if not self.no_context_takeover:
-                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                            zlib.DEFLATED,
-                                            -self.max_window_bits)
-
-            other_wbits = self.extension.request.get('max_window_bits', 15)
+        def init(self, extension):
+            self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                                         zlib.DEFLATED, -self.max_window_bits)
+            other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
             self.dec = zlib.decompressobj(-other_wbits)
             self.dec = zlib.decompressobj(-other_wbits)
 
 
         def send(self, frame):
         def send(self, frame):
-            if not frame.rsv1 and not isinstance(frame, ControlFrame):
+            # FIXME: this does not seem to work properly on Android
+            if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
+                   len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
                 frame.rsv1 = True
                 frame.rsv1 = True
-                frame.payload = self.deflate(frame.payload)
+                frame.payload = self.deflate(frame)
 
 
             return frame
             return frame
 
 
@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
 
 
             return frame
             return frame
 
 
-        def deflate(self, data):
-            if self.no_context_takeover:
-                defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                        zlib.DEFLATED, -self.max_window_bits)
-                # FIXME: why the '\x00' below? This was borrowed from
-                # https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
-                return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
+        def deflate(self, frame):
+            compressed = self.defl.compress(frame.payload)
+
+            if frame.final or self.no_context_takeover:
+                compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                        zlib.DEFLATED, -self.max_window_bits)
+            else:
+                compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
+                assert compressed[-4:] == '\x00\x00\xff\xff'
+                compressed = compressed[:-4]
 
 
-            compressed = self.defl.compress(data)
-            compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
-            assert compressed[-4:] == '\x00\x00\xff\xff'
-            return compressed[:-4]
+            return compressed
 
 
         def inflate(self, data):
         def inflate(self, data):
-            data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
-            assert not self.dec.unused_data
-            return data
+            return self.dec.decompress(data + '\x00\x00\xff\xff') + \
+                   self.dec.flush(zlib.Z_SYNC_FLUSH)
 
 
 
 
 class WebkitDeflateFrame(DeflateFrame):
 class WebkitDeflateFrame(DeflateFrame):

+ 8 - 4
errors.py

@@ -1,11 +1,15 @@
 class SocketClosed(Exception):
 class SocketClosed(Exception):
-    def __init__(self, code=None, reason=''):
-        self.code = code
-        self.reason = reason
+    def __init__(self, initialized):
+        self.initialized = initialized
 
 
     @property
     @property
     def message(self):
     def message(self):
-        return ('' if self.code is None else '[%d] ' % self.code) + self.reason
+        s = 'socket closed'
+
+        if self.initialized:
+            s += ' (initialized)'
+
+        return s
 
 
 
 
 class HandshakeError(Exception):
 class HandshakeError(Exception):

+ 23 - 24
extension.py

@@ -19,23 +19,31 @@ class Extension(object):
         self.request = dict(self.__class__.request)
         self.request = dict(self.__class__.request)
         self.request.update(request)
         self.request.update(request)
 
 
+        self.init()
+
     def __str__(self):
     def __str__(self):
         return '<Extension "%s" defaults=%s request=%s>' \
         return '<Extension "%s" defaults=%s request=%s>' \
                % (self.name, self.defaults, self.request)
                % (self.name, self.defaults, self.request)
 
 
+    def init(self):
+        return NotImplemented
+
     def create_hook(self, **kwargs):
     def create_hook(self, **kwargs):
         params = {}
         params = {}
         params.update(self.defaults)
         params.update(self.defaults)
         params.update(kwargs)
         params.update(kwargs)
-        return self.Hook(self, **params)
+        hook = self.Hook(**params)
+        hook.init(self)
+        return hook
 
 
     class Hook:
     class Hook:
-        def __init__(self, extension, **kwargs):
-            self.extension = extension
-
+        def __init__(self, **kwargs):
             for param, value in kwargs.iteritems():
             for param, value in kwargs.iteritems():
                 setattr(self, param, value)
                 setattr(self, param, value)
 
 
+        def init(self, extension):
+            return NotImplemented
+
         def send(self, frame):
         def send(self, frame):
             return frame
             return frame
 
 
@@ -43,28 +51,19 @@ class Extension(object):
             return frame
             return frame
 
 
 
 
-def filter_extensions(extensions):
-    """
-    Remove extensions that use conflicting rsv bits and/or opcodes, with the
-    first options being the most preferable.
-    """
+def extension_conflicts(ext, existing):
     rsv1_reserved = False
     rsv1_reserved = False
     rsv2_reserved = False
     rsv2_reserved = False
     rsv3_reserved = False
     rsv3_reserved = False
-    opcodes_reserved = []
-    compat = []
-
-    for ext in extensions:
-        if ext.rsv1 and rsv1_reserved \
-                or ext.rsv2 and rsv2_reserved \
-                or ext.rsv3 and rsv3_reserved \
-                or len(set(ext.opcodes) & set(opcodes_reserved)):
-            continue
+    reserved_opcodes = []
 
 
-        rsv1_reserved |= ext.rsv1
-        rsv2_reserved |= ext.rsv2
-        rsv3_reserved |= ext.rsv3
-        opcodes_reserved.extend(ext.opcodes)
-        compat.append(ext)
+    for e in existing:
+        rsv1_reserved |= e.rsv1
+        rsv2_reserved |= e.rsv2
+        rsv3_reserved |= e.rsv3
+        reserved_opcodes.extend(e.opcodes)
 
 
-    return compat
+    return ext.rsv1 and rsv1_reserved \
+            or ext.rsv2 and rsv2_reserved \
+            or ext.rsv3 and rsv3_reserved \
+            or len(set(ext.opcodes) & set(reserved_opcodes))

+ 92 - 22
frame.py

@@ -21,9 +21,11 @@ CLOSE_MESSAGE_TOOBIG = 1009
 CLOSE_MISSING_EXTENSIONS = 1010
 CLOSE_MISSING_EXTENSIONS = 1010
 CLOSE_UNABLE = 1011
 CLOSE_UNABLE = 1011
 
 
+line_printable = [c for c in printable if c not in '\r\n\x0b\x0c']
+
 
 
 def printstr(s):
 def printstr(s):
-    return ''.join(c if c in printable else '.' for c in s)
+    return ''.join(c if c in line_printable else '.' for c in str(s))
 
 
 
 
 class Frame(object):
 class Frame(object):
@@ -154,7 +156,18 @@ class Frame(object):
         if len(self.payload) > max_pl_disp:
         if len(self.payload) > max_pl_disp:
             pl += '...'
             pl += '...'
 
 
-        return s + ' payload=%s>' % pl
+        s += ' payload=%s' % pl
+
+        if self.rsv1:
+            s += ' rsv1'
+
+        if self.rsv2:
+            s += ' rsv2'
+
+        if self.rsv3:
+            s += ' rsv3'
+
+        return s + '>'
 
 
 
 
 class ControlFrame(Frame):
 class ControlFrame(Frame):
@@ -194,12 +207,8 @@ class ControlFrame(Frame):
         return code, reason
         return code, reason
 
 
 
 
-def receive_frame(sock):
-    """
-    Receive a single frame on socket `sock`. The frame scheme is explained in
-    the docs of Frame.pack().
-    """
-    b1, b2 = struct.unpack('!BB', recvn(sock, 2))
+def decode_frame(reader):
+    b1, b2 = struct.unpack('!BB', reader.readn(2))
 
 
     final = bool(b1 & 0x80)
     final = bool(b1 & 0x80)
     rsv1 = bool(b1 & 0x40)
     rsv1 = bool(b1 & 0x40)
@@ -211,16 +220,16 @@ def receive_frame(sock):
     payload_len = b2 & 0x7F
     payload_len = b2 & 0x7F
 
 
     if payload_len == 126:
     if payload_len == 126:
-        payload_len = struct.unpack('!H', recvn(sock, 2))
+        payload_len = struct.unpack('!H', reader.readn(2))
     elif payload_len == 127:
     elif payload_len == 127:
-        payload_len = struct.unpack('!Q', recvn(sock, 8))
+        payload_len = struct.unpack('!Q', reader.readn(8))
 
 
     if masked:
     if masked:
-        masking_key = recvn(sock, 4)
-        payload = mask(masking_key, recvn(sock, payload_len))
+        masking_key = reader.readn(4)
+        payload = mask(masking_key, reader.readn(payload_len))
     else:
     else:
         masking_key = ''
         masking_key = ''
-        payload = recvn(sock, payload_len)
+        payload = reader.readn(payload_len)
 
 
     # Control frames have most significant bit 1
     # Control frames have most significant bit 1
     cls = ControlFrame if opcode & 0x8 else Frame
     cls = ControlFrame if opcode & 0x8 else Frame
@@ -229,21 +238,77 @@ def receive_frame(sock):
                rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
                rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
 
 
 
 
-def recvn(sock, n):
+def receive_frame(sock):
+    return decode_frame(SocketReader(sock))
+
+
+def read_frame(data):
+    reader = BufferReader(data)
+    frame = decode_frame(reader)
+    return frame, reader.offset
+
+
+def pop_frame(data):
+    frame, size = read_frame(data)
+    return frame, data[size:]
+
+
+class BufferReader(object):
+    def __init__(self, data):
+        self.data = data
+        self.offset = 0
+
+    def readn(self, n):
+        assert len(self.data) - self.offset >= n
+        self.offset += n
+        return self.data[self.offset - n:self.offset]
+
+
+class SocketReader(object):
+    def __init__(self, sock):
+        self.sock = sock
+
+    def readn(self, n):
+        """
+        Keep receiving data until exactly `n` bytes have been read.
+        """
+        data = ''
+
+        while len(data) < n:
+            received = self.sock.recv(n - len(data))
+
+            if not len(received):
+                raise socket.error('no data read from socket')
+
+            data += received
+
+        return data
+
+
+def contains_frame(data):
     """
     """
-    Keep receiving data from `sock` until exactly `n` bytes have been read.
+    Read the frame length from the start of `data` and check if the data is
+    long enough to contain the entire frame.
     """
     """
-    data = ''
+    if len(data) < 2:
+        return False
+
+    b2 = struct.unpack('!B', data[1])[0]
+    payload_len = b2 & 0x7F
+    payload_start = 2
 
 
-    while len(data) < n:
-        received = sock.recv(n - len(data))
+    if payload_len == 126:
+        if len(data) > 4:
+            payload_len = struct.unpack('!H', data[2:4])
 
 
-        if not len(received):
-            raise socket.error('no data read from socket')
+        payload_start = 4
+    elif payload_len == 127:
+        if len(data) > 12:
+            payload_len = struct.unpack('!Q', data[4:12])
 
 
-        data += received
+        payload_start = 12
 
 
-    return data
+    return len(data) >= payload_len + payload_start
 
 
 
 
 def mask(key, original):
 def mask(key, original):
@@ -265,3 +330,8 @@ def mask(key, original):
         masked[i] ^= key[i % 4]
         masked[i] ^= key[i % 4]
 
 
     return masked
     return masked
+
+
+def create_close_frame(code, reason):
+    payload = '' if code is None else struct.pack('!H', code) + reason
+    return ControlFrame(OPCODE_CLOSE, payload)

+ 16 - 17
handshake.py

@@ -7,7 +7,7 @@ from hashlib import sha1
 from urlparse import urlparse
 from urlparse import urlparse
 
 
 from errors import HandshakeError
 from errors import HandshakeError
-from extension import filter_extensions
+from extension import extension_conflicts
 from python_digest import build_authorization_request
 from python_digest import build_authorization_request
 
 
 
 
@@ -15,7 +15,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
 WS_VERSION = '13'
 WS_VERSION = '13'
 MAX_REDIRECTS = 10
 MAX_REDIRECTS = 10
 HDR_TIMEOUT = 5
 HDR_TIMEOUT = 5
-MAX_HDR_LEN = 512
+MAX_HDR_LEN = 1024
 
 
 
 
 class Handshake(object):
 class Handshake(object):
@@ -65,7 +65,11 @@ class Handshake(object):
 
 
             start_time = time.time()
             start_time = time.time()
 
 
-            while hdr[-4:] != '\r\n\r\n' and len(hdr) < MAX_HDR_LEN:
+            while hdr[-4:] != '\r\n\r\n':
+                if len(hdr) == MAX_HDR_LEN:
+                    raise HandshakeError('request exceeds maximum header '
+                                         'length of %d' % MAX_HDR_LEN)
+
                 hdr += self.sock.recv(1)
                 hdr += self.sock.recv(1)
 
 
                 time_diff = time.time() - start_time
                 time_diff = time.time() - start_time
@@ -169,23 +173,19 @@ class ServerHandshake(Handshake):
         # Only supported extensions are returned
         # Only supported extensions are returned
         if 'Sec-WebSocket-Extensions' in headers:
         if 'Sec-WebSocket-Extensions' in headers:
             supported_ext = dict((e.name, e) for e in ssock.extensions)
             supported_ext = dict((e.name, e) for e in ssock.extensions)
+            self.wsock.extension_hooks = []
             extensions = []
             extensions = []
-            all_params = []
 
 
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
                 name, params = parse_param_hdr(ext)
                 name, params = parse_param_hdr(ext)
 
 
                 if name in supported_ext:
                 if name in supported_ext:
-                    extensions.append(supported_ext[name])
-                    all_params.append(params)
-
-            self.wsock.extensions = filter_extensions(extensions)
+                    ext = supported_ext[name]
 
 
-            for ext, params in zip(self.wsock.extensions, all_params):
-                hook = ext.create_hook(**params)
-                self.wsock.add_hook(send=hook.send, recv=hook.recv)
-        else:
-            self.wsock.extensions = []
+                    if not extension_conflicts(ext, extensions):
+                        extensions.append(ext)
+                        hook = ext.create_hook(**params)
+                        self.wsock.extension_hooks.append(hook)
 
 
         # Check if requested resource location is served by this server
         # Check if requested resource location is served by this server
         if ssock.locations:
         if ssock.locations:
@@ -212,7 +212,7 @@ class ServerHandshake(Handshake):
         location = '%s://%s%s' % (scheme, host, self.wsock.location)
         location = '%s://%s%s' % (scheme, host, self.wsock.location)
 
 
         # Construct HTTP response header
         # Construct HTTP response header
-        yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
+        yield 'HTTP/1.1 101 Switching Protocols'
         yield 'Upgrade', 'websocket'
         yield 'Upgrade', 'websocket'
         yield 'Connection', 'Upgrade'
         yield 'Connection', 'Upgrade'
         yield 'Sec-WebSocket-Origin', origin
         yield 'Sec-WebSocket-Origin', origin
@@ -274,7 +274,7 @@ class ClientHandshake(Handshake):
         # Compare extensions, add hooks only for those returned by server
         # Compare extensions, add hooks only for those returned by server
         if 'Sec-WebSocket-Extensions' in headers:
         if 'Sec-WebSocket-Extensions' in headers:
             supported_ext = dict((e.name, e) for e in self.wsock.extensions)
             supported_ext = dict((e.name, e) for e in self.wsock.extensions)
-            self.wsock.extensions = []
+            self.wsock.extension_hooks = []
 
 
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
                 name, params = parse_param_hdr(ext)
                 name, params = parse_param_hdr(ext)
@@ -284,8 +284,7 @@ class ClientHandshake(Handshake):
                                          'unsupported extension "%s"' % name)
                                          'unsupported extension "%s"' % name)
 
 
                 hook = supported_ext[name].create_hook(**params)
                 hook = supported_ext[name].create_hook(**params)
-                self.wsock.extensions.append(supported_ext[name])
-                self.wsock.add_hook(send=hook.send, recv=hook.recv)
+                self.wsock.extension_hooks.append(hook)
 
 
         # Assert that returned protocol (if any) is supported
         # Assert that returned protocol (if any) is supported
         if 'Sec-WebSocket-Protocol' in headers:
         if 'Sec-WebSocket-Protocol' in headers:

+ 24 - 19
server.py

@@ -32,7 +32,7 @@ class Server(object):
     """
     """
 
 
     def __init__(self, address, loglevel=logging.INFO, ssl_args=None,
     def __init__(self, address, loglevel=logging.INFO, ssl_args=None,
-                 max_join_time=2.0, **kwargs):
+                 max_join_time=2.0, backlog_size=32, **kwargs):
         """
         """
         Constructor for a simple web socket server.
         Constructor for a simple web socket server.
 
 
@@ -53,6 +53,8 @@ class Server(object):
 
 
         `max_join_time` is the maximum time (in seconds) to wait for client
         `max_join_time` is the maximum time (in seconds) to wait for client
         responses after sending CLOSE frames, it defaults to 2 seconds.
         responses after sending CLOSE frames, it defaults to 2 seconds.
+
+        `backlog_size` is directly passed to `websocket.listen`.
         """
         """
         logging.basicConfig(level=loglevel,
         logging.basicConfig(level=loglevel,
                 format='%(asctime)s: %(levelname)s: %(message)s',
                 format='%(asctime)s: %(levelname)s: %(message)s',
@@ -69,14 +71,14 @@ class Server(object):
             self.sock.enable_ssl(server_side=True, **ssl_args)
             self.sock.enable_ssl(server_side=True, **ssl_args)
 
 
         self.sock.bind(address)
         self.sock.bind(address)
-        self.sock.listen(5)
-
-        self.clients = []
-        self.client_threads = []
+        self.sock.listen(backlog_size)
 
 
         self.max_join_time = max_join_time
         self.max_join_time = max_join_time
 
 
     def run(self):
     def run(self):
+        self.clients = []
+        self.client_threads = []
+
         while True:
         while True:
             try:
             try:
                 sock, address = self.sock.accept()
                 sock, address = self.sock.accept()
@@ -134,30 +136,22 @@ class Server(object):
         self.onclose(client, code, reason)
         self.onclose(client, code, reason)
 
 
     def onopen(self, client):
     def onopen(self, client):
-        logging.debug('Opened socket to %s', client)
+        return NotImplemented
 
 
     def onmessage(self, client, message):
     def onmessage(self, client, message):
-        logging.debug('Received %s from %s', message, client)
+        return NotImplemented
 
 
     def onping(self, client, payload):
     def onping(self, client, payload):
-        logging.debug('Sent ping "%s" to %s', payload, client)
+        return NotImplemented
 
 
     def onpong(self, client, payload):
     def onpong(self, client, payload):
-        logging.debug('Received pong "%s" from %s', payload, client)
+        return NotImplemented
 
 
     def onclose(self, client, code, reason):
     def onclose(self, client, code, reason):
-        msg = 'Closed socket to %s' % client
-
-        if code is not None:
-            msg += ' [%d]' % code
-
-        if len(reason):
-            msg += ': ' + reason
-
-        logging.debug(msg)
+        return NotImplemented
 
 
     def onerror(self, client, e):
     def onerror(self, client, e):
-        logging.error(format_exc(e))
+        return NotImplemented
 
 
 
 
 class Client(Connection):
 class Client(Connection):
@@ -176,21 +170,32 @@ class Client(Connection):
         Connection.send(self, message, fragment_size=fragment_size, mask=mask)
         Connection.send(self, message, fragment_size=fragment_size, mask=mask)
 
 
     def onopen(self):
     def onopen(self):
+        logging.debug('Opened socket to %s', self)
         self.server.onopen(self)
         self.server.onopen(self)
 
 
     def onmessage(self, message):
     def onmessage(self, message):
+        logging.debug('Received %s from %s', message, self)
         self.server.onmessage(self, message)
         self.server.onmessage(self, message)
 
 
     def onping(self, payload):
     def onping(self, payload):
+        logging.debug('Sent ping "%s" to %s', payload, self)
         self.server.onping(self, payload)
         self.server.onping(self, payload)
 
 
     def onpong(self, payload):
     def onpong(self, payload):
+        logging.debug('Received pong "%s" from %s', payload, self)
         self.server.onpong(self, payload)
         self.server.onpong(self, payload)
 
 
     def onclose(self, code, reason):
     def onclose(self, code, reason):
+        msg = 'Closed socket to %s' % self
+
+        if code is not None:
+            msg += ': [%d] %s' % (code, reason)
+
+        logging.debug(msg)
         self.server.remove_client(self, code, reason)
         self.server.remove_client(self, code, reason)
 
 
     def onerror(self, e):
     def onerror(self, e):
+        logging.error(format_exc(e))
         self.server.onerror(self, e)
         self.server.onerror(self, e)
 
 
 
 

+ 10 - 2
test/client.py

@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 #!/usr/bin/env python
 import sys
 import sys
+import ssl
 from os.path import abspath, dirname
 from os.path import abspath, dirname
 
 
 basepath = abspath(dirname(abspath(__file__)) + '/..')
 basepath = abspath(dirname(abspath(__file__)) + '/..')
@@ -20,7 +21,7 @@ class EchoClient(Connection):
 
 
     def onmessage(self, msg):
     def onmessage(self, msg):
         print 'Received', msg
         print 'Received', msg
-        raise SocketClosed(None, 'response received')
+        self.close(None, 'response received')
 
 
     def onerror(self, e):
     def onerror(self, e):
         print 'Error:', e
         print 'Error:', e
@@ -29,8 +30,15 @@ class EchoClient(Connection):
         print 'Connection closed'
         print 'Connection closed'
 
 
 
 
+secure = True
+
 if __name__ == '__main__':
 if __name__ == '__main__':
-    print 'Connecting to ws://%s:%d' % ADDR
+    scheme = 'wss' if secure else 'ws'
+    print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
     sock = websocket()
     sock = websocket()
+
+    if secure:
+        sock.enable_ssl(ca_certs='cert.pem', cert_reqs=ssl.CERT_REQUIRED)
+
     sock.connect(ADDR)
     sock.connect(ADDR)
     EchoClient(sock).receive_forever()
     EchoClient(sock).receive_forever()

+ 45 - 0
test/talk.py

@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+import sys
+import socket
+from os.path import abspath, dirname
+
+basepath = abspath(dirname(abspath(__file__)) + '/..')
+sys.path.insert(0, basepath)
+
+from websocket import websocket
+from connection import Connection
+from message import TextMessage
+from errors import SocketClosed
+
+
+if __name__ == '__main__':
+    if len(sys.argv) < 3:
+        print >> sys.stderr, 'Usage: python %s HOST PORT' % sys.argv[0]
+        sys.exit(1)
+
+    host = sys.argv[1]
+    port = int(sys.argv[2])
+
+    sock = websocket()
+    sock.connect((host, port))
+    sock.settimeout(1.0)
+    conn = Connection(sock)
+
+    try:
+        try:
+            while True:
+                msg = TextMessage(raw_input())
+                print 'send:', msg
+                conn.send(msg)
+
+                try:
+                    print 'recv:', conn.recv()
+                except socket.timeout:
+                    print 'no response'
+        except EOFError:
+            conn.close()
+    except SocketClosed as e:
+        if e.initialized:
+            print 'closed connection'
+        else:
+            print 'other side closed connection'

+ 101 - 55
websocket.py

@@ -1,7 +1,7 @@
 import socket
 import socket
 import ssl
 import ssl
 
 
-from frame import receive_frame
+from frame import receive_frame, pop_frame, contains_frame
 from handshake import ServerHandshake, ClientHandshake
 from handshake import ServerHandshake, ClientHandshake
 from errors import SSLError
 from errors import SSLError
 
 
@@ -11,7 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
                    'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
                    'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
                    'proto']
                    'proto']
 
 
-
 class websocket(object):
 class websocket(object):
     """
     """
     Implementation of web socket, upgrades a regular TCP socket to a websocket
     Implementation of web socket, upgrades a regular TCP socket to a websocket
@@ -36,22 +35,23 @@ class websocket(object):
     >>> sock.connect(('', 8000))
     >>> sock.connect(('', 8000))
     >>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
     >>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
     """
     """
-    def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
+    def __init__(self, sock=None, origin=None, protocols=[], extensions=[],
                  location='/', trusted_origins=[], locations=[], auth=None,
                  location='/', trusted_origins=[], locations=[], auth=None,
-                 sfamily=socket.AF_INET, sproto=0):
+                 recv_callback=None, sfamily=socket.AF_INET, sproto=0):
         """
         """
         Create a regular TCP socket of family `family` and protocol
         Create a regular TCP socket of family `family` and protocol
 
 
         `sock` is an optional regular TCP socket to be used for sending binary
         `sock` is an optional regular TCP socket to be used for sending binary
         data. If not specified, a new socket is created.
         data. If not specified, a new socket is created.
 
 
-        `protocols` is a list of supported protocol names.
-
-        `extensions` is a list of supported extensions (`Extension` instances).
-
         `origin` (for client sockets) is the value for the "Origin" header sent
         `origin` (for client sockets) is the value for the "Origin" header sent
         in a client handshake .
         in a client handshake .
 
 
+        `protocols` is a list of supported protocol names.
+
+        `extensions` (for server sockets) is a list of supported extensions
+        (`Extension` instances).
+
         `location` (for client sockets) is optional, used to request a
         `location` (for client sockets) is optional, used to request a
         particular resource in the HTTP handshake. In a URL, this would show as
         particular resource in the HTTP handshake. In a URL, this would show as
         ws://host[:port]/<location>. Use this when the server serves multiple
         ws://host[:port]/<location>. Use this when the server serves multiple
@@ -71,10 +71,17 @@ class websocket(object):
         `auth` is optional, used for HTTP Basic or Digest authentication during
         `auth` is optional, used for HTTP Basic or Digest authentication during
         the handshake. It must be specified as a (username, password) tuple.
         the handshake. It must be specified as a (username, password) tuple.
 
 
+        `recv_callback` is the callback for received frames in asynchronous
+        sockets. Use in conjunction with setblocking(0). The callback itself
+        may for example change the recv_callback attribute to change the
+        behaviour for the next received message. Can be set when calling
+        `queue_send`.
+
         `sfamily` and `sproto` are used for the regular socket constructor.
         `sfamily` and `sproto` are used for the regular socket constructor.
         """
         """
         self.protocols = protocols
         self.protocols = protocols
         self.extensions = extensions
         self.extensions = extensions
+        self.extension_hooks = []
         self.origin = origin
         self.origin = origin
         self.location = location
         self.location = location
         self.trusted_origins = trusted_origins
         self.trusted_origins = trusted_origins
@@ -85,11 +92,16 @@ class websocket(object):
 
 
         self.handshake_sent = False
         self.handshake_sent = False
 
 
-        self.hooks_send = []
-        self.hooks_recv = []
+        self.sendbuf_frames = []
+        self.sendbuf = ''
+        self.recvbuf = ''
+        self.recv_callback = recv_callback
 
 
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
 
+    def set_extensions(self, extensions):
+        self.extensions = [ext.Hook() for ext in extensions]
+
     def __getattr__(self, name):
     def __getattr__(self, name):
         if name in INHERITED_ATTRS:
         if name in INHERITED_ATTRS:
             return getattr(self.sock, name)
             return getattr(self.sock, name)
@@ -122,29 +134,31 @@ class websocket(object):
         ClientHandshake(self).perform()
         ClientHandshake(self).perform()
         self.handshake_sent = True
         self.handshake_sent = True
 
 
+    def apply_send_hooks(self, frame):
+        for hook in self.extension_hooks:
+            frame = hook.send(frame)
+
+        return frame
+
+    def apply_recv_hooks(self, frame):
+        for hook in reversed(self.extension_hooks):
+            frame = hook.recv(frame)
+
+        return frame
+
     def send(self, *args):
     def send(self, *args):
         """
         """
         Send a number of frames.
         Send a number of frames.
         """
         """
         for frame in args:
         for frame in args:
-            for hook in self.hooks_send:
-                frame = hook(frame)
-
-            #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
-            self.sock.sendall(frame.pack())
+            self.sock.sendall(self.apply_send_hooks(frame).pack())
 
 
     def recv(self):
     def recv(self):
         """
         """
         Receive a single frames. This can be either a data frame or a control
         Receive a single frames. This can be either a data frame or a control
         frame.
         frame.
         """
         """
-        frame = receive_frame(self.sock)
-
-        for hook in self.hooks_recv:
-            frame = hook(frame)
-
-        #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
-        return frame
+        return self.apply_recv_hooks(receive_frame(self.sock))
 
 
     def recvn(self, n):
     def recvn(self, n):
         """
         """
@@ -153,47 +167,79 @@ class websocket(object):
         """
         """
         return [self.recv() for i in xrange(n)]
         return [self.recv() for i in xrange(n)]
 
 
-    def enable_ssl(self, *args, **kwargs):
+    def queue_send(self, frame, callback=None, recv_callback=None):
         """
         """
-        Transforms the regular socket.socket to an ssl.SSLSocket for secure
-        connections. Any arguments are passed to ssl.wrap_socket:
-        http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
+        Enqueue `frame` to the send buffer so that it is send on the next
+        `do_async_send`. `callback` is an optional callable to call when the
+        frame has been fully written. `recv_callback` is an optional callable
+        to quickly set the `recv_callback` attribute to.
         """
         """
-        if self.handshake_sent:
-            raise SSLError('can only enable SSL before handshake')
+        frame = self.apply_send_hooks(frame)
+        self.sendbuf += frame.pack()
+        self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
 
 
-        self.secure = True
-        self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
+        if recv_callback:
+            self.recv_callback = recv_callback
+
+    def do_async_send(self):
+        """
+        Send any queued data. This function should only be called after a write
+        event on a file descriptor.
+        """
+        assert len(self.sendbuf)
+
+        nwritten = self.sock.send(self.sendbuf)
+        nframes = 0
 
 
-    def add_hook(self, send=None, recv=None, prepend=False):
+        for entry in self.sendbuf_frames:
+            frame, offset, callback = entry
+
+            if offset <= nwritten:
+                nframes += 1
+
+                if callback:
+                    callback()
+            else:
+                entry[1] -= nwritten
+
+        self.sendbuf = self.sendbuf[nwritten:]
+        self.sendbuf_frames = self.sendbuf_frames[nframes:]
+
+    def do_async_recv(self, bufsize):
+        """
+        Receive any completed frames from the socket. This function should only
+        be called after a read event on a file descriptor.
         """
         """
-        Add a pair of send and receive hooks that are called for each frame
-        that is sent or received. A hook is a function that receives a single
-        argument - a Frame instance - and returns a `Frame` instance as well.
+        data = self.sock.recv(bufsize)
+
+        if len(data) == 0:
+            raise socket.error('no data to receive')
 
 
-        `prepend` is a flag indicating whether the send hook is prepended to
-        the other send hooks. This is expecially useful when a program uses
-        extensions such as the built-in `DeflateFrame` extension. These
-        extensions are installed using these hooks as well.
+        self.recvbuf += data
 
 
-        For example, the following code creates a `Frame` instance for data
-        being sent and removes the instance for received data. This way, data
-        can be sent and received as if on a regular socket.
-        >>> import wspy
-        >>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
-        >>>               lambda frame: frame.payload)
+        while contains_frame(self.recvbuf):
+            frame, self.recvbuf = pop_frame(self.recvbuf)
+            frame = self.apply_recv_hooks(frame)
 
 
-        To add base64 encoding to the example above:
-        >>> import base64
-        >>> sock.add_hook(base64.encodestring, base64.decodestring, True)
+            if not self.recv_callback:
+                raise ValueError('no callback installed for %s' % frame)
 
 
-        Note that here `prepend=True`, so that data passed to `send()` is first
-        encoded and then packed into a frame. Of course, one could also decide
-        to add the base64 hook first, or to return a new `Frame` instance with
-        base64-encoded data.
+            self.recv_callback(frame)
+
+    def can_send(self):
+        return len(self.sendbuf) > 0
+
+    def can_recv(self):
+        return self.recv_callback is not None
+
+    def enable_ssl(self, *args, **kwargs):
+        """
+        Transforms the regular socket.socket to an ssl.SSLSocket for secure
+        connections. Any arguments are passed to ssl.wrap_socket:
+        http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
         """
         """
-        if send:
-            self.hooks_send.insert(0 if prepend else -1, send)
+        if self.handshake_sent:
+            raise SSLError('can only enable SSL before handshake')
 
 
-        if recv:
-            self.hooks_recv.insert(-1 if prepend else 0, recv)
+        self.secure = True
+        self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)