소스 검색

Started adding asynchronous methods

Taddeus Kroes 12 년 전
부모
커밋
b304505126
3개의 변경된 파일143개의 추가작업 그리고 24개의 파일을 삭제
  1. 4 3
      __init__.py
  2. 71 20
      frame.py
  3. 68 1
      websocket.py

+ 4 - 3
__init__.py

@@ -1,13 +1,14 @@
-from websocket import websocket
+from websocket import websocket, STATE_INIT, STATE_READ, STATE_WRITE, \
+                      STATE_CLOSE
 from server import Server
 from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
         OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
         CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
         CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
-        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
+        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE, read_frame, pop_frame, \
+        contains_frame
 from connection import Connection
 from message import Message, TextMessage, BinaryMessage
 from errors import SocketClosed, HandshakeError, PingError, SSLError
 from extension import Extension
 from deflate_frame import DeflateFrame, WebkitDeflateFrame
-#from multiplex import Multiplex

+ 71 - 20
frame.py

@@ -194,12 +194,8 @@ class ControlFrame(Frame):
         return code, reason
 
 
-def receive_frame(sock):
-    """
-    Receive a single frame on socket `sock`. The frame scheme is explained in
-    the docs of Frame.pack().
-    """
-    b1, b2 = struct.unpack('!BB', recvn(sock, 2))
+def decode_frame(reader):
+    b1, b2 = struct.unpack('!BB', reader.recvn(2))
 
     final = bool(b1 & 0x80)
     rsv1 = bool(b1 & 0x40)
@@ -211,16 +207,16 @@ def receive_frame(sock):
     payload_len = b2 & 0x7F
 
     if payload_len == 126:
-        payload_len = struct.unpack('!H', recvn(sock, 2))
+        payload_len = struct.unpack('!H', reader.recvn(2))
     elif payload_len == 127:
-        payload_len = struct.unpack('!Q', recvn(sock, 8))
+        payload_len = struct.unpack('!Q', reader.recvn(8))
 
     if masked:
-        masking_key = recvn(sock, 4)
-        payload = mask(masking_key, recvn(sock, payload_len))
+        masking_key = reader.recvn(4)
+        payload = mask(masking_key, reader.recvn(payload_len))
     else:
         masking_key = ''
-        payload = recvn(sock, payload_len)
+        payload = reader.recvn(payload_len)
 
     # Control frames have most significant bit 1
     cls = ControlFrame if opcode & 0x8 else Frame
@@ -229,21 +225,76 @@ def receive_frame(sock):
                rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
 
 
-def recvn(sock, n):
+def receive_frame(sock):
+    return decode_frame(BlockingSocket(sock))
+
+
+def read_frame(data):
+    reader = BufferedReader(data)
+    frame = decode_frame(reader)
+    return frame, len(data) - reader.offset
+
+
+def pop_frame(data):
+    frame, l = read_frame(data)
+
+
+class BufferedReader(object):
+    def __init__(self, data):
+        self.data = data
+        self.offset = 0
+
+    def recvn(self, n):
+        assert len(self.data) - self.offset >= n
+        self.offset += n
+        return self.data[self.offset - n:self.offset]
+
+
+class BlockingSocket(object):
+    def __init__(self, sock):
+        self.sock = sock
+
+    def recvn(self, n):
+        """
+        Keep receiving data until exactly `n` bytes have been read.
+        """
+        data = ''
+
+        while len(data) < n:
+            received = self.sock.recv(n - len(data))
+
+            if not len(received):
+                raise socket.error('no data read from socket')
+
+            data += received
+
+        return data
+
+
+def contains_frame(data):
     """
-    Keep receiving data from `sock` until exactly `n` bytes have been read.
+    Read the frame length from the start of `data` and check if the data is
+    long enough to contain the entire frame.
     """
-    data = ''
+    if len(data) < 2:
+        return False
+
+    b2 = struct.unpack('!B', data[1])
+    payload_len = b2 & 0x7F
+    payload_start = 2
 
-    while len(data) < n:
-        received = sock.recv(n - len(data))
+    if payload_len == 126:
+        if len(data) > 4:
+            payload_len = struct.unpack('!H', data[2:4])
 
-        if not len(received):
-            raise socket.error('no data read from socket')
+        payload_start = 4
+    elif payload_len == 127:
+        if len(data) > 12:
+            payload_len = struct.unpack('!Q', data[4:12])
 
-        data += received
+        payload_start = 12
 
-    return data
+    return len(data) >= payload_len + payload_start
 
 
 def mask(key, original):

+ 68 - 1
websocket.py

@@ -1,7 +1,7 @@
 import socket
 import ssl
 
-from frame import receive_frame
+from frame import receive_frame, pop_frame, contains_frame
 from handshake import ServerHandshake, ClientHandshake
 from errors import SSLError
 
@@ -11,6 +11,11 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
                    'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
                    'proto']
 
+STATE_INIT = 0
+STATE_READ = 1
+STATE_WRITE = 2
+STATE_CLOSE = 4
+
 
 class websocket(object):
     """
@@ -88,6 +93,11 @@ class websocket(object):
         self.hooks_send = []
         self.hooks_recv = []
 
+        self.state = STATE_INIT
+        self.sendbuf = ''
+        self.recvbuf = ''
+        self.recv_callbacks = []
+
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
     def __getattr__(self, name):
@@ -153,6 +163,62 @@ class websocket(object):
         """
         return [self.recv() for i in xrange(n)]
 
+    def queue_send(self, frame):
+        """
+        Enqueue `frame` to the send buffer so that it is send on the next
+        `do_async_send`.
+        """
+        for hook in self.hooks_send:
+            frame = hook(frame)
+
+        self.sendbuf += frame.pack()
+        self.state |= STATE_WRITE
+
+    def queue_recv(self, callback):
+        """
+        Enqueue `callback` to be called when the next frame is recieved by
+        `do_async_recv`.
+        """
+        self.recv_callbacks.push(callback)
+        self.state |= STATE_READ
+
+    def queue_close(self):
+        self.state |= STATE_CLOSE
+
+    def do_async_send(self):
+        """
+        Send any queued data. If all data is sent, STATE_WRITE is removed from
+        the state mask.
+        """
+        assert self.state | STATE_WRITE
+        assert len(self.sendbuf)
+
+        nwritten = self.sock.send(self.sendbuf)
+        self.sendbuf = self.sendbuf[nwritten:]
+
+        if len(self.sendbuf) == 0:
+            self.state ^= STATE_WRITE
+
+    def do_async_recv(self, bufsize):
+        """
+        """
+        assert self.state | STATE_READ
+
+        self.recvbuf += self.sock.recv(bufsize)
+
+        while contains_frame(self.recvbuf):
+            frame, self.recvbuf = pop_frame(self.recvbuf)
+
+            if len(self.recv_callbacks) == 0:
+                raise IndexError('no callback installed for received frame %s'
+                                 % frame)
+
+            cb = self.recv_callbacks.pop(0)
+            cb(frame)
+
+        if len(self.recvbuf) == 0:
+            self.state ^= STATE_READ
+
     def enable_ssl(self, *args, **kwargs):
         """
         Transforms the regular socket.socket to an ssl.SSLSocket for secure
@@ -180,6 +246,7 @@ class websocket(object):
         being sent and removes the instance for received data. This way, data
         can be sent and received as if on a regular socket.
         >>> import wspy
+        >>> sock = wspy.websocket()
         >>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
         >>>               lambda frame: frame.payload)