Explorar o código

Rewrote websocket + connection to add support for asynchronous sockets, added asynchrounous connecten & server

Taddeus Kroes %!s(int64=12) %!d(string=hai) anos
pai
achega
a55b2fad33
Modificáronse 10 ficheiros con 322 adicións e 106 borrados
  1. 1 0
      README.md
  2. 2 2
      __init__.py
  3. 185 0
      async.py
  4. 52 47
      connection.py
  5. 8 4
      errors.py
  6. 9 3
      frame.py
  7. 9 10
      server.py
  8. 10 2
      test/client.py
  9. 1 1
      test/server.py
  10. 45 37
      websocket.py

+ 1 - 0
README.md

@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library:
 - Secure sockets using SSL certificates (for 'wss://...' URLs).
 - The possibility to add extensions to the web socket protocol. An included
   implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
+- Asynchronous sockets with an EPOLL-based server.
 
 
 Installation

+ 2 - 2
__init__.py

@@ -1,5 +1,4 @@
-from websocket import websocket, STATE_INIT, STATE_READ, STATE_WRITE, \
-                      STATE_CLOSE
+from websocket import websocket
 from server import Server
 from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
         OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
@@ -12,3 +11,4 @@ from message import Message, TextMessage, BinaryMessage
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from extension import Extension
 from deflate_frame import DeflateFrame, WebkitDeflateFrame
+from async import AsyncConnection, AsyncServer

+ 185 - 0
async.py

@@ -0,0 +1,185 @@
+import socket
+from select import epoll, EPOLLIN, EPOLLOUT, EPOLLHUP
+from traceback import format_exc
+import logging
+
+from connection import Connection
+from frame import ControlFrame, OPCODE_PING, OPCODE_CONTINUATION, \
+                  create_close_frame
+from server import Server, Client
+from errors import HandshakeError, SocketClosed
+
+
+class AsyncConnection(Connection):
+    def __init__(self, sock):
+        sock.recv_callback = self.contruct_message
+        sock.recv_close_callback = self.onclose
+        self.recvbuf = []
+        Connection.__init__(self, sock)
+
+    def contruct_message(self, frame):
+        if isinstance(frame, ControlFrame):
+            self.handle_control_frame(frame)
+            return
+
+        self.recvbuf.append(frame)
+
+        if frame.final:
+            message = self.concat_fragments(self.recvbuf)
+            self.recvbuf = []
+            self.onmessage(message)
+        elif len(self.recvbuf) > 1 and frame.opcode != OPCODE_CONTINUATION:
+            raise ValueError('expected continuation/control frame, got %s '
+                             'instead' % frame)
+
+    def send(self, message, fragment_size=None, mask=False):
+        frames = self.message_to_frames(message, fragment_size, mask)
+
+        for frame in frames[:-1]:
+            self.sock.queue_send(frame)
+
+        self.sock.queue_send(frames[-1], lambda: self.onsend(message))
+
+    def send_frame(self, frame, callback):
+        self.sock.queue_send(frame, callback)
+
+    def do_async_send(self):
+        self.execute_controlled(self.sock.do_async_send)
+
+    def do_async_recv(self, bufsize):
+        self.execute_controlled(self.sock.do_async_recv, bufsize)
+
+    def execute_controlled(self, func, *args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except (KeyboardInterrupt, SystemExit, SocketClosed):
+            raise
+        except Exception as e:
+            self.onerror(e)
+            self.onclose(None, 'error: %s' % e)
+
+            try:
+                self.sock.close()
+            except socket.error:
+                pass
+
+            raise e
+
+    def send_close_frame(self, code, reason):
+        self.sock.queue_send(create_close_frame(code, reason),
+                             self.shutdown_write)
+        self.close_frame_sent = True
+
+    def close(self, code=None, reason=''):
+        self.send_close_frame(code, reason)
+
+    def send_ping(self, payload=''):
+        self.sock.queue_send(ControlFrame(OPCODE_PING, payload),
+                             lambda: self.onping(payload))
+        self.ping_payload = payload
+        self.ping_sent = True
+
+    def onsend(self, message):
+        """
+        Called after a message has been written.
+        """
+        return NotImplemented
+
+
+class AsyncServer(Server):
+    def __init__(self, *args, **kwargs):
+        Server.__init__(self, *args, **kwargs)
+
+        self.recvbuf_size = kwargs.get('recvbuf_size', 2048)
+
+        self.epoll = epoll()
+        self.epoll.register(self.sock.fileno(), EPOLLIN)
+        self.conns = {}
+
+    @property
+    def clients(self):
+        return self.conns.itervalues()
+
+    def remove_client(self, client, code, reason):
+        self.epoll.unregister(client.fno)
+        del self.conns[client.fno]
+        self.onclose(client, code, reason)
+
+    def handle_events(self):
+        for fileno, event in self.epoll.poll(1):
+            if fileno == self.sock.fileno():
+                try:
+                    sock, addr = self.sock.accept()
+                except HandshakeError as e:
+                    logging.error('Invalid request: %s', e.message)
+                    continue
+
+                client = AsyncClient(self, sock)
+                client.fno = sock.fileno()
+                sock.setblocking(0)
+                self.epoll.register(client.fno, EPOLLIN)
+                self.conns[client.fno] = client
+                logging.debug('Registered client %s', client)
+
+            elif event & EPOLLHUP:
+                self.epoll.unregister(fileno)
+                del self.conns[fileno]
+
+            else:
+                conn = self.conns[fileno]
+
+                try:
+                    if event & EPOLLOUT:
+                        conn.do_async_send()
+                    elif event & EPOLLIN:
+                        conn.do_async_recv(self.recvbuf_size)
+                except (KeyboardInterrupt, SystemExit):
+                    raise
+                except SocketClosed:
+                    continue
+                except Exception as e:
+                    logging.error(format_exc(e))
+                    continue
+
+                mask = 0
+
+                if conn.sock.can_send():
+                    mask |= EPOLLOUT
+
+                if conn.sock.can_recv():
+                    mask |= EPOLLIN
+
+                self.epoll.modify(fileno, mask)
+
+    def run(self):
+        try:
+            while True:
+                self.handle_events()
+        except (KeyboardInterrupt, SystemExit):
+            logging.info('Received interrupt, stopping server...')
+        finally:
+            self.epoll.unregister(self.sock.fileno())
+            self.epoll.close()
+            self.sock.close()
+
+    def onsend(self, client, message):
+        logging.debug('Written "%s" to %s', message, client)
+
+
+class AsyncClient(Client, AsyncConnection):
+    def __init__(self, server, sock):
+        self.server = server
+        AsyncConnection.__init__(self, sock)
+
+    def send(self, message, fragment_size=None, mask=False):
+        logging.debug('Enqueueing %s to %s', message, self)
+        AsyncConnection.send(self, message, fragment_size, mask)
+
+    def onsend(self, message):
+        self.server.onsend(self, message)
+
+
+if __name__ == '__main__':
+    import sys
+    port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
+    AsyncServer(('', port), loglevel=logging.DEBUG).run()

+ 52 - 47
connection.py

@@ -1,8 +1,7 @@
-import struct
 import socket
 
 from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
-                  OPCODE_CONTINUATION
+                  OPCODE_CONTINUATION, create_close_frame
 from message import create_message
 from errors import SocketClosed, PingError
 
@@ -54,19 +53,30 @@ class Connection(object):
 
         self.onopen()
 
+    def message_to_frames(self, message, fragment_size=None, mask=False):
+        for hook in self.hooks_send:
+            message = hook(message)
+
+        if fragment_size is None:
+            yield message.frame(mask=mask)
+        else:
+            for frame in message.fragment(fragment_size, mask=mask):
+                yield frame
+
     def send(self, message, fragment_size=None, mask=False):
         """
         Send a message. If `fragment_size` is specified, the message is
         fragmented into multiple frames whose payload size does not extend
         `fragment_size`.
         """
-        for hook in self.hooks_send:
-            message = hook(message)
+        for frame in self.message_to_frames(message, fragment_size, mask):
+            self.send_frame(frame)
 
-        if fragment_size is None:
-            self.sock.send(message.frame(mask=mask))
-        else:
-            self.sock.send(*message.fragment(fragment_size, mask=mask))
+    def send_frame(self, frame, callback=None):
+        self.sock.send(frame)
+
+        if callback:
+            callback()
 
     def recv(self):
         """
@@ -82,12 +92,15 @@ class Connection(object):
 
             if isinstance(frame, ControlFrame):
                 self.handle_control_frame(frame)
-            elif len(fragments) and frame.opcode != OPCODE_CONTINUATION:
+            elif len(fragments) > 0 and frame.opcode != OPCODE_CONTINUATION:
                 raise ValueError('expected continuation/control frame, got %s '
                                  'instead' % frame)
             else:
                 fragments.append(frame)
 
+        return self.concat_fragments(fragments)
+
+    def concat_fragments(self, fragments):
         payload = bytearray()
 
         for f in fragments:
@@ -105,16 +118,20 @@ class Connection(object):
         Handle a control frame as defined by RFC 6455.
         """
         if frame.opcode == OPCODE_CLOSE:
-            # Close the connection from this end as well
             self.close_frame_received = True
             code, reason = frame.unpack_close()
 
-            # No more receiving data after a close message
-            raise SocketClosed(code, reason)
+            if self.close_frame_sent:
+                self.onclose(code, reason)
+                self.sock.close()
+                raise SocketClosed(True)
+            else:
+                self.close_params = (code, reason)
+                self.send_close_frame(code, reason)
 
         elif frame.opcode == OPCODE_PING:
             # Respond with a pong message with identical payload
-            self.sock.send(ControlFrame(OPCODE_PONG, frame.payload))
+            self.send_frame(ControlFrame(OPCODE_PONG, frame.payload))
 
         elif frame.opcode == OPCODE_PONG:
             # Assert that the PONG payload is identical to that of the PING
@@ -138,38 +155,40 @@ class Connection(object):
         while True:
             try:
                 self.onmessage(self.recv())
-            except SocketClosed as e:
-                self.close(e.code, e.reason)
+            except (KeyboardInterrupt, SystemExit, SocketClosed):
                 break
-            except socket.error as e:
+            except Exception as e:
                 self.onerror(e)
+                self.onclose(None, 'error: %s' % e)
 
                 try:
                     self.sock.close()
                 except socket.error:
                     pass
 
-                self.onclose(None, '')
-                break
-            except Exception as e:
-                self.onerror(e)
+                raise e
 
     def send_ping(self, payload=''):
         """
         Send a PING control frame with an optional payload.
         """
-        self.sock.send(ControlFrame(OPCODE_PING, payload))
+        self.send_frame(ControlFrame(OPCODE_PING, payload),
+                        lambda: self.onping(payload))
         self.ping_payload = payload
         self.ping_sent = True
-        self.onping(payload)
 
-    def send_close_frame(self, code=None, reason=''):
-        """
-        Send a CLOSE control frame.
-        """
-        payload = '' if code is None else struct.pack('!H', code) + reason
-        self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
+    def send_close_frame(self, code, reason):
+        self.send_frame(create_close_frame(code, reason))
         self.close_frame_sent = True
+        self.shutdown_write()
+
+    def shutdown_write(self):
+        if self.close_frame_received:
+            self.onclose(*self.close_params)
+            self.sock.close()
+            raise SocketClosed(False)
+        else:
+            self.sock.shutdown(socket.SHUT_WR)
 
     def close(self, code=None, reason=''):
         """
@@ -179,28 +198,14 @@ class Connection(object):
         called after the response has been received, but before the socket is
         actually closed.
         """
-        # Send CLOSE frame
-        if not self.close_frame_sent:
-            self.send_close_frame(code, reason)
-
-        # Receive CLOSE frame
-        if not self.close_frame_received:
-            frame = self.sock.recv()
-
-            if frame.opcode != OPCODE_CLOSE:
-                raise ValueError('expected CLOSE frame, got %s' % frame)
-
-            self.close_frame_received = True
-            res_code, res_reason = frame.unpack_close()
+        self.send_close_frame(code, reason)
 
-            # FIXME: check if res_code == code and res_reason == reason?
+        frame = self.sock.recv()
 
-            # FIXME: alternatively, keep receiving frames in a loop until a
-            # CLOSE frame is received, so that a fragmented chain may arrive
-            # fully first
+        if frame.opcode != OPCODE_CLOSE:
+            raise ValueError('expected CLOSE frame, got %s' % frame)
 
-        self.onclose(code, reason)
-        self.sock.close()
+        self.handle_control_frame(frame)
 
     def add_hook(self, send=None, recv=None, prepend=False):
         """

+ 8 - 4
errors.py

@@ -1,11 +1,15 @@
 class SocketClosed(Exception):
-    def __init__(self, code=None, reason=''):
-        self.code = code
-        self.reason = reason
+    def __init__(self, initialized):
+        self.initialized = initialized
 
     @property
     def message(self):
-        return ('' if self.code is None else '[%d] ' % self.code) + self.reason
+        s = 'socket closed'
+
+        if self.initialized:
+            s += ' (initialized)'
+
+        return s
 
 
 class HandshakeError(Exception):

+ 9 - 3
frame.py

@@ -232,11 +232,12 @@ def receive_frame(sock):
 def read_frame(data):
     reader = BufferReader(data)
     frame = decode_frame(reader)
-    return frame, len(data) - reader.offset
+    return frame, reader.offset
 
 
 def pop_frame(data):
-    frame, l = read_frame(data)
+    frame, size = read_frame(data)
+    return frame, data[size:]
 
 
 class BufferReader(object):
@@ -279,7 +280,7 @@ def contains_frame(data):
     if len(data) < 2:
         return False
 
-    b2 = struct.unpack('!B', data[1])
+    b2 = struct.unpack('!B', data[1])[0]
     payload_len = b2 & 0x7F
     payload_start = 2
 
@@ -316,3 +317,8 @@ def mask(key, original):
         masked[i] ^= key[i % 4]
 
     return masked
+
+
+def create_close_frame(code, reason):
+    payload = '' if code is None else struct.pack('!H', code) + reason
+    return ControlFrame(OPCODE_CLOSE, payload)

+ 9 - 10
server.py

@@ -25,14 +25,14 @@ class Server(object):
     >>>         print 'Received message "%s"' % message.payload
     >>>         client.send(wspy.TextMessage(message.payload))
 
-    >>>     def onclose(self, client):
+    >>>     def onclose(self, client, code, reason):
     >>>         print 'Client %s disconnected' % client
 
     >>> EchoServer(('', 8000)).run()
     """
 
     def __init__(self, address, loglevel=logging.INFO, ssl_args=None,
-                 max_join_time=2.0, **kwargs):
+                 max_join_time=2.0, backlog_size=32, **kwargs):
         """
         Constructor for a simple web socket server.
 
@@ -53,6 +53,8 @@ class Server(object):
 
         `max_join_time` is the maximum time (in seconds) to wait for client
         responses after sending CLOSE frames, it defaults to 2 seconds.
+
+        `backlog_size` is directly passed to `websocket.listen`.
         """
         logging.basicConfig(level=loglevel,
                 format='%(asctime)s: %(levelname)s: %(message)s',
@@ -69,14 +71,14 @@ class Server(object):
             self.sock.enable_ssl(server_side=True, **ssl_args)
 
         self.sock.bind(address)
-        self.sock.listen(5)
-
-        self.clients = []
-        self.client_threads = []
+        self.sock.listen(backlog_size)
 
         self.max_join_time = max_join_time
 
     def run(self):
+        self.clients = []
+        self.client_threads = []
+
         while True:
             try:
                 sock, address = self.sock.accept()
@@ -149,10 +151,7 @@ class Server(object):
         msg = 'Closed socket to %s' % client
 
         if code is not None:
-            msg += ' [%d]' % code
-
-        if len(reason):
-            msg += ': ' + reason
+            msg += ': [%d] %s' % (code, reason)
 
         logging.debug(msg)
 

+ 10 - 2
test/client.py

@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 import sys
+import ssl
 from os.path import abspath, dirname
 
 basepath = abspath(dirname(abspath(__file__)) + '/..')
@@ -20,7 +21,7 @@ class EchoClient(Connection):
 
     def onmessage(self, msg):
         print 'Received', msg
-        raise SocketClosed(None, 'response received')
+        self.close(None, 'response received')
 
     def onerror(self, e):
         print 'Error:', e
@@ -29,8 +30,15 @@ class EchoClient(Connection):
         print 'Connection closed'
 
 
+secure = True
+
 if __name__ == '__main__':
-    print 'Connecting to ws://%s:%d' % ADDR
+    scheme = 'wss' if secure else 'ws'
+    print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
     sock = websocket()
+
+    if secure:
+        sock.enable_ssl(ca_certs='cert.pem', cert_reqs=ssl.CERT_REQUIRED)
+
     sock.connect(ADDR)
     EchoClient(sock).receive_forever()

+ 1 - 1
test/server.py

@@ -20,5 +20,5 @@ if __name__ == '__main__':
     deflate = WebkitDeflateFrame()
     #deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True})
     EchoServer(('localhost', 8000), extensions=[deflate],
-               #ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
+               ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
                loglevel=logging.DEBUG).run()

+ 45 - 37
websocket.py

@@ -11,12 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
                    'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
                    'proto']
 
-STATE_INIT = 0
-STATE_READ = 1
-STATE_WRITE = 2
-STATE_CLOSE = 4
-
-
 class websocket(object):
     """
     Implementation of web socket, upgrades a regular TCP socket to a websocket
@@ -43,7 +37,7 @@ class websocket(object):
     """
     def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
                  location='/', trusted_origins=[], locations=[], auth=None,
-                 sfamily=socket.AF_INET, sproto=0):
+                 recv_callback=None, sfamily=socket.AF_INET, sproto=0):
         """
         Create a regular TCP socket of family `family` and protocol
 
@@ -76,6 +70,12 @@ class websocket(object):
         `auth` is optional, used for HTTP Basic or Digest authentication during
         the handshake. It must be specified as a (username, password) tuple.
 
+        `recv_callback` is the callback for received frames in asynchronous
+        sockets. Use in conjunction with setblocking(0). The callback itself
+        may for example change the recv_callback attribute to change the
+        behaviour for the next received message. Can be set when calling
+        `queue_send`.
+
         `sfamily` and `sproto` are used for the regular socket constructor.
         """
         self.protocols = protocols
@@ -93,10 +93,10 @@ class websocket(object):
         self.hooks_send = []
         self.hooks_recv = []
 
-        self.state = STATE_INIT
+        self.sendbuf_frames = []
         self.sendbuf = ''
         self.recvbuf = ''
-        self.recv_callbacks = []
+        self.recv_callback = recv_callback
 
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
@@ -163,61 +163,69 @@ class websocket(object):
         """
         return [self.recv() for i in xrange(n)]
 
-    def queue_send(self, frame):
+    def queue_send(self, frame, callback=None, recv_callback=None):
         """
         Enqueue `frame` to the send buffer so that it is send on the next
-        `do_async_send`.
+        `do_async_send`. `callback` is an optional callable to call when the
+        frame has been fully written. `recv_callback` is an optional callable
+        to quickly set the `recv_callback` attribute to.
         """
         for hook in self.hooks_send:
             frame = hook(frame)
 
         self.sendbuf += frame.pack()
-        self.state |= STATE_WRITE
+        self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
 
-    def queue_recv(self, callback):
-        """
-        Enqueue `callback` to be called when the next frame is recieved by
-        `do_async_recv`.
-        """
-        self.recv_callbacks.push(callback)
-        self.state |= STATE_READ
-
-    def queue_close(self):
-        self.state |= STATE_CLOSE
+        if recv_callback:
+            self.recv_callback = recv_callback
 
     def do_async_send(self):
         """
-        Send any queued data. If all data is sent, STATE_WRITE is removed from
-        the state mask.
+        Send any queued data.
         """
-        assert self.state & STATE_WRITE
         assert len(self.sendbuf)
 
         nwritten = self.sock.send(self.sendbuf)
-        self.sendbuf = self.sendbuf[nwritten:]
+        nframes = 0
 
-        if len(self.sendbuf) == 0:
-            self.state ^= STATE_WRITE
+        for entry in self.sendbuf_frames:
+            frame, offset, callback = entry
+
+            if offset <= nwritten:
+                nframes += 1
+
+                if callback:
+                    print 'write cb'
+                    callback()
+            else:
+                entry[1] -= nwritten
+
+        self.sendbuf = self.sendbuf[nwritten:]
+        self.sendbuf_frames = self.sendbuf_frames[nframes:]
 
     def do_async_recv(self, bufsize):
         """
         """
-        assert self.state & STATE_READ
+        data = self.sock.recv(bufsize)
+
+        if len(data) == 0:
+            raise socket.error('no data to receive')
 
-        self.recvbuf += self.sock.recv(bufsize)
+        self.recvbuf += data
 
         while contains_frame(self.recvbuf):
             frame, self.recvbuf = pop_frame(self.recvbuf)
 
-            if len(self.recv_callbacks) == 0:
-                raise IndexError('no callback installed for received frame %s'
-                                 % frame)
+            if not self.recv_callback:
+                raise ValueError('no callback installed for %s' % frame)
+
+            self.recv_callback(frame)
 
-            cb = self.recv_callbacks.pop(0)
-            cb(frame)
+    def can_send(self):
+        return len(self.sendbuf) > 0
 
-        if len(self.recvbuf) == 0:
-            self.state ^= STATE_READ
+    def can_recv(self):
+        return self.recv_callback is not None
 
     def enable_ssl(self, *args, **kwargs):
         """