Commit e8a8408e authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented frame (un)packing

parent 3c9af2c0
class SocketClosed(Exception):
pass
import struct
from os import urandom
from exceptions import SocketClosed
OPCODE_CONTINUATION = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xA
class Frame(object): class Frame(object):
pass def __init__(self, opcode, payload, masking_key='', final=True, rsv1=False,
rsv2=False, rsv3=False):
if len(masking_key) not in (0, 4):
raise ValueError('invalid masking key "%s"' % masking_key)
self.final = final
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.opcode = opcode
self.masking_key = masking_key
self.payload = payload
def pack(self):
"""
Pack the frame into a string according to the following scheme:
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
"""
header = struct.pack('!B', (self.fin << 7) | (self.rsv1 << 6) |
(self.rsv2 << 5) | (self.rsv3 << 4) | self.opcode)
mask = bool(self.masking_key) << 7
payload_len = len(self.payload)
if payload_len <= 125:
header += struct.pack('!B', mask | payload_len)
elif payload_len < (1 << 16):
header += struct.pack('!BH', mask | 126, payload_len)
elif payload_len < (1 << 63):
header += struct.pack('!BQ', mask | 127, payload_len)
else:
raise Exception('the payload length is too damn high!')
if self.masking_key:
return header + self.masking_key + self.mask_payload()
return header + self.payload
def mask_payload(self):
return mask(self.masking_key, self.payload)
def fragment(self, fragment_size, mask=False):
frames = []
for start in range(0, len(self.payload), fragment_size):
payload = self.payload[start:start + fragment_size]
key = urandom(4) if mask else ''
frames.append(Frame(OPCODE_CONTINUATION, payload, key, False))
frames[0].opcode = self.opcode
frames[-1].final = True
return frames
def __str__(self):
return '<Frame opcode=%c len=%d>' % (self.opcode, len(self.payload))
class FrameReceiver(object):
def __init__(self, sock):
self.sock = sock
def assert_received(self, n, exact=False):
if self.nreceived < n:
recv = recv_exactly if exact else recv_at_least
received = recv(self.sock, n - self.nreceived)
if not len(received):
raise SocketClosed()
self.buf += received
self.nreceived = len(self.buf)
def receive_fragments(self):
fragments = [self.receive_frame()]
while not fragments[-1].final:
fragments.append(self.receive_frame())
return fragments
def receive_frame(self):
self.buf = ''
self.nreceived = 0
total_len = 2
self.assert_received(2)
b1, b2 = struct.unpack('!BB', self.buf[:2])
final = bool(b1 & 0x80)
rsv1 = bool(b1 & 0x40)
rsv2 = bool(b1 & 0x20)
rsv3 = bool(b1 & 0x10)
opcode = b1 & 0x0F
mask = bool(b2 & 0x80)
payload_len = b2 & 0x7F
if mask:
total_len += 4
if payload_len == 126:
self.assert_received(4)
total_len += 4 + struct.unpack('!H', self.buf[2:4])
key_start = 4
elif payload_len == 127:
self.assert_received(8)
total_len += 8 + struct.unpack('!Q', self.buf[2:10])
key_start = 10
else:
total_len += payload_len
key_start = 2
self.assert_received(total_len, exact=True)
if mask:
payload_start = key_start + 4
masking_key = self.buf[key_start:payload_start]
payload = mask(masking_key, self.buf[payload_start:])
else:
masking_key = ''
payload = self.buf[key_start:]
return Frame(opcode, payload, masking_key=masking_key, final=final,
rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
def recv_exactly(sock, n):
"""
Keep receiving data from `sock' until exactly `n' bytes have been read.
"""
left = n
data = ''
while left > 0:
received = sock.recv(left)
data += received
left -= len(received)
return received
def recv_at_least(sock, n, at_least):
"""
Keep receiving data from `sock' until at least `n' bytes have been read.
"""
left = at_least
data = ''
while left > 0:
received = sock.recv(n)
data += received
left -= len(received)
return data
def mask(key, original):
"""
Mask an octet string using the given masking key.
The following masking algorithm is used, as defined in RFC 6455:
for each octet:
j = i MOD 4
transformed-octet-i = original-octet-i XOR masking-key-octet-j
"""
if len(key) != 4:
raise ValueError('invalid masking key "%s"' % key)
key = map(ord, key)
masked = bytearray(original)
for i in xrange(len(masked)):
masked[i] ^= key[i % 4]
return masked
def concat_frames(frames):
"""
Create a new Frame object with the concatenated payload of the given list
of frames.
"""
assert len(frames)
first = frames[0]
assert first.opcode != 0
assert frames[-1].final
return Frame(first.opcode, ''.join([f.payload for f in frames]),
rsv1=first.rsv1, rsv2=first.rsv2, rsv3=first.rsv3)
from frame import Frame, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, \
OPCODE_PING, OPCODE_PONG
__all__ = ['Message', 'TextMessage', 'BinaryMessage', 'CloseMessage',
'PingMessage', 'PongMessage']
class Message(object):
def __init__(self, opcode, payload):
self.opcode = opcode
self.payload = payload
def frame(self):
return Frame(self.opcode, self.payload)
def fragment(self, fragment_size, mask=False):
return self.frame().fragment(fragment_size, mask)
def __str__(self):
return '<%s opcode=%x size=%d>' \
% (self.__class__.__name__, self.opcode, len(self.payload))
class TextMessage(Message):
def __init__(self, payload):
super(TextMessage, self).__init__(OPCODE_TEXT, payload)
class BinaryMessage(Message):
def __init__(self, payload):
super(TextMessage, self).__init__(OPCODE_BINARY, payload)
class CloseMessage(Message):
def __init__(self, payload):
super(TextMessage, self).__init__(OPCODE_CLOSE, payload)
class PingMessage(Message):
def __init__(self, payload):
super(TextMessage, self).__init__(OPCODE_PING, payload)
class PongMessage(Message):
def __init__(self, payload):
super(TextMessage, self).__init__(OPCODE_PONG, payload)
OPCODE_CLASS_MAP = {
OPCODE_TEXT: TextMessage,
OPCODE_BINARY: BinaryMessage,
OPCODE_CLOSE: CloseMessage,
OPCODE_PING: PingMessage,
OPCODE_PONG: PongMessage,
}
def create_message(opcode, payload):
if opcode in OPCODE_CLASS_MAP:
return OPCODE_CLASS_MAP(payload)
return Message(opcode, payload)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment