Răsfoiți Sursa

Started adding asynchronous methods

Taddeus Kroes 12 ani în urmă
părinte
comite
b304505126
3 a modificat fișierele cu 143 adăugiri și 24 ștergeri
  1. 4 3
      __init__.py
  2. 71 20
      frame.py
  3. 68 1
      websocket.py

+ 4 - 3
__init__.py

@@ -1,13 +1,14 @@
-from websocket import websocket
+from websocket import websocket, STATE_INIT, STATE_READ, STATE_WRITE, \
+                      STATE_CLOSE
 from server import Server
 from server import Server
 from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
 from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
         OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
         OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
         CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
         CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
         CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
         CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
-        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
+        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE, read_frame, pop_frame, \
+        contains_frame
 from connection import Connection
 from connection import Connection
 from message import Message, TextMessage, BinaryMessage
 from message import Message, TextMessage, BinaryMessage
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from extension import Extension
 from extension import Extension
 from deflate_frame import DeflateFrame, WebkitDeflateFrame
 from deflate_frame import DeflateFrame, WebkitDeflateFrame
-#from multiplex import Multiplex

+ 71 - 20
frame.py

@@ -194,12 +194,8 @@ class ControlFrame(Frame):
         return code, reason
         return code, reason
 
 
 
 
-def receive_frame(sock):
-    """
-    Receive a single frame on socket `sock`. The frame scheme is explained in
-    the docs of Frame.pack().
-    """
-    b1, b2 = struct.unpack('!BB', recvn(sock, 2))
+def decode_frame(reader):
+    b1, b2 = struct.unpack('!BB', reader.recvn(2))
 
 
     final = bool(b1 & 0x80)
     final = bool(b1 & 0x80)
     rsv1 = bool(b1 & 0x40)
     rsv1 = bool(b1 & 0x40)
@@ -211,16 +207,16 @@ def receive_frame(sock):
     payload_len = b2 & 0x7F
     payload_len = b2 & 0x7F
 
 
     if payload_len == 126:
     if payload_len == 126:
-        payload_len = struct.unpack('!H', recvn(sock, 2))
+        payload_len = struct.unpack('!H', reader.recvn(2))
     elif payload_len == 127:
     elif payload_len == 127:
-        payload_len = struct.unpack('!Q', recvn(sock, 8))
+        payload_len = struct.unpack('!Q', reader.recvn(8))
 
 
     if masked:
     if masked:
-        masking_key = recvn(sock, 4)
-        payload = mask(masking_key, recvn(sock, payload_len))
+        masking_key = reader.recvn(4)
+        payload = mask(masking_key, reader.recvn(payload_len))
     else:
     else:
         masking_key = ''
         masking_key = ''
-        payload = recvn(sock, payload_len)
+        payload = reader.recvn(payload_len)
 
 
     # Control frames have most significant bit 1
     # Control frames have most significant bit 1
     cls = ControlFrame if opcode & 0x8 else Frame
     cls = ControlFrame if opcode & 0x8 else Frame
@@ -229,21 +225,76 @@ def receive_frame(sock):
                rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
                rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
 
 
 
 
-def recvn(sock, n):
+def receive_frame(sock):
+    return decode_frame(BlockingSocket(sock))
+
+
+def read_frame(data):
+    reader = BufferedReader(data)
+    frame = decode_frame(reader)
+    return frame, len(data) - reader.offset
+
+
+def pop_frame(data):
+    frame, l = read_frame(data)
+
+
+class BufferedReader(object):
+    def __init__(self, data):
+        self.data = data
+        self.offset = 0
+
+    def recvn(self, n):
+        assert len(self.data) - self.offset >= n
+        self.offset += n
+        return self.data[self.offset - n:self.offset]
+
+
+class BlockingSocket(object):
+    def __init__(self, sock):
+        self.sock = sock
+
+    def recvn(self, n):
+        """
+        Keep receiving data until exactly `n` bytes have been read.
+        """
+        data = ''
+
+        while len(data) < n:
+            received = self.sock.recv(n - len(data))
+
+            if not len(received):
+                raise socket.error('no data read from socket')
+
+            data += received
+
+        return data
+
+
+def contains_frame(data):
     """
     """
-    Keep receiving data from `sock` until exactly `n` bytes have been read.
+    Read the frame length from the start of `data` and check if the data is
+    long enough to contain the entire frame.
     """
     """
-    data = ''
+    if len(data) < 2:
+        return False
+
+    b2 = struct.unpack('!B', data[1])
+    payload_len = b2 & 0x7F
+    payload_start = 2
 
 
-    while len(data) < n:
-        received = sock.recv(n - len(data))
+    if payload_len == 126:
+        if len(data) > 4:
+            payload_len = struct.unpack('!H', data[2:4])
 
 
-        if not len(received):
-            raise socket.error('no data read from socket')
+        payload_start = 4
+    elif payload_len == 127:
+        if len(data) > 12:
+            payload_len = struct.unpack('!Q', data[4:12])
 
 
-        data += received
+        payload_start = 12
 
 
-    return data
+    return len(data) >= payload_len + payload_start
 
 
 
 
 def mask(key, original):
 def mask(key, original):

+ 68 - 1
websocket.py

@@ -1,7 +1,7 @@
 import socket
 import socket
 import ssl
 import ssl
 
 
-from frame import receive_frame
+from frame import receive_frame, pop_frame, contains_frame
 from handshake import ServerHandshake, ClientHandshake
 from handshake import ServerHandshake, ClientHandshake
 from errors import SSLError
 from errors import SSLError
 
 
@@ -11,6 +11,11 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
                    'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
                    'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
                    'proto']
                    'proto']
 
 
+STATE_INIT = 0
+STATE_READ = 1
+STATE_WRITE = 2
+STATE_CLOSE = 4
+
 
 
 class websocket(object):
 class websocket(object):
     """
     """
@@ -88,6 +93,11 @@ class websocket(object):
         self.hooks_send = []
         self.hooks_send = []
         self.hooks_recv = []
         self.hooks_recv = []
 
 
+        self.state = STATE_INIT
+        self.sendbuf = ''
+        self.recvbuf = ''
+        self.recv_callbacks = []
+
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
 
     def __getattr__(self, name):
     def __getattr__(self, name):
@@ -153,6 +163,62 @@ class websocket(object):
         """
         """
         return [self.recv() for i in xrange(n)]
         return [self.recv() for i in xrange(n)]
 
 
+    def queue_send(self, frame):
+        """
+        Enqueue `frame` to the send buffer so that it is send on the next
+        `do_async_send`.
+        """
+        for hook in self.hooks_send:
+            frame = hook(frame)
+
+        self.sendbuf += frame.pack()
+        self.state |= STATE_WRITE
+
+    def queue_recv(self, callback):
+        """
+        Enqueue `callback` to be called when the next frame is recieved by
+        `do_async_recv`.
+        """
+        self.recv_callbacks.push(callback)
+        self.state |= STATE_READ
+
+    def queue_close(self):
+        self.state |= STATE_CLOSE
+
+    def do_async_send(self):
+        """
+        Send any queued data. If all data is sent, STATE_WRITE is removed from
+        the state mask.
+        """
+        assert self.state | STATE_WRITE
+        assert len(self.sendbuf)
+
+        nwritten = self.sock.send(self.sendbuf)
+        self.sendbuf = self.sendbuf[nwritten:]
+
+        if len(self.sendbuf) == 0:
+            self.state ^= STATE_WRITE
+
+    def do_async_recv(self, bufsize):
+        """
+        """
+        assert self.state | STATE_READ
+
+        self.recvbuf += self.sock.recv(bufsize)
+
+        while contains_frame(self.recvbuf):
+            frame, self.recvbuf = pop_frame(self.recvbuf)
+
+            if len(self.recv_callbacks) == 0:
+                raise IndexError('no callback installed for received frame %s'
+                                 % frame)
+
+            cb = self.recv_callbacks.pop(0)
+            cb(frame)
+
+        if len(self.recvbuf) == 0:
+            self.state ^= STATE_READ
+
     def enable_ssl(self, *args, **kwargs):
     def enable_ssl(self, *args, **kwargs):
         """
         """
         Transforms the regular socket.socket to an ssl.SSLSocket for secure
         Transforms the regular socket.socket to an ssl.SSLSocket for secure
@@ -180,6 +246,7 @@ class websocket(object):
         being sent and removes the instance for received data. This way, data
         being sent and removes the instance for received data. This way, data
         can be sent and received as if on a regular socket.
         can be sent and received as if on a regular socket.
         >>> import wspy
         >>> import wspy
+        >>> sock = wspy.websocket()
         >>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
         >>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
         >>>               lambda frame: frame.payload)
         >>>               lambda frame: frame.payload)