Sfoglia il codice sorgente

Merge branch 'async'

Taddeus Kroes 11 anni fa
parent
commit
6c79550e13
13 ha cambiato i file con 592 aggiunte e 224 eliminazioni
  1. 1 0
      README.md
  2. 3 2
      __init__.py
  3. 190 0
      async.py
  4. 52 47
      connection.py
  5. 27 32
      deflate_frame.py
  6. 8 4
      errors.py
  7. 23 24
      extension.py
  8. 92 22
      frame.py
  9. 16 17
      handshake.py
  10. 24 19
      server.py
  11. 10 2
      test/client.py
  12. 45 0
      test/talk.py
  13. 101 55
      websocket.py

+ 1 - 0
README.md

@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library:
 - Secure sockets using SSL certificates (for 'wss://...' URLs).
 - The possibility to add extensions to the web socket protocol. An included
   implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
+- Asynchronous sockets with an EPOLL-based server.
 
 
 Installation

+ 3 - 2
__init__.py

@@ -4,10 +4,11 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
         OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
         CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
         CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
-        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
+        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE, read_frame, pop_frame, \
+        contains_frame
 from connection import Connection
 from message import Message, TextMessage, BinaryMessage
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from extension import Extension
 from deflate_frame import DeflateFrame, WebkitDeflateFrame
-#from multiplex import Multiplex
+from async import AsyncConnection, AsyncServer

+ 190 - 0
async.py

@@ -0,0 +1,190 @@
+import socket
+from select import epoll, EPOLLIN, EPOLLOUT, EPOLLHUP
+from traceback import format_exc
+import logging
+
+from connection import Connection
+from frame import ControlFrame, OPCODE_PING, OPCODE_CONTINUATION, \
+                  create_close_frame
+from server import Server, Client
+from errors import HandshakeError, SocketClosed
+
+
+class AsyncConnection(Connection):
+    def __init__(self, sock):
+        sock.recv_callback = self.contruct_message
+        sock.recv_close_callback = self.onclose
+        self.recvbuf = []
+        Connection.__init__(self, sock)
+
+    def contruct_message(self, frame):
+        if isinstance(frame, ControlFrame):
+            self.handle_control_frame(frame)
+            return
+
+        self.recvbuf.append(frame)
+
+        if frame.final:
+            message = self.concat_fragments(self.recvbuf)
+            self.recvbuf = []
+            self.onmessage(message)
+        elif len(self.recvbuf) > 1 and frame.opcode != OPCODE_CONTINUATION:
+            raise ValueError('expected continuation/control frame, got %s '
+                             'instead' % frame)
+
+    def send(self, message, fragment_size=None, mask=False):
+        frames = list(self.message_to_frames(message, fragment_size, mask))
+
+        for frame in frames[:-1]:
+            self.sock.queue_send(frame)
+
+        self.sock.queue_send(frames[-1], lambda: self.onsend(message))
+
+    def send_frame(self, frame, callback):
+        self.sock.queue_send(frame, callback)
+
+    def do_async_send(self):
+        self.execute_controlled(self.sock.do_async_send)
+
+    def do_async_recv(self, bufsize):
+        self.execute_controlled(self.sock.do_async_recv, bufsize)
+
+    def execute_controlled(self, func, *args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except (KeyboardInterrupt, SystemExit, SocketClosed):
+            raise
+        except Exception as e:
+            self.onerror(e)
+            self.onclose(None, 'error: %s' % e)
+
+            try:
+                self.sock.close()
+            except socket.error:
+                pass
+
+            raise e
+
+    def send_close_frame(self, code, reason):
+        self.sock.queue_send(create_close_frame(code, reason),
+                             self.shutdown_write)
+        self.close_frame_sent = True
+
+    def close(self, code=None, reason=''):
+        self.send_close_frame(code, reason)
+
+    def send_ping(self, payload=''):
+        self.sock.queue_send(ControlFrame(OPCODE_PING, payload),
+                             lambda: self.onping(payload))
+        self.ping_payload = payload
+        self.ping_sent = True
+
+    def onsend(self, message):
+        """
+        Called after a message has been written.
+        """
+        return NotImplemented
+
+
+class AsyncServer(Server):
+    def __init__(self, *args, **kwargs):
+        Server.__init__(self, *args, **kwargs)
+
+        self.recvbuf_size = kwargs.get('recvbuf_size', 2048)
+
+        self.epoll = epoll()
+        self.epoll.register(self.sock.fileno(), EPOLLIN)
+        self.conns = {}
+
+    @property
+    def clients(self):
+        return self.conns.values()
+
+    def remove_client(self, client, code, reason):
+        self.epoll.unregister(client.fno)
+        del self.conns[client.fno]
+        self.onclose(client, code, reason)
+
+    def handle_events(self):
+        for fileno, event in self.epoll.poll(1):
+            if fileno == self.sock.fileno():
+                try:
+                    sock, addr = self.sock.accept()
+                except HandshakeError as e:
+                    logging.error('Invalid request: %s', e.message)
+                    continue
+
+                client = AsyncClient(self, sock)
+                client.fno = sock.fileno()
+                sock.setblocking(0)
+                self.epoll.register(client.fno, EPOLLIN)
+                self.conns[client.fno] = client
+                logging.debug('Registered client %s', client)
+
+            elif event & EPOLLHUP:
+                self.epoll.unregister(fileno)
+                del self.conns[fileno]
+
+            else:
+                conn = self.conns[fileno]
+
+                try:
+                    if event & EPOLLOUT:
+                        conn.do_async_send()
+                    elif event & EPOLLIN:
+                        conn.do_async_recv(self.recvbuf_size)
+                except (KeyboardInterrupt, SystemExit):
+                    raise
+                except SocketClosed:
+                    continue
+                except Exception as e:
+                    logging.error(format_exc(e).rstrip())
+                    continue
+
+                self.update_mask(conn)
+
+    def run(self):
+        try:
+            while True:
+                self.handle_events()
+        except (KeyboardInterrupt, SystemExit):
+            logging.info('Received interrupt, stopping server...')
+        finally:
+            self.epoll.unregister(self.sock.fileno())
+            self.epoll.close()
+            self.sock.close()
+
+    def update_mask(self, conn):
+        mask = 0
+
+        if conn.sock.can_send():
+            mask |= EPOLLOUT
+
+        if conn.sock.can_recv():
+            mask |= EPOLLIN
+
+        self.epoll.modify(conn.sock.fileno(), mask)
+
+    def onsend(self, client, message):
+        return NotImplemented
+
+
+class AsyncClient(Client, AsyncConnection):
+    def __init__(self, server, sock):
+        self.server = server
+        AsyncConnection.__init__(self, sock)
+
+    def send(self, message, fragment_size=None, mask=False):
+        logging.debug('Enqueueing %s to %s', message, self)
+        AsyncConnection.send(self, message, fragment_size, mask)
+        self.server.update_mask(self)
+
+    def onsend(self, message):
+        logging.debug('Finished sending %s to %s', message, self)
+        self.server.onsend(self, message)
+
+
+if __name__ == '__main__':
+    import sys
+    port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
+    AsyncServer(('', port), loglevel=logging.DEBUG).run()

+ 52 - 47
connection.py

@@ -1,8 +1,7 @@
-import struct
 import socket
 
 from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
-                  OPCODE_CONTINUATION
+                  OPCODE_CONTINUATION, create_close_frame
 from message import create_message
 from errors import SocketClosed, PingError
 
@@ -54,19 +53,30 @@ class Connection(object):
 
         self.onopen()
 
+    def message_to_frames(self, message, fragment_size=None, mask=False):
+        for hook in self.hooks_send:
+            message = hook(message)
+
+        if fragment_size is None:
+            yield message.frame(mask=mask)
+        else:
+            for frame in message.fragment(fragment_size, mask=mask):
+                yield frame
+
     def send(self, message, fragment_size=None, mask=False):
         """
         Send a message. If `fragment_size` is specified, the message is
         fragmented into multiple frames whose payload size does not extend
         `fragment_size`.
         """
-        for hook in self.hooks_send:
-            message = hook(message)
+        for frame in self.message_to_frames(message, fragment_size, mask):
+            self.send_frame(frame)
 
-        if fragment_size is None:
-            self.sock.send(message.frame(mask=mask))
-        else:
-            self.sock.send(*message.fragment(fragment_size, mask=mask))
+    def send_frame(self, frame, callback=None):
+        self.sock.send(frame)
+
+        if callback:
+            callback()
 
     def recv(self):
         """
@@ -82,12 +92,15 @@ class Connection(object):
 
             if isinstance(frame, ControlFrame):
                 self.handle_control_frame(frame)
-            elif len(fragments) and frame.opcode != OPCODE_CONTINUATION:
+            elif len(fragments) > 0 and frame.opcode != OPCODE_CONTINUATION:
                 raise ValueError('expected continuation/control frame, got %s '
                                  'instead' % frame)
             else:
                 fragments.append(frame)
 
+        return self.concat_fragments(fragments)
+
+    def concat_fragments(self, fragments):
         payload = bytearray()
 
         for f in fragments:
@@ -105,16 +118,20 @@ class Connection(object):
         Handle a control frame as defined by RFC 6455.
         """
         if frame.opcode == OPCODE_CLOSE:
-            # Close the connection from this end as well
             self.close_frame_received = True
             code, reason = frame.unpack_close()
 
-            # No more receiving data after a close message
-            raise SocketClosed(code, reason)
+            if self.close_frame_sent:
+                self.onclose(code, reason)
+                self.sock.close()
+                raise SocketClosed(True)
+            else:
+                self.close_params = (code, reason)
+                self.send_close_frame(code, reason)
 
         elif frame.opcode == OPCODE_PING:
             # Respond with a pong message with identical payload
-            self.sock.send(ControlFrame(OPCODE_PONG, frame.payload))
+            self.send_frame(ControlFrame(OPCODE_PONG, frame.payload))
 
         elif frame.opcode == OPCODE_PONG:
             # Assert that the PONG payload is identical to that of the PING
@@ -138,38 +155,40 @@ class Connection(object):
         while True:
             try:
                 self.onmessage(self.recv())
-            except SocketClosed as e:
-                self.close(e.code, e.reason)
+            except (KeyboardInterrupt, SystemExit, SocketClosed):
                 break
-            except socket.error as e:
+            except Exception as e:
                 self.onerror(e)
+                self.onclose(None, 'error: %s' % e)
 
                 try:
                     self.sock.close()
                 except socket.error:
                     pass
 
-                self.onclose(None, '')
-                break
-            except Exception as e:
-                self.onerror(e)
+                raise e
 
     def send_ping(self, payload=''):
         """
         Send a PING control frame with an optional payload.
         """
-        self.sock.send(ControlFrame(OPCODE_PING, payload))
+        self.send_frame(ControlFrame(OPCODE_PING, payload),
+                        lambda: self.onping(payload))
         self.ping_payload = payload
         self.ping_sent = True
-        self.onping(payload)
 
-    def send_close_frame(self, code=None, reason=''):
-        """
-        Send a CLOSE control frame.
-        """
-        payload = '' if code is None else struct.pack('!H', code) + reason
-        self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
+    def send_close_frame(self, code, reason):
+        self.send_frame(create_close_frame(code, reason))
         self.close_frame_sent = True
+        self.shutdown_write()
+
+    def shutdown_write(self):
+        if self.close_frame_received:
+            self.onclose(*self.close_params)
+            self.sock.close()
+            raise SocketClosed(False)
+        else:
+            self.sock.shutdown(socket.SHUT_WR)
 
     def close(self, code=None, reason=''):
         """
@@ -179,28 +198,14 @@ class Connection(object):
         called after the response has been received, but before the socket is
         actually closed.
         """
-        # Send CLOSE frame
-        if not self.close_frame_sent:
-            self.send_close_frame(code, reason)
-
-        # Receive CLOSE frame
-        if not self.close_frame_received:
-            frame = self.sock.recv()
-
-            if frame.opcode != OPCODE_CLOSE:
-                raise ValueError('expected CLOSE frame, got %s' % frame)
-
-            self.close_frame_received = True
-            res_code, res_reason = frame.unpack_close()
+        self.send_close_frame(code, reason)
 
-            # FIXME: check if res_code == code and res_reason == reason?
+        frame = self.sock.recv()
 
-            # FIXME: alternatively, keep receiving frames in a loop until a
-            # CLOSE frame is received, so that a fragmented chain may arrive
-            # fully first
+        if frame.opcode != OPCODE_CLOSE:
+            raise ValueError('expected CLOSE frame, got %s' % frame)
 
-        self.onclose(code, reason)
-        self.sock.close()
+        self.handle_control_frame(frame)
 
     def add_hook(self, send=None, recv=None, prepend=False):
         """

+ 27 - 32
deflate_frame.py

@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
 
     name = 'deflate-frame'
     rsv1 = True
-    defaults = {'max_window_bits': 15, 'no_context_takeover': False}
+    defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
 
-    def __init__(self, defaults={}, request={}):
-        Extension.__init__(self, defaults, request)
+    COMPRESSION_THRESHOLD = 64  # minimal payload size for compression
 
+    def init(self):
         mwb = self.defaults['max_window_bits']
         cto = self.defaults['no_context_takeover']
 
-        if not isinstance(mwb, int):
-            raise ValueError('"max_window_bits" must be an integer')
-        elif mwb > 15:
-            raise ValueError('"max_window_bits" may not be larger than 15')
+        if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
+            raise ValueError('"max_window_bits" must be in range 1-15')
 
         if cto is not False and cto is not True:
             raise ValueError('"no_context_takeover" must have no value')
 
     class Hook(Extension.Hook):
-        def __init__(self, extension, **kwargs):
-            Extension.Hook.__init__(self, extension, **kwargs)
-
-            if not self.no_context_takeover:
-                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                            zlib.DEFLATED,
-                                            -self.max_window_bits)
-
-            other_wbits = self.extension.request.get('max_window_bits', 15)
+        def init(self, extension):
+            self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                                         zlib.DEFLATED, -self.max_window_bits)
+            other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
             self.dec = zlib.decompressobj(-other_wbits)
 
         def send(self, frame):
-            if not frame.rsv1 and not isinstance(frame, ControlFrame):
+            # FIXME: this does not seem to work properly on Android
+            if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
+                   len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
                 frame.rsv1 = True
-                frame.payload = self.deflate(frame.payload)
+                frame.payload = self.deflate(frame)
 
             return frame
 
@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
 
             return frame
 
-        def deflate(self, data):
-            if self.no_context_takeover:
-                defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
-                                        zlib.DEFLATED, -self.max_window_bits)
-                # FIXME: why the '\x00' below? This was borrowed from
-                # https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
-                return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
+        def deflate(self, frame):
+            compressed = self.defl.compress(frame.payload)
+
+            if frame.final or self.no_context_takeover:
+                compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
+                self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+                        zlib.DEFLATED, -self.max_window_bits)
+            else:
+                compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
+                assert compressed[-4:] == '\x00\x00\xff\xff'
+                compressed = compressed[:-4]
 
-            compressed = self.defl.compress(data)
-            compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
-            assert compressed[-4:] == '\x00\x00\xff\xff'
-            return compressed[:-4]
+            return compressed
 
         def inflate(self, data):
-            data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
-            assert not self.dec.unused_data
-            return data
+            return self.dec.decompress(data + '\x00\x00\xff\xff') + \
+                   self.dec.flush(zlib.Z_SYNC_FLUSH)
 
 
 class WebkitDeflateFrame(DeflateFrame):

+ 8 - 4
errors.py

@@ -1,11 +1,15 @@
 class SocketClosed(Exception):
-    def __init__(self, code=None, reason=''):
-        self.code = code
-        self.reason = reason
+    def __init__(self, initialized):
+        self.initialized = initialized
 
     @property
     def message(self):
-        return ('' if self.code is None else '[%d] ' % self.code) + self.reason
+        s = 'socket closed'
+
+        if self.initialized:
+            s += ' (initialized)'
+
+        return s
 
 
 class HandshakeError(Exception):

+ 23 - 24
extension.py

@@ -19,23 +19,31 @@ class Extension(object):
         self.request = dict(self.__class__.request)
         self.request.update(request)
 
+        self.init()
+
     def __str__(self):
         return '<Extension "%s" defaults=%s request=%s>' \
                % (self.name, self.defaults, self.request)
 
+    def init(self):
+        return NotImplemented
+
     def create_hook(self, **kwargs):
         params = {}
         params.update(self.defaults)
         params.update(kwargs)
-        return self.Hook(self, **params)
+        hook = self.Hook(**params)
+        hook.init(self)
+        return hook
 
     class Hook:
-        def __init__(self, extension, **kwargs):
-            self.extension = extension
-
+        def __init__(self, **kwargs):
             for param, value in kwargs.iteritems():
                 setattr(self, param, value)
 
+        def init(self, extension):
+            return NotImplemented
+
         def send(self, frame):
             return frame
 
@@ -43,28 +51,19 @@ class Extension(object):
             return frame
 
 
-def filter_extensions(extensions):
-    """
-    Remove extensions that use conflicting rsv bits and/or opcodes, with the
-    first options being the most preferable.
-    """
+def extension_conflicts(ext, existing):
     rsv1_reserved = False
     rsv2_reserved = False
     rsv3_reserved = False
-    opcodes_reserved = []
-    compat = []
-
-    for ext in extensions:
-        if ext.rsv1 and rsv1_reserved \
-                or ext.rsv2 and rsv2_reserved \
-                or ext.rsv3 and rsv3_reserved \
-                or len(set(ext.opcodes) & set(opcodes_reserved)):
-            continue
+    reserved_opcodes = []
 
-        rsv1_reserved |= ext.rsv1
-        rsv2_reserved |= ext.rsv2
-        rsv3_reserved |= ext.rsv3
-        opcodes_reserved.extend(ext.opcodes)
-        compat.append(ext)
+    for e in existing:
+        rsv1_reserved |= e.rsv1
+        rsv2_reserved |= e.rsv2
+        rsv3_reserved |= e.rsv3
+        reserved_opcodes.extend(e.opcodes)
 
-    return compat
+    return ext.rsv1 and rsv1_reserved \
+            or ext.rsv2 and rsv2_reserved \
+            or ext.rsv3 and rsv3_reserved \
+            or len(set(ext.opcodes) & set(reserved_opcodes))

+ 92 - 22
frame.py

@@ -21,9 +21,11 @@ CLOSE_MESSAGE_TOOBIG = 1009
 CLOSE_MISSING_EXTENSIONS = 1010
 CLOSE_UNABLE = 1011
 
+line_printable = [c for c in printable if c not in '\r\n\x0b\x0c']
+
 
 def printstr(s):
-    return ''.join(c if c in printable else '.' for c in s)
+    return ''.join(c if c in line_printable else '.' for c in str(s))
 
 
 class Frame(object):
@@ -154,7 +156,18 @@ class Frame(object):
         if len(self.payload) > max_pl_disp:
             pl += '...'
 
-        return s + ' payload=%s>' % pl
+        s += ' payload=%s' % pl
+
+        if self.rsv1:
+            s += ' rsv1'
+
+        if self.rsv2:
+            s += ' rsv2'
+
+        if self.rsv3:
+            s += ' rsv3'
+
+        return s + '>'
 
 
 class ControlFrame(Frame):
@@ -194,12 +207,8 @@ class ControlFrame(Frame):
         return code, reason
 
 
-def receive_frame(sock):
-    """
-    Receive a single frame on socket `sock`. The frame scheme is explained in
-    the docs of Frame.pack().
-    """
-    b1, b2 = struct.unpack('!BB', recvn(sock, 2))
+def decode_frame(reader):
+    b1, b2 = struct.unpack('!BB', reader.readn(2))
 
     final = bool(b1 & 0x80)
     rsv1 = bool(b1 & 0x40)
@@ -211,16 +220,16 @@ def receive_frame(sock):
     payload_len = b2 & 0x7F
 
     if payload_len == 126:
-        payload_len = struct.unpack('!H', recvn(sock, 2))
+        payload_len = struct.unpack('!H', reader.readn(2))
     elif payload_len == 127:
-        payload_len = struct.unpack('!Q', recvn(sock, 8))
+        payload_len = struct.unpack('!Q', reader.readn(8))
 
     if masked:
-        masking_key = recvn(sock, 4)
-        payload = mask(masking_key, recvn(sock, payload_len))
+        masking_key = reader.readn(4)
+        payload = mask(masking_key, reader.readn(payload_len))
     else:
         masking_key = ''
-        payload = recvn(sock, payload_len)
+        payload = reader.readn(payload_len)
 
     # Control frames have most significant bit 1
     cls = ControlFrame if opcode & 0x8 else Frame
@@ -229,21 +238,77 @@ def receive_frame(sock):
                rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
 
 
-def recvn(sock, n):
+def receive_frame(sock):
+    return decode_frame(SocketReader(sock))
+
+
+def read_frame(data):
+    reader = BufferReader(data)
+    frame = decode_frame(reader)
+    return frame, reader.offset
+
+
+def pop_frame(data):
+    frame, size = read_frame(data)
+    return frame, data[size:]
+
+
+class BufferReader(object):
+    def __init__(self, data):
+        self.data = data
+        self.offset = 0
+
+    def readn(self, n):
+        assert len(self.data) - self.offset >= n
+        self.offset += n
+        return self.data[self.offset - n:self.offset]
+
+
+class SocketReader(object):
+    def __init__(self, sock):
+        self.sock = sock
+
+    def readn(self, n):
+        """
+        Keep receiving data until exactly `n` bytes have been read.
+        """
+        data = ''
+
+        while len(data) < n:
+            received = self.sock.recv(n - len(data))
+
+            if not len(received):
+                raise socket.error('no data read from socket')
+
+            data += received
+
+        return data
+
+
+def contains_frame(data):
     """
-    Keep receiving data from `sock` until exactly `n` bytes have been read.
+    Read the frame length from the start of `data` and check if the data is
+    long enough to contain the entire frame.
     """
-    data = ''
+    if len(data) < 2:
+        return False
+
+    b2 = struct.unpack('!B', data[1])[0]
+    payload_len = b2 & 0x7F
+    payload_start = 2
 
-    while len(data) < n:
-        received = sock.recv(n - len(data))
+    if payload_len == 126:
+        if len(data) > 4:
+            payload_len = struct.unpack('!H', data[2:4])
 
-        if not len(received):
-            raise socket.error('no data read from socket')
+        payload_start = 4
+    elif payload_len == 127:
+        if len(data) > 12:
+            payload_len = struct.unpack('!Q', data[4:12])
 
-        data += received
+        payload_start = 12
 
-    return data
+    return len(data) >= payload_len + payload_start
 
 
 def mask(key, original):
@@ -265,3 +330,8 @@ def mask(key, original):
         masked[i] ^= key[i % 4]
 
     return masked
+
+
+def create_close_frame(code, reason):
+    payload = '' if code is None else struct.pack('!H', code) + reason
+    return ControlFrame(OPCODE_CLOSE, payload)

+ 16 - 17
handshake.py

@@ -7,7 +7,7 @@ from hashlib import sha1
 from urlparse import urlparse
 
 from errors import HandshakeError
-from extension import filter_extensions
+from extension import extension_conflicts
 from python_digest import build_authorization_request
 
 
@@ -15,7 +15,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
 WS_VERSION = '13'
 MAX_REDIRECTS = 10
 HDR_TIMEOUT = 5
-MAX_HDR_LEN = 512
+MAX_HDR_LEN = 1024
 
 
 class Handshake(object):
@@ -65,7 +65,11 @@ class Handshake(object):
 
             start_time = time.time()
 
-            while hdr[-4:] != '\r\n\r\n' and len(hdr) < MAX_HDR_LEN:
+            while hdr[-4:] != '\r\n\r\n':
+                if len(hdr) == MAX_HDR_LEN:
+                    raise HandshakeError('request exceeds maximum header '
+                                         'length of %d' % MAX_HDR_LEN)
+
                 hdr += self.sock.recv(1)
 
                 time_diff = time.time() - start_time
@@ -169,23 +173,19 @@ class ServerHandshake(Handshake):
         # Only supported extensions are returned
         if 'Sec-WebSocket-Extensions' in headers:
             supported_ext = dict((e.name, e) for e in ssock.extensions)
+            self.wsock.extension_hooks = []
             extensions = []
-            all_params = []
 
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
                 name, params = parse_param_hdr(ext)
 
                 if name in supported_ext:
-                    extensions.append(supported_ext[name])
-                    all_params.append(params)
-
-            self.wsock.extensions = filter_extensions(extensions)
+                    ext = supported_ext[name]
 
-            for ext, params in zip(self.wsock.extensions, all_params):
-                hook = ext.create_hook(**params)
-                self.wsock.add_hook(send=hook.send, recv=hook.recv)
-        else:
-            self.wsock.extensions = []
+                    if not extension_conflicts(ext, extensions):
+                        extensions.append(ext)
+                        hook = ext.create_hook(**params)
+                        self.wsock.extension_hooks.append(hook)
 
         # Check if requested resource location is served by this server
         if ssock.locations:
@@ -212,7 +212,7 @@ class ServerHandshake(Handshake):
         location = '%s://%s%s' % (scheme, host, self.wsock.location)
 
         # Construct HTTP response header
-        yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
+        yield 'HTTP/1.1 101 Switching Protocols'
         yield 'Upgrade', 'websocket'
         yield 'Connection', 'Upgrade'
         yield 'Sec-WebSocket-Origin', origin
@@ -274,7 +274,7 @@ class ClientHandshake(Handshake):
         # Compare extensions, add hooks only for those returned by server
         if 'Sec-WebSocket-Extensions' in headers:
             supported_ext = dict((e.name, e) for e in self.wsock.extensions)
-            self.wsock.extensions = []
+            self.wsock.extension_hooks = []
 
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
                 name, params = parse_param_hdr(ext)
@@ -284,8 +284,7 @@ class ClientHandshake(Handshake):
                                          'unsupported extension "%s"' % name)
 
                 hook = supported_ext[name].create_hook(**params)
-                self.wsock.extensions.append(supported_ext[name])
-                self.wsock.add_hook(send=hook.send, recv=hook.recv)
+                self.wsock.extension_hooks.append(hook)
 
         # Assert that returned protocol (if any) is supported
         if 'Sec-WebSocket-Protocol' in headers:

+ 24 - 19
server.py

@@ -32,7 +32,7 @@ class Server(object):
     """
 
     def __init__(self, address, loglevel=logging.INFO, ssl_args=None,
-                 max_join_time=2.0, **kwargs):
+                 max_join_time=2.0, backlog_size=32, **kwargs):
         """
         Constructor for a simple web socket server.
 
@@ -53,6 +53,8 @@ class Server(object):
 
         `max_join_time` is the maximum time (in seconds) to wait for client
         responses after sending CLOSE frames, it defaults to 2 seconds.
+
+        `backlog_size` is directly passed to `websocket.listen`.
         """
         logging.basicConfig(level=loglevel,
                 format='%(asctime)s: %(levelname)s: %(message)s',
@@ -69,14 +71,14 @@ class Server(object):
             self.sock.enable_ssl(server_side=True, **ssl_args)
 
         self.sock.bind(address)
-        self.sock.listen(5)
-
-        self.clients = []
-        self.client_threads = []
+        self.sock.listen(backlog_size)
 
         self.max_join_time = max_join_time
 
     def run(self):
+        self.clients = []
+        self.client_threads = []
+
         while True:
             try:
                 sock, address = self.sock.accept()
@@ -134,30 +136,22 @@ class Server(object):
         self.onclose(client, code, reason)
 
     def onopen(self, client):
-        logging.debug('Opened socket to %s', client)
+        return NotImplemented
 
     def onmessage(self, client, message):
-        logging.debug('Received %s from %s', message, client)
+        return NotImplemented
 
     def onping(self, client, payload):
-        logging.debug('Sent ping "%s" to %s', payload, client)
+        return NotImplemented
 
     def onpong(self, client, payload):
-        logging.debug('Received pong "%s" from %s', payload, client)
+        return NotImplemented
 
     def onclose(self, client, code, reason):
-        msg = 'Closed socket to %s' % client
-
-        if code is not None:
-            msg += ' [%d]' % code
-
-        if len(reason):
-            msg += ': ' + reason
-
-        logging.debug(msg)
+        return NotImplemented
 
     def onerror(self, client, e):
-        logging.error(format_exc(e))
+        return NotImplemented
 
 
 class Client(Connection):
@@ -176,21 +170,32 @@ class Client(Connection):
         Connection.send(self, message, fragment_size=fragment_size, mask=mask)
 
     def onopen(self):
+        logging.debug('Opened socket to %s', self)
         self.server.onopen(self)
 
     def onmessage(self, message):
+        logging.debug('Received %s from %s', message, self)
         self.server.onmessage(self, message)
 
     def onping(self, payload):
+        logging.debug('Sent ping "%s" to %s', payload, self)
         self.server.onping(self, payload)
 
     def onpong(self, payload):
+        logging.debug('Received pong "%s" from %s', payload, self)
         self.server.onpong(self, payload)
 
     def onclose(self, code, reason):
+        msg = 'Closed socket to %s' % self
+
+        if code is not None:
+            msg += ': [%d] %s' % (code, reason)
+
+        logging.debug(msg)
         self.server.remove_client(self, code, reason)
 
     def onerror(self, e):
+        logging.error(format_exc(e))
         self.server.onerror(self, e)
 
 

+ 10 - 2
test/client.py

@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 import sys
+import ssl
 from os.path import abspath, dirname
 
 basepath = abspath(dirname(abspath(__file__)) + '/..')
@@ -20,7 +21,7 @@ class EchoClient(Connection):
 
     def onmessage(self, msg):
         print 'Received', msg
-        raise SocketClosed(None, 'response received')
+        self.close(None, 'response received')
 
     def onerror(self, e):
         print 'Error:', e
@@ -29,8 +30,15 @@ class EchoClient(Connection):
         print 'Connection closed'
 
 
+secure = True
+
 if __name__ == '__main__':
-    print 'Connecting to ws://%s:%d' % ADDR
+    scheme = 'wss' if secure else 'ws'
+    print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
     sock = websocket()
+
+    if secure:
+        sock.enable_ssl(ca_certs='cert.pem', cert_reqs=ssl.CERT_REQUIRED)
+
     sock.connect(ADDR)
     EchoClient(sock).receive_forever()

+ 45 - 0
test/talk.py

@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+import sys
+import socket
+from os.path import abspath, dirname
+
+basepath = abspath(dirname(abspath(__file__)) + '/..')
+sys.path.insert(0, basepath)
+
+from websocket import websocket
+from connection import Connection
+from message import TextMessage
+from errors import SocketClosed
+
+
+if __name__ == '__main__':
+    if len(sys.argv) < 3:
+        print >> sys.stderr, 'Usage: python %s HOST PORT' % sys.argv[0]
+        sys.exit(1)
+
+    host = sys.argv[1]
+    port = int(sys.argv[2])
+
+    sock = websocket()
+    sock.connect((host, port))
+    sock.settimeout(1.0)
+    conn = Connection(sock)
+
+    try:
+        try:
+            while True:
+                msg = TextMessage(raw_input())
+                print 'send:', msg
+                conn.send(msg)
+
+                try:
+                    print 'recv:', conn.recv()
+                except socket.timeout:
+                    print 'no response'
+        except EOFError:
+            conn.close()
+    except SocketClosed as e:
+        if e.initialized:
+            print 'closed connection'
+        else:
+            print 'other side closed connection'

+ 101 - 55
websocket.py

@@ -1,7 +1,7 @@
 import socket
 import ssl
 
-from frame import receive_frame
+from frame import receive_frame, pop_frame, contains_frame
 from handshake import ServerHandshake, ClientHandshake
 from errors import SSLError
 
@@ -11,7 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
                    'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
                    'proto']
 
-
 class websocket(object):
     """
     Implementation of web socket, upgrades a regular TCP socket to a websocket
@@ -36,22 +35,23 @@ class websocket(object):
     >>> sock.connect(('', 8000))
     >>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
     """
-    def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
+    def __init__(self, sock=None, origin=None, protocols=[], extensions=[],
                  location='/', trusted_origins=[], locations=[], auth=None,
-                 sfamily=socket.AF_INET, sproto=0):
+                 recv_callback=None, sfamily=socket.AF_INET, sproto=0):
         """
         Create a regular TCP socket of family `family` and protocol
 
         `sock` is an optional regular TCP socket to be used for sending binary
         data. If not specified, a new socket is created.
 
-        `protocols` is a list of supported protocol names.
-
-        `extensions` is a list of supported extensions (`Extension` instances).
-
         `origin` (for client sockets) is the value for the "Origin" header sent
         in a client handshake .
 
+        `protocols` is a list of supported protocol names.
+
+        `extensions` (for server sockets) is a list of supported extensions
+        (`Extension` instances).
+
         `location` (for client sockets) is optional, used to request a
         particular resource in the HTTP handshake. In a URL, this would show as
         ws://host[:port]/<location>. Use this when the server serves multiple
@@ -71,10 +71,17 @@ class websocket(object):
         `auth` is optional, used for HTTP Basic or Digest authentication during
         the handshake. It must be specified as a (username, password) tuple.
 
+        `recv_callback` is the callback for received frames in asynchronous
+        sockets. Use in conjunction with setblocking(0). The callback itself
+        may for example change the recv_callback attribute to change the
+        behaviour for the next received message. Can be set when calling
+        `queue_send`.
+
         `sfamily` and `sproto` are used for the regular socket constructor.
         """
         self.protocols = protocols
         self.extensions = extensions
+        self.extension_hooks = []
         self.origin = origin
         self.location = location
         self.trusted_origins = trusted_origins
@@ -85,11 +92,16 @@ class websocket(object):
 
         self.handshake_sent = False
 
-        self.hooks_send = []
-        self.hooks_recv = []
+        self.sendbuf_frames = []
+        self.sendbuf = ''
+        self.recvbuf = ''
+        self.recv_callback = recv_callback
 
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
+    def set_extensions(self, extensions):
+        self.extensions = [ext.Hook() for ext in extensions]
+
     def __getattr__(self, name):
         if name in INHERITED_ATTRS:
             return getattr(self.sock, name)
@@ -122,29 +134,31 @@ class websocket(object):
         ClientHandshake(self).perform()
         self.handshake_sent = True
 
+    def apply_send_hooks(self, frame):
+        for hook in self.extension_hooks:
+            frame = hook.send(frame)
+
+        return frame
+
+    def apply_recv_hooks(self, frame):
+        for hook in reversed(self.extension_hooks):
+            frame = hook.recv(frame)
+
+        return frame
+
     def send(self, *args):
         """
         Send a number of frames.
         """
         for frame in args:
-            for hook in self.hooks_send:
-                frame = hook(frame)
-
-            #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
-            self.sock.sendall(frame.pack())
+            self.sock.sendall(self.apply_send_hooks(frame).pack())
 
     def recv(self):
         """
         Receive a single frames. This can be either a data frame or a control
         frame.
         """
-        frame = receive_frame(self.sock)
-
-        for hook in self.hooks_recv:
-            frame = hook(frame)
-
-        #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
-        return frame
+        return self.apply_recv_hooks(receive_frame(self.sock))
 
     def recvn(self, n):
         """
@@ -153,47 +167,79 @@ class websocket(object):
         """
         return [self.recv() for i in xrange(n)]
 
-    def enable_ssl(self, *args, **kwargs):
+    def queue_send(self, frame, callback=None, recv_callback=None):
         """
-        Transforms the regular socket.socket to an ssl.SSLSocket for secure
-        connections. Any arguments are passed to ssl.wrap_socket:
-        http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
+        Enqueue `frame` to the send buffer so that it is send on the next
+        `do_async_send`. `callback` is an optional callable to call when the
+        frame has been fully written. `recv_callback` is an optional callable
+        to quickly set the `recv_callback` attribute to.
         """
-        if self.handshake_sent:
-            raise SSLError('can only enable SSL before handshake')
+        frame = self.apply_send_hooks(frame)
+        self.sendbuf += frame.pack()
+        self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
 
-        self.secure = True
-        self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
+        if recv_callback:
+            self.recv_callback = recv_callback
+
+    def do_async_send(self):
+        """
+        Send any queued data. This function should only be called after a write
+        event on a file descriptor.
+        """
+        assert len(self.sendbuf)
+
+        nwritten = self.sock.send(self.sendbuf)
+        nframes = 0
 
-    def add_hook(self, send=None, recv=None, prepend=False):
+        for entry in self.sendbuf_frames:
+            frame, offset, callback = entry
+
+            if offset <= nwritten:
+                nframes += 1
+
+                if callback:
+                    callback()
+            else:
+                entry[1] -= nwritten
+
+        self.sendbuf = self.sendbuf[nwritten:]
+        self.sendbuf_frames = self.sendbuf_frames[nframes:]
+
+    def do_async_recv(self, bufsize):
+        """
+        Receive any completed frames from the socket. This function should only
+        be called after a read event on a file descriptor.
         """
-        Add a pair of send and receive hooks that are called for each frame
-        that is sent or received. A hook is a function that receives a single
-        argument - a Frame instance - and returns a `Frame` instance as well.
+        data = self.sock.recv(bufsize)
+
+        if len(data) == 0:
+            raise socket.error('no data to receive')
 
-        `prepend` is a flag indicating whether the send hook is prepended to
-        the other send hooks. This is expecially useful when a program uses
-        extensions such as the built-in `DeflateFrame` extension. These
-        extensions are installed using these hooks as well.
+        self.recvbuf += data
 
-        For example, the following code creates a `Frame` instance for data
-        being sent and removes the instance for received data. This way, data
-        can be sent and received as if on a regular socket.
-        >>> import wspy
-        >>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
-        >>>               lambda frame: frame.payload)
+        while contains_frame(self.recvbuf):
+            frame, self.recvbuf = pop_frame(self.recvbuf)
+            frame = self.apply_recv_hooks(frame)
 
-        To add base64 encoding to the example above:
-        >>> import base64
-        >>> sock.add_hook(base64.encodestring, base64.decodestring, True)
+            if not self.recv_callback:
+                raise ValueError('no callback installed for %s' % frame)
 
-        Note that here `prepend=True`, so that data passed to `send()` is first
-        encoded and then packed into a frame. Of course, one could also decide
-        to add the base64 hook first, or to return a new `Frame` instance with
-        base64-encoded data.
+            self.recv_callback(frame)
+
+    def can_send(self):
+        return len(self.sendbuf) > 0
+
+    def can_recv(self):
+        return self.recv_callback is not None
+
+    def enable_ssl(self, *args, **kwargs):
+        """
+        Transforms the regular socket.socket to an ssl.SSLSocket for secure
+        connections. Any arguments are passed to ssl.wrap_socket:
+        http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
         """
-        if send:
-            self.hooks_send.insert(0 if prepend else -1, send)
+        if self.handshake_sent:
+            raise SSLError('can only enable SSL before handshake')
 
-        if recv:
-            self.hooks_recv.insert(-1 if prepend else 0, recv)
+        self.secure = True
+        self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)