Commit ff54ce82 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Rewrote FrameReceiver to two cleaner functions

parent 82ac7d72
......@@ -87,74 +87,55 @@ class Frame(object):
return '<Frame opcode=%c len=%d>' % (self.opcode, len(self.payload))
class FrameReceiver(object):
def __init__(self, sock):
self.sock = sock
def assert_received(self, n, exact=False):
if self.nreceived < n:
recv = recv_exactly if exact else recv_at_least
received = recv(self.sock, n - self.nreceived)
if not len(received):
raise SocketClosed()
self.buf += received
self.nreceived = len(self.buf)
def receive_fragments(self):
fragments = [self.receive_frame()]
while not fragments[-1].final:
fragments.append(self.receive_frame())
return fragments
def receive_frame(self):
self.buf = ''
self.nreceived = 0
total_len = 2
self.assert_received(2)
b1, b2 = struct.unpack('!BB', self.buf[:2])
final = bool(b1 & 0x80)
rsv1 = bool(b1 & 0x40)
rsv2 = bool(b1 & 0x20)
rsv3 = bool(b1 & 0x10)
opcode = b1 & 0x0F
mask = bool(b2 & 0x80)
payload_len = b2 & 0x7F
if mask:
total_len += 4
if payload_len == 126:
self.assert_received(4)
total_len += 4 + struct.unpack('!H', self.buf[2:4])
key_start = 4
elif payload_len == 127:
self.assert_received(8)
total_len += 8 + struct.unpack('!Q', self.buf[2:10])
key_start = 10
else:
total_len += payload_len
key_start = 2
self.assert_received(total_len, exact=True)
def receive_fragments(sock):
"""
Receive a sequence of frames that belong together:
- An ititial frame with non-zero opcode
- Zero or more frames with opcode = 0 and final = False
- A final frame with opcpde = 0 and final = True
The first and last frame may be the same frame, having a non-zero opcode
and final = True. Thus, this function returns a list of at least a single
frame.
"""
fragments = [receive_frame(sock)]
if mask:
payload_start = key_start + 4
masking_key = self.buf[key_start:payload_start]
payload = mask(masking_key, self.buf[payload_start:])
else:
masking_key = ''
payload = self.buf[key_start:]
while not fragments[-1].final:
fragments.append(receive_frame(sock))
return Frame(opcode, payload, masking_key=masking_key, final=final,
rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
return fragments
def recv_exactly(sock, n):
def receive_frame(sock):
"""
Receive a single frame on the given socket.
"""
b1, b2 = struct.unpack('!BB', recvn(sock, 2))
final = bool(b1 & 0x80)
rsv1 = bool(b1 & 0x40)
rsv2 = bool(b1 & 0x20)
rsv3 = bool(b1 & 0x10)
opcode = b1 & 0x0F
mask = bool(b2 & 0x80)
payload_len = b2 & 0x7F
if payload_len == 126:
payload_len = struct.unpack('!H', recvn(sock, 2))
elif payload_len == 127:
payload_len = struct.unpack('!Q', recvn(sock, 8))
if mask:
masking_key = recvn(sock, 4)
payload = mask(masking_key, recvn(payload_len))
else:
masking_key = ''
payload = recvn(payload_len)
return Frame(opcode, payload, masking_key=masking_key, final=final,
rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
def recvn(sock, n):
"""
Keep receiving data from `sock' until exactly `n' bytes have been read.
"""
......@@ -163,25 +144,14 @@ def recv_exactly(sock, n):
while left > 0:
received = sock.recv(left)
data += received
left -= len(received)
return received
def recv_at_least(sock, n, at_least):
"""
Keep receiving data from `sock' until at least `n' bytes have been read.
"""
left = at_least
data = ''
if not len(received):
raise SocketClosed()
while left > 0:
received = sock.recv(n)
data += received
left -= len(received)
return data
return received
def mask(key, original):
......
......@@ -69,7 +69,7 @@ class Client(WebSocket):
self.server.onmessage(self, message)
def onclose(self):
self.server.onclose(self, message)
self.server.onclose(self)
def __str__(self):
return '<Client at %s:%d>' % self.address
......
......@@ -2,7 +2,7 @@ import re
from hashlib import sha1
from threading import Thread
from frame import FrameReceiver
from frame import receive_fragments
from message import create_message
from exceptions import SocketClosed
......@@ -11,7 +11,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
WS_VERSION = '13'
class WebSocket(FrameReceiver):
class WebSocket(object):
def __init__(self, sock, address, encoding=None):
super(WebSocket, self).__init__(sock)
self.address = address
......@@ -27,7 +27,7 @@ class WebSocket(FrameReceiver):
self.sock.sendall(frame.pack())
def receive_message(self):
frames = self.receive_fragments()
frames = receive_fragments(self.sock)
payload = ''.join([f.payload for f in frames])
return create_message(frames[0].opcode, payload)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment