Procházet zdrojové kódy

Rewrote FrameReceiver to two cleaner functions

Taddeus Kroes před 13 roky
rodič
revize
ff54ce821c
3 změnil soubory, kde provedl 52 přidání a 82 odebrání
  1. 48 78
      frame.py
  2. 1 1
      server.py
  3. 3 3
      websocket.py

+ 48 - 78
frame.py

@@ -87,74 +87,55 @@ class Frame(object):
         return '<Frame opcode=%c len=%d>' % (self.opcode, len(self.payload))
 
 
-class FrameReceiver(object):
-    def __init__(self, sock):
-        self.sock = sock
-
-    def assert_received(self, n, exact=False):
-        if self.nreceived < n:
-            recv = recv_exactly if exact else recv_at_least
-            received = recv(self.sock, n - self.nreceived)
-
-            if not len(received):
-                raise SocketClosed()
-
-            self.buf += received
-            self.nreceived = len(self.buf)
-
-    def receive_fragments(self):
-        fragments = [self.receive_frame()]
-
-        while not fragments[-1].final:
-            fragments.append(self.receive_frame())
-
-        return fragments
-
-    def receive_frame(self):
-        self.buf = ''
-        self.nreceived = 0
-        total_len = 2
-
-        self.assert_received(2)
-        b1, b2 = struct.unpack('!BB', self.buf[:2])
-        final = bool(b1 & 0x80)
-        rsv1 = bool(b1 & 0x40)
-        rsv2 = bool(b1 & 0x20)
-        rsv3 = bool(b1 & 0x10)
-        opcode = b1 & 0x0F
-        mask = bool(b2 & 0x80)
-        payload_len = b2 & 0x7F
-
-        if mask:
-            total_len += 4
-
-        if payload_len == 126:
-            self.assert_received(4)
-            total_len += 4 + struct.unpack('!H', self.buf[2:4])
-            key_start = 4
-        elif payload_len == 127:
-            self.assert_received(8)
-            total_len += 8 + struct.unpack('!Q', self.buf[2:10])
-            key_start = 10
-        else:
-            total_len += payload_len
-            key_start = 2
-
-        self.assert_received(total_len, exact=True)
+def receive_fragments(sock):
+    """
+    Receive a sequence of frames that belong together:
+    - An ititial frame with non-zero opcode
+    - Zero or more frames with opcode = 0 and final = False
+    - A final frame with opcpde = 0 and final = True
+
+    The first and last frame may be the same frame, having a non-zero opcode
+    and final = True. Thus, this function returns a list of at least a single
+    frame.
+    """
+    fragments = [receive_frame(sock)]
 
-        if mask:
-            payload_start = key_start + 4
-            masking_key = self.buf[key_start:payload_start]
-            payload = mask(masking_key, self.buf[payload_start:])
-        else:
-            masking_key = ''
-            payload = self.buf[key_start:]
+    while not fragments[-1].final:
+        fragments.append(receive_frame(sock))
 
-        return Frame(opcode, payload, masking_key=masking_key, final=final,
-                      rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
+    return fragments
 
 
-def recv_exactly(sock, n):
+def receive_frame(sock):
+    """
+    Receive a single frame on the given socket.
+    """
+    b1, b2 = struct.unpack('!BB', recvn(sock, 2))
+    final = bool(b1 & 0x80)
+    rsv1 = bool(b1 & 0x40)
+    rsv2 = bool(b1 & 0x20)
+    rsv3 = bool(b1 & 0x10)
+    opcode = b1 & 0x0F
+    mask = bool(b2 & 0x80)
+    payload_len = b2 & 0x7F
+
+    if payload_len == 126:
+        payload_len = struct.unpack('!H', recvn(sock, 2))
+    elif payload_len == 127:
+        payload_len = struct.unpack('!Q', recvn(sock, 8))
+
+    if mask:
+        masking_key = recvn(sock, 4)
+        payload = mask(masking_key, recvn(payload_len))
+    else:
+        masking_key = ''
+        payload = recvn(payload_len)
+
+    return Frame(opcode, payload, masking_key=masking_key, final=final,
+                    rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
+
+
+def recvn(sock, n):
     """
     Keep receiving data from `sock' until exactly `n' bytes have been read.
     """
@@ -163,25 +144,14 @@ def recv_exactly(sock, n):
 
     while left > 0:
         received = sock.recv(left)
-        data += received
-        left -= len(received)
 
-    return received
-
-
-def recv_at_least(sock, n, at_least):
-    """
-    Keep receiving data from `sock' until at least `n' bytes have been read.
-    """
-    left = at_least
-    data = ''
+        if not len(received):
+            raise SocketClosed()
 
-    while left > 0:
-        received = sock.recv(n)
         data += received
         left -= len(received)
 
-    return data
+    return received
 
 
 def mask(key, original):

+ 1 - 1
server.py

@@ -69,7 +69,7 @@ class Client(WebSocket):
         self.server.onmessage(self, message)
 
     def onclose(self):
-        self.server.onclose(self, message)
+        self.server.onclose(self)
 
     def __str__(self):
         return '<Client at %s:%d>' % self.address

+ 3 - 3
websocket.py

@@ -2,7 +2,7 @@ import re
 from hashlib import sha1
 from threading import Thread
 
-from frame import FrameReceiver
+from frame import receive_fragments
 from message import create_message
 from exceptions import SocketClosed
 
@@ -11,7 +11,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
 WS_VERSION = '13'
 
 
-class WebSocket(FrameReceiver):
+class WebSocket(object):
     def __init__(self, sock, address, encoding=None):
         super(WebSocket, self).__init__(sock)
         self.address = address
@@ -27,7 +27,7 @@ class WebSocket(FrameReceiver):
         self.sock.sendall(frame.pack())
 
     def receive_message(self):
-        frames = self.receive_fragments()
+        frames = receive_fragments(self.sock)
         payload = ''.join([f.payload for f in frames])
         return create_message(frames[0].opcode, payload)