Quellcode durchsuchen

Implemented frame (un)packing

Taddeus Kroes vor 13 Jahren
Ursprung
Commit
e8a8408ed4
3 geänderte Dateien mit 282 neuen und 1 gelöschten Zeilen
  1. 2 0
      exceptions.py
  2. 217 1
      frame.py
  3. 63 0
      message.py

+ 2 - 0
exceptions.py

@@ -0,0 +1,2 @@
+class SocketClosed(Exception):
+    pass

+ 217 - 1
frame.py

@@ -1,2 +1,218 @@
+import struct
+from os import urandom
+
+from exceptions import SocketClosed
+
+
+OPCODE_CONTINUATION = 0x0
+OPCODE_TEXT = 0x1
+OPCODE_BINARY = 0x2
+OPCODE_CLOSE = 0x8
+OPCODE_PING = 0x9
+OPCODE_PONG = 0xA
+
+
 class Frame(object):
-    pass
+    def __init__(self, opcode, payload, masking_key='', final=True, rsv1=False,
+            rsv2=False, rsv3=False):
+        if len(masking_key)  not in (0, 4):
+            raise ValueError('invalid masking key "%s"' % masking_key)
+
+        self.final = final
+        self.rsv1 = rsv1
+        self.rsv2 = rsv2
+        self.rsv3 = rsv3
+        self.opcode = opcode
+        self.masking_key = masking_key
+        self.payload = payload
+
+    def pack(self):
+        """
+        Pack the frame into a string according to the following scheme:
+
+        +-+-+-+-+-------+-+-------------+-------------------------------+
+        |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
+        |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
+        |N|V|V|V|       |S|             |   (if payload len==126/127)   |
+        | |1|2|3|       |K|             |                               |
+        +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+        |     Extended payload length continued, if payload len == 127  |
+        + - - - - - - - - - - - - - - - +-------------------------------+
+        |                               |Masking-key, if MASK set to 1  |
+        +-------------------------------+-------------------------------+
+        | Masking-key (continued)       |          Payload Data         |
+        +-------------------------------- - - - - - - - - - - - - - - - +
+        :                     Payload Data continued ...                :
+        + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+        |                     Payload Data continued ...                |
+        +---------------------------------------------------------------+
+        """
+        header = struct.pack('!B', (self.fin << 7) | (self.rsv1 << 6) |
+                             (self.rsv2 << 5) | (self.rsv3 << 4) | self.opcode)
+
+        mask = bool(self.masking_key) << 7
+        payload_len = len(self.payload)
+
+        if payload_len <= 125:
+            header += struct.pack('!B', mask | payload_len)
+        elif payload_len < (1 << 16):
+            header += struct.pack('!BH', mask | 126, payload_len)
+        elif payload_len < (1 << 63):
+            header += struct.pack('!BQ', mask | 127, payload_len)
+        else:
+            raise Exception('the payload length is too damn high!')
+
+        if self.masking_key:
+            return header + self.masking_key + self.mask_payload()
+
+        return header + self.payload
+
+    def mask_payload(self):
+        return mask(self.masking_key, self.payload)
+
+    def fragment(self, fragment_size, mask=False):
+        frames = []
+
+        for start in range(0, len(self.payload), fragment_size):
+            payload = self.payload[start:start + fragment_size]
+            key = urandom(4) if mask else ''
+            frames.append(Frame(OPCODE_CONTINUATION, payload, key, False))
+
+        frames[0].opcode = self.opcode
+        frames[-1].final = True
+
+        return frames
+
+    def __str__(self):
+        return '<Frame opcode=%c len=%d>' % (self.opcode, len(self.payload))
+
+
+class FrameReceiver(object):
+    def __init__(self, sock):
+        self.sock = sock
+
+    def assert_received(self, n, exact=False):
+        if self.nreceived < n:
+            recv = recv_exactly if exact else recv_at_least
+            received = recv(self.sock, n - self.nreceived)
+
+            if not len(received):
+                raise SocketClosed()
+
+            self.buf += received
+            self.nreceived = len(self.buf)
+
+    def receive_fragments(self):
+        fragments = [self.receive_frame()]
+
+        while not fragments[-1].final:
+            fragments.append(self.receive_frame())
+
+        return fragments
+
+    def receive_frame(self):
+        self.buf = ''
+        self.nreceived = 0
+        total_len = 2
+
+        self.assert_received(2)
+        b1, b2 = struct.unpack('!BB', self.buf[:2])
+        final = bool(b1 & 0x80)
+        rsv1 = bool(b1 & 0x40)
+        rsv2 = bool(b1 & 0x20)
+        rsv3 = bool(b1 & 0x10)
+        opcode = b1 & 0x0F
+        mask = bool(b2 & 0x80)
+        payload_len = b2 & 0x7F
+
+        if mask:
+            total_len += 4
+
+        if payload_len == 126:
+            self.assert_received(4)
+            total_len += 4 + struct.unpack('!H', self.buf[2:4])
+            key_start = 4
+        elif payload_len == 127:
+            self.assert_received(8)
+            total_len += 8 + struct.unpack('!Q', self.buf[2:10])
+            key_start = 10
+        else:
+            total_len += payload_len
+            key_start = 2
+
+        self.assert_received(total_len, exact=True)
+
+        if mask:
+            payload_start = key_start + 4
+            masking_key = self.buf[key_start:payload_start]
+            payload = mask(masking_key, self.buf[payload_start:])
+        else:
+            masking_key = ''
+            payload = self.buf[key_start:]
+
+        return Frame(opcode, payload, masking_key=masking_key, final=final,
+                      rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
+
+
+def recv_exactly(sock, n):
+    """
+    Keep receiving data from `sock' until exactly `n' bytes have been read.
+    """
+    left = n
+    data = ''
+
+    while left > 0:
+        received = sock.recv(left)
+        data += received
+        left -= len(received)
+
+    return received
+
+
+def recv_at_least(sock, n, at_least):
+    """
+    Keep receiving data from `sock' until at least `n' bytes have been read.
+    """
+    left = at_least
+    data = ''
+
+    while left > 0:
+        received = sock.recv(n)
+        data += received
+        left -= len(received)
+
+    return data
+
+
+def mask(key, original):
+    """
+    Mask an octet string using the given masking key.
+    The following masking algorithm is used, as defined in RFC 6455:
+
+    for each octet:
+        j = i MOD 4
+        transformed-octet-i = original-octet-i XOR masking-key-octet-j
+    """
+    if len(key) != 4:
+        raise ValueError('invalid masking key "%s"' % key)
+
+    key = map(ord, key)
+    masked = bytearray(original)
+
+    for i in xrange(len(masked)):
+        masked[i] ^= key[i % 4]
+
+    return masked
+
+
+def concat_frames(frames):
+    """
+    Create a new Frame object with the concatenated payload of the given list
+    of frames.
+    """
+    assert len(frames)
+    first = frames[0]
+    assert first.opcode != 0
+    assert frames[-1].final
+    return Frame(first.opcode, ''.join([f.payload for f in frames]),
+                 rsv1=first.rsv1, rsv2=first.rsv2, rsv3=first.rsv3)

+ 63 - 0
message.py

@@ -0,0 +1,63 @@
+from frame import Frame, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, \
+        OPCODE_PING, OPCODE_PONG
+
+
+__all__ = ['Message', 'TextMessage', 'BinaryMessage', 'CloseMessage',
+           'PingMessage', 'PongMessage']
+
+
+class Message(object):
+    def __init__(self, opcode, payload):
+        self.opcode = opcode
+        self.payload = payload
+
+    def frame(self):
+        return Frame(self.opcode, self.payload)
+
+    def fragment(self, fragment_size, mask=False):
+        return self.frame().fragment(fragment_size, mask)
+
+    def __str__(self):
+        return '<%s opcode=%x size=%d>' \
+               % (self.__class__.__name__, self.opcode, len(self.payload))
+
+
+class TextMessage(Message):
+    def __init__(self, payload):
+        super(TextMessage, self).__init__(OPCODE_TEXT, payload)
+
+
+class BinaryMessage(Message):
+    def __init__(self, payload):
+        super(TextMessage, self).__init__(OPCODE_BINARY, payload)
+
+
+class CloseMessage(Message):
+    def __init__(self, payload):
+        super(TextMessage, self).__init__(OPCODE_CLOSE, payload)
+
+
+class PingMessage(Message):
+    def __init__(self, payload):
+        super(TextMessage, self).__init__(OPCODE_PING, payload)
+
+
+class PongMessage(Message):
+    def __init__(self, payload):
+        super(TextMessage, self).__init__(OPCODE_PONG, payload)
+
+
+OPCODE_CLASS_MAP = {
+    OPCODE_TEXT: TextMessage,
+    OPCODE_BINARY: BinaryMessage,
+    OPCODE_CLOSE: CloseMessage,
+    OPCODE_PING: PingMessage,
+    OPCODE_PONG: PongMessage,
+}
+
+
+def create_message(opcode, payload):
+    if opcode in OPCODE_CLASS_MAP:
+        return OPCODE_CLASS_MAP(payload)
+
+    return Message(opcode, payload)