Commit 6c79550e authored by Taddeüs Kroes's avatar Taddeüs Kroes

Merge branch 'async'

parents 4d9fbb0c 447ee6fa
...@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library: ...@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library:
- Secure sockets using SSL certificates (for 'wss://...' URLs). - Secure sockets using SSL certificates (for 'wss://...' URLs).
- The possibility to add extensions to the web socket protocol. An included - The possibility to add extensions to the web socket protocol. An included
implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06). implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
- Asynchronous sockets with an EPOLL-based server.
Installation Installation
......
...@@ -4,10 +4,11 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \ ...@@ -4,10 +4,11 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \ OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \ CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \ CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE, read_frame, pop_frame, \
contains_frame
from connection import Connection from connection import Connection
from message import Message, TextMessage, BinaryMessage from message import Message, TextMessage, BinaryMessage
from errors import SocketClosed, HandshakeError, PingError, SSLError from errors import SocketClosed, HandshakeError, PingError, SSLError
from extension import Extension from extension import Extension
from deflate_frame import DeflateFrame, WebkitDeflateFrame from deflate_frame import DeflateFrame, WebkitDeflateFrame
#from multiplex import Multiplex from async import AsyncConnection, AsyncServer
import socket
from select import epoll, EPOLLIN, EPOLLOUT, EPOLLHUP
from traceback import format_exc
import logging
from connection import Connection
from frame import ControlFrame, OPCODE_PING, OPCODE_CONTINUATION, \
create_close_frame
from server import Server, Client
from errors import HandshakeError, SocketClosed
class AsyncConnection(Connection):
def __init__(self, sock):
sock.recv_callback = self.contruct_message
sock.recv_close_callback = self.onclose
self.recvbuf = []
Connection.__init__(self, sock)
def contruct_message(self, frame):
if isinstance(frame, ControlFrame):
self.handle_control_frame(frame)
return
self.recvbuf.append(frame)
if frame.final:
message = self.concat_fragments(self.recvbuf)
self.recvbuf = []
self.onmessage(message)
elif len(self.recvbuf) > 1 and frame.opcode != OPCODE_CONTINUATION:
raise ValueError('expected continuation/control frame, got %s '
'instead' % frame)
def send(self, message, fragment_size=None, mask=False):
frames = list(self.message_to_frames(message, fragment_size, mask))
for frame in frames[:-1]:
self.sock.queue_send(frame)
self.sock.queue_send(frames[-1], lambda: self.onsend(message))
def send_frame(self, frame, callback):
self.sock.queue_send(frame, callback)
def do_async_send(self):
self.execute_controlled(self.sock.do_async_send)
def do_async_recv(self, bufsize):
self.execute_controlled(self.sock.do_async_recv, bufsize)
def execute_controlled(self, func, *args, **kwargs):
try:
func(*args, **kwargs)
except (KeyboardInterrupt, SystemExit, SocketClosed):
raise
except Exception as e:
self.onerror(e)
self.onclose(None, 'error: %s' % e)
try:
self.sock.close()
except socket.error:
pass
raise e
def send_close_frame(self, code, reason):
self.sock.queue_send(create_close_frame(code, reason),
self.shutdown_write)
self.close_frame_sent = True
def close(self, code=None, reason=''):
self.send_close_frame(code, reason)
def send_ping(self, payload=''):
self.sock.queue_send(ControlFrame(OPCODE_PING, payload),
lambda: self.onping(payload))
self.ping_payload = payload
self.ping_sent = True
def onsend(self, message):
"""
Called after a message has been written.
"""
return NotImplemented
class AsyncServer(Server):
def __init__(self, *args, **kwargs):
Server.__init__(self, *args, **kwargs)
self.recvbuf_size = kwargs.get('recvbuf_size', 2048)
self.epoll = epoll()
self.epoll.register(self.sock.fileno(), EPOLLIN)
self.conns = {}
@property
def clients(self):
return self.conns.values()
def remove_client(self, client, code, reason):
self.epoll.unregister(client.fno)
del self.conns[client.fno]
self.onclose(client, code, reason)
def handle_events(self):
for fileno, event in self.epoll.poll(1):
if fileno == self.sock.fileno():
try:
sock, addr = self.sock.accept()
except HandshakeError as e:
logging.error('Invalid request: %s', e.message)
continue
client = AsyncClient(self, sock)
client.fno = sock.fileno()
sock.setblocking(0)
self.epoll.register(client.fno, EPOLLIN)
self.conns[client.fno] = client
logging.debug('Registered client %s', client)
elif event & EPOLLHUP:
self.epoll.unregister(fileno)
del self.conns[fileno]
else:
conn = self.conns[fileno]
try:
if event & EPOLLOUT:
conn.do_async_send()
elif event & EPOLLIN:
conn.do_async_recv(self.recvbuf_size)
except (KeyboardInterrupt, SystemExit):
raise
except SocketClosed:
continue
except Exception as e:
logging.error(format_exc(e).rstrip())
continue
self.update_mask(conn)
def run(self):
try:
while True:
self.handle_events()
except (KeyboardInterrupt, SystemExit):
logging.info('Received interrupt, stopping server...')
finally:
self.epoll.unregister(self.sock.fileno())
self.epoll.close()
self.sock.close()
def update_mask(self, conn):
mask = 0
if conn.sock.can_send():
mask |= EPOLLOUT
if conn.sock.can_recv():
mask |= EPOLLIN
self.epoll.modify(conn.sock.fileno(), mask)
def onsend(self, client, message):
return NotImplemented
class AsyncClient(Client, AsyncConnection):
def __init__(self, server, sock):
self.server = server
AsyncConnection.__init__(self, sock)
def send(self, message, fragment_size=None, mask=False):
logging.debug('Enqueueing %s to %s', message, self)
AsyncConnection.send(self, message, fragment_size, mask)
self.server.update_mask(self)
def onsend(self, message):
logging.debug('Finished sending %s to %s', message, self)
self.server.onsend(self, message)
if __name__ == '__main__':
import sys
port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
AsyncServer(('', port), loglevel=logging.DEBUG).run()
import struct
import socket import socket
from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \ from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
OPCODE_CONTINUATION OPCODE_CONTINUATION, create_close_frame
from message import create_message from message import create_message
from errors import SocketClosed, PingError from errors import SocketClosed, PingError
...@@ -54,19 +53,30 @@ class Connection(object): ...@@ -54,19 +53,30 @@ class Connection(object):
self.onopen() self.onopen()
def message_to_frames(self, message, fragment_size=None, mask=False):
for hook in self.hooks_send:
message = hook(message)
if fragment_size is None:
yield message.frame(mask=mask)
else:
for frame in message.fragment(fragment_size, mask=mask):
yield frame
def send(self, message, fragment_size=None, mask=False): def send(self, message, fragment_size=None, mask=False):
""" """
Send a message. If `fragment_size` is specified, the message is Send a message. If `fragment_size` is specified, the message is
fragmented into multiple frames whose payload size does not extend fragmented into multiple frames whose payload size does not extend
`fragment_size`. `fragment_size`.
""" """
for hook in self.hooks_send: for frame in self.message_to_frames(message, fragment_size, mask):
message = hook(message) self.send_frame(frame)
if fragment_size is None: def send_frame(self, frame, callback=None):
self.sock.send(message.frame(mask=mask)) self.sock.send(frame)
else:
self.sock.send(*message.fragment(fragment_size, mask=mask)) if callback:
callback()
def recv(self): def recv(self):
""" """
...@@ -82,12 +92,15 @@ class Connection(object): ...@@ -82,12 +92,15 @@ class Connection(object):
if isinstance(frame, ControlFrame): if isinstance(frame, ControlFrame):
self.handle_control_frame(frame) self.handle_control_frame(frame)
elif len(fragments) and frame.opcode != OPCODE_CONTINUATION: elif len(fragments) > 0 and frame.opcode != OPCODE_CONTINUATION:
raise ValueError('expected continuation/control frame, got %s ' raise ValueError('expected continuation/control frame, got %s '
'instead' % frame) 'instead' % frame)
else: else:
fragments.append(frame) fragments.append(frame)
return self.concat_fragments(fragments)
def concat_fragments(self, fragments):
payload = bytearray() payload = bytearray()
for f in fragments: for f in fragments:
...@@ -105,16 +118,20 @@ class Connection(object): ...@@ -105,16 +118,20 @@ class Connection(object):
Handle a control frame as defined by RFC 6455. Handle a control frame as defined by RFC 6455.
""" """
if frame.opcode == OPCODE_CLOSE: if frame.opcode == OPCODE_CLOSE:
# Close the connection from this end as well
self.close_frame_received = True self.close_frame_received = True
code, reason = frame.unpack_close() code, reason = frame.unpack_close()
# No more receiving data after a close message if self.close_frame_sent:
raise SocketClosed(code, reason) self.onclose(code, reason)
self.sock.close()
raise SocketClosed(True)
else:
self.close_params = (code, reason)
self.send_close_frame(code, reason)
elif frame.opcode == OPCODE_PING: elif frame.opcode == OPCODE_PING:
# Respond with a pong message with identical payload # Respond with a pong message with identical payload
self.sock.send(ControlFrame(OPCODE_PONG, frame.payload)) self.send_frame(ControlFrame(OPCODE_PONG, frame.payload))
elif frame.opcode == OPCODE_PONG: elif frame.opcode == OPCODE_PONG:
# Assert that the PONG payload is identical to that of the PING # Assert that the PONG payload is identical to that of the PING
...@@ -138,38 +155,40 @@ class Connection(object): ...@@ -138,38 +155,40 @@ class Connection(object):
while True: while True:
try: try:
self.onmessage(self.recv()) self.onmessage(self.recv())
except SocketClosed as e: except (KeyboardInterrupt, SystemExit, SocketClosed):
self.close(e.code, e.reason)
break break
except socket.error as e: except Exception as e:
self.onerror(e) self.onerror(e)
self.onclose(None, 'error: %s' % e)
try: try:
self.sock.close() self.sock.close()
except socket.error: except socket.error:
pass pass
self.onclose(None, '') raise e
break
except Exception as e:
self.onerror(e)
def send_ping(self, payload=''): def send_ping(self, payload=''):
""" """
Send a PING control frame with an optional payload. Send a PING control frame with an optional payload.
""" """
self.sock.send(ControlFrame(OPCODE_PING, payload)) self.send_frame(ControlFrame(OPCODE_PING, payload),
lambda: self.onping(payload))
self.ping_payload = payload self.ping_payload = payload
self.ping_sent = True self.ping_sent = True
self.onping(payload)
def send_close_frame(self, code=None, reason=''): def send_close_frame(self, code, reason):
""" self.send_frame(create_close_frame(code, reason))
Send a CLOSE control frame.
"""
payload = '' if code is None else struct.pack('!H', code) + reason
self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
self.close_frame_sent = True self.close_frame_sent = True
self.shutdown_write()
def shutdown_write(self):
if self.close_frame_received:
self.onclose(*self.close_params)
self.sock.close()
raise SocketClosed(False)
else:
self.sock.shutdown(socket.SHUT_WR)
def close(self, code=None, reason=''): def close(self, code=None, reason=''):
""" """
...@@ -179,28 +198,14 @@ class Connection(object): ...@@ -179,28 +198,14 @@ class Connection(object):
called after the response has been received, but before the socket is called after the response has been received, but before the socket is
actually closed. actually closed.
""" """
# Send CLOSE frame
if not self.close_frame_sent:
self.send_close_frame(code, reason) self.send_close_frame(code, reason)
# Receive CLOSE frame
if not self.close_frame_received:
frame = self.sock.recv() frame = self.sock.recv()
if frame.opcode != OPCODE_CLOSE: if frame.opcode != OPCODE_CLOSE:
raise ValueError('expected CLOSE frame, got %s' % frame) raise ValueError('expected CLOSE frame, got %s' % frame)
self.close_frame_received = True self.handle_control_frame(frame)
res_code, res_reason = frame.unpack_close()
# FIXME: check if res_code == code and res_reason == reason?
# FIXME: alternatively, keep receiving frames in a loop until a
# CLOSE frame is received, so that a fragmented chain may arrive
# fully first
self.onclose(code, reason)
self.sock.close()
def add_hook(self, send=None, recv=None, prepend=False): def add_hook(self, send=None, recv=None, prepend=False):
""" """
......
...@@ -20,38 +20,33 @@ class DeflateFrame(Extension): ...@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
name = 'deflate-frame' name = 'deflate-frame'
rsv1 = True rsv1 = True
defaults = {'max_window_bits': 15, 'no_context_takeover': False} defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
def __init__(self, defaults={}, request={}): COMPRESSION_THRESHOLD = 64 # minimal payload size for compression
Extension.__init__(self, defaults, request)
def init(self):
mwb = self.defaults['max_window_bits'] mwb = self.defaults['max_window_bits']
cto = self.defaults['no_context_takeover'] cto = self.defaults['no_context_takeover']
if not isinstance(mwb, int): if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
raise ValueError('"max_window_bits" must be an integer') raise ValueError('"max_window_bits" must be in range 1-15')
elif mwb > 15:
raise ValueError('"max_window_bits" may not be larger than 15')
if cto is not False and cto is not True: if cto is not False and cto is not True:
raise ValueError('"no_context_takeover" must have no value') raise ValueError('"no_context_takeover" must have no value')
class Hook(Extension.Hook): class Hook(Extension.Hook):
def __init__(self, extension, **kwargs): def init(self, extension):
Extension.Hook.__init__(self, extension, **kwargs)
if not self.no_context_takeover:
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, zlib.DEFLATED, -self.max_window_bits)
-self.max_window_bits) other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
other_wbits = self.extension.request.get('max_window_bits', 15)
self.dec = zlib.decompressobj(-other_wbits) self.dec = zlib.decompressobj(-other_wbits)
def send(self, frame): def send(self, frame):
if not frame.rsv1 and not isinstance(frame, ControlFrame): # FIXME: this does not seem to work properly on Android
if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
frame.rsv1 = True frame.rsv1 = True
frame.payload = self.deflate(frame.payload) frame.payload = self.deflate(frame)
return frame return frame
...@@ -65,23 +60,23 @@ class DeflateFrame(Extension): ...@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
return frame return frame
def deflate(self, data): def deflate(self, frame):
if self.no_context_takeover: compressed = self.defl.compress(frame.payload)
defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
# FIXME: why the '\x00' below? This was borrowed from
# https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
compressed = self.defl.compress(data) if frame.final or self.no_context_takeover:
compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
else:
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH) compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff' assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4] compressed = compressed[:-4]
return compressed
def inflate(self, data): def inflate(self, data):
data = self.dec.decompress(str(data + '\x00\x00\xff\xff')) return self.dec.decompress(data + '\x00\x00\xff\xff') + \
assert not self.dec.unused_data self.dec.flush(zlib.Z_SYNC_FLUSH)
return data
class WebkitDeflateFrame(DeflateFrame): class WebkitDeflateFrame(DeflateFrame):
......
class SocketClosed(Exception): class SocketClosed(Exception):
def __init__(self, code=None, reason=''): def __init__(self, initialized):
self.code = code self.initialized = initialized
self.reason = reason
@property @property
def message(self): def message(self):
return ('' if self.code is None else '[%d] ' % self.code) + self.reason s = 'socket closed'
if self.initialized:
s += ' (initialized)'
return s
class HandshakeError(Exception): class HandshakeError(Exception):
......
...@@ -19,23 +19,31 @@ class Extension(object): ...@@ -19,23 +19,31 @@ class Extension(object):
self.request = dict(self.__class__.request) self.request = dict(self.__class__.request)
self.request.update(request) self.request.update(request)
self.init()
def __str__(self): def __str__(self):
return '<Extension "%s" defaults=%s request=%s>' \ return '<Extension "%s" defaults=%s request=%s>' \
% (self.name, self.defaults, self.request) % (self.name, self.defaults, self.request)
def init(self):
return NotImplemented
def create_hook(self, **kwargs): def create_hook(self, **kwargs):
params = {} params = {}
params.update(self.defaults) params.update(self.defaults)
params.update(kwargs) params.update(kwargs)
return self.Hook(self, **params) hook = self.Hook(**params)
hook.init(self)
return hook
class Hook: class Hook:
def __init__(self, extension, **kwargs): def __init__(self, **kwargs):
self.extension = extension
for param, value in kwargs.iteritems(): for param, value in kwargs.iteritems():
setattr(self, param, value) setattr(self, param, value)
def init(self, extension):
return NotImplemented
def send(self, frame): def send(self, frame):
return frame return frame
...@@ -43,28 +51,19 @@ class Extension(object): ...@@ -43,28 +51,19 @@ class Extension(object):
return frame return frame
def filter_extensions(extensions): def extension_conflicts(ext, existing):
"""
Remove extensions that use conflicting rsv bits and/or opcodes, with the
first options being the most preferable.
"""
rsv1_reserved = False rsv1_reserved = False
rsv2_reserved = False rsv2_reserved = False
rsv3_reserved = False rsv3_reserved = False
opcodes_reserved = [] reserved_opcodes = []
compat = []
for ext in extensions: for e in existing:
if ext.rsv1 and rsv1_reserved \ rsv1_reserved |= e.rsv1
rsv2_reserved |= e.rsv2
rsv3_reserved |= e.rsv3
reserved_opcodes.extend(e.opcodes)
return ext.rsv1 and rsv1_reserved \
or ext.rsv2 and rsv2_reserved \ or ext.rsv2 and rsv2_reserved \
or ext.rsv3 and rsv3_reserved \ or ext.rsv3 and rsv3_reserved \
or len(set(ext.opcodes) & set(opcodes_reserved)): or len(set(ext.opcodes) & set(reserved_opcodes))
continue
rsv1_reserved |= ext.rsv1
rsv2_reserved |= ext.rsv2
rsv3_reserved |= ext.rsv3
opcodes_reserved.extend(ext.opcodes)
compat.append(ext)
return compat
...@@ -21,9 +21,11 @@ CLOSE_MESSAGE_TOOBIG = 1009 ...@@ -21,9 +21,11 @@ CLOSE_MESSAGE_TOOBIG = 1009
CLOSE_MISSING_EXTENSIONS = 1010 CLOSE_MISSING_EXTENSIONS = 1010
CLOSE_UNABLE = 1011 CLOSE_UNABLE = 1011
line_printable = [c for c in printable if c not in '\r\n\x0b\x0c']
def printstr(s): def printstr(s):
return ''.join(c if c in printable else '.' for c in s) return ''.join(c if c in line_printable else '.' for c in str(s))
class Frame(object): class Frame(object):
...@@ -154,7 +156,18 @@ class Frame(object): ...@@ -154,7 +156,18 @@ class Frame(object):
if len(self.payload) > max_pl_disp: if len(self.payload) > max_pl_disp:
pl += '...' pl += '...'
return s + ' payload=%s>' % pl s += ' payload=%s' % pl
if self.rsv1:
s += ' rsv1'
if self.rsv2:
s += ' rsv2'
if self.rsv3:
s += ' rsv3'
return s + '>'
class ControlFrame(Frame): class ControlFrame(Frame):
...@@ -194,12 +207,8 @@ class ControlFrame(Frame): ...@@ -194,12 +207,8 @@ class ControlFrame(Frame):
return code, reason return code, reason
def receive_frame(sock): def decode_frame(reader):
""" b1, b2 = struct.unpack('!BB', reader.readn(2))
Receive a single frame on socket `sock`. The frame scheme is explained in
the docs of Frame.pack().
"""
b1, b2 = struct.unpack('!BB', recvn(sock, 2))
final = bool(b1 & 0x80) final = bool(b1 & 0x80)
rsv1 = bool(b1 & 0x40) rsv1 = bool(b1 & 0x40)
...@@ -211,16 +220,16 @@ def receive_frame(sock): ...@@ -211,16 +220,16 @@ def receive_frame(sock):
payload_len = b2 & 0x7F payload_len = b2 & 0x7F
if payload_len == 126: if payload_len == 126:
payload_len = struct.unpack('!H', recvn(sock, 2)) payload_len = struct.unpack('!H', reader.readn(2))
elif payload_len == 127: elif payload_len == 127:
payload_len = struct.unpack('!Q', recvn(sock, 8)) payload_len = struct.unpack('!Q', reader.readn(8))
if masked: if masked:
masking_key = recvn(sock, 4) masking_key = reader.readn(4)
payload = mask(masking_key, recvn(sock, payload_len)) payload = mask(masking_key, reader.readn(payload_len))
else: else:
masking_key = '' masking_key = ''
payload = recvn(sock, payload_len) payload = reader.readn(payload_len)
# Control frames have most significant bit 1 # Control frames have most significant bit 1
cls = ControlFrame if opcode & 0x8 else Frame cls = ControlFrame if opcode & 0x8 else Frame
...@@ -229,14 +238,44 @@ def receive_frame(sock): ...@@ -229,14 +238,44 @@ def receive_frame(sock):
rsv1=rsv1, rsv2=rsv2, rsv3=rsv3) rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
def recvn(sock, n): def receive_frame(sock):
return decode_frame(SocketReader(sock))
def read_frame(data):
reader = BufferReader(data)
frame = decode_frame(reader)
return frame, reader.offset
def pop_frame(data):
frame, size = read_frame(data)
return frame, data[size:]
class BufferReader(object):
def __init__(self, data):
self.data = data
self.offset = 0
def readn(self, n):
assert len(self.data) - self.offset >= n
self.offset += n
return self.data[self.offset - n:self.offset]
class SocketReader(object):
def __init__(self, sock):
self.sock = sock
def readn(self, n):
""" """
Keep receiving data from `sock` until exactly `n` bytes have been read. Keep receiving data until exactly `n` bytes have been read.
""" """
data = '' data = ''
while len(data) < n: while len(data) < n:
received = sock.recv(n - len(data)) received = self.sock.recv(n - len(data))
if not len(received): if not len(received):
raise socket.error('no data read from socket') raise socket.error('no data read from socket')
...@@ -246,6 +285,32 @@ def recvn(sock, n): ...@@ -246,6 +285,32 @@ def recvn(sock, n):
return data return data
def contains_frame(data):
"""
Read the frame length from the start of `data` and check if the data is
long enough to contain the entire frame.
"""
if len(data) < 2:
return False
b2 = struct.unpack('!B', data[1])[0]
payload_len = b2 & 0x7F
payload_start = 2
if payload_len == 126:
if len(data) > 4:
payload_len = struct.unpack('!H', data[2:4])
payload_start = 4
elif payload_len == 127:
if len(data) > 12:
payload_len = struct.unpack('!Q', data[4:12])
payload_start = 12
return len(data) >= payload_len + payload_start
def mask(key, original): def mask(key, original):
""" """
Mask an octet string using the given masking key. Mask an octet string using the given masking key.
...@@ -265,3 +330,8 @@ def mask(key, original): ...@@ -265,3 +330,8 @@ def mask(key, original):
masked[i] ^= key[i % 4] masked[i] ^= key[i % 4]
return masked return masked
def create_close_frame(code, reason):
payload = '' if code is None else struct.pack('!H', code) + reason
return ControlFrame(OPCODE_CLOSE, payload)
...@@ -7,7 +7,7 @@ from hashlib import sha1 ...@@ -7,7 +7,7 @@ from hashlib import sha1
from urlparse import urlparse from urlparse import urlparse
from errors import HandshakeError from errors import HandshakeError
from extension import filter_extensions from extension import extension_conflicts
from python_digest import build_authorization_request from python_digest import build_authorization_request
...@@ -15,7 +15,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' ...@@ -15,7 +15,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
WS_VERSION = '13' WS_VERSION = '13'
MAX_REDIRECTS = 10 MAX_REDIRECTS = 10
HDR_TIMEOUT = 5 HDR_TIMEOUT = 5
MAX_HDR_LEN = 512 MAX_HDR_LEN = 1024
class Handshake(object): class Handshake(object):
...@@ -65,7 +65,11 @@ class Handshake(object): ...@@ -65,7 +65,11 @@ class Handshake(object):
start_time = time.time() start_time = time.time()
while hdr[-4:] != '\r\n\r\n' and len(hdr) < MAX_HDR_LEN: while hdr[-4:] != '\r\n\r\n':
if len(hdr) == MAX_HDR_LEN:
raise HandshakeError('request exceeds maximum header '
'length of %d' % MAX_HDR_LEN)
hdr += self.sock.recv(1) hdr += self.sock.recv(1)
time_diff = time.time() - start_time time_diff = time.time() - start_time
...@@ -169,23 +173,19 @@ class ServerHandshake(Handshake): ...@@ -169,23 +173,19 @@ class ServerHandshake(Handshake):
# Only supported extensions are returned # Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in ssock.extensions) supported_ext = dict((e.name, e) for e in ssock.extensions)
self.wsock.extension_hooks = []
extensions = [] extensions = []
all_params = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']): for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext) name, params = parse_param_hdr(ext)
if name in supported_ext: if name in supported_ext:
extensions.append(supported_ext[name]) ext = supported_ext[name]
all_params.append(params)
self.wsock.extensions = filter_extensions(extensions)
for ext, params in zip(self.wsock.extensions, all_params): if not extension_conflicts(ext, extensions):
extensions.append(ext)
hook = ext.create_hook(**params) hook = ext.create_hook(**params)
self.wsock.add_hook(send=hook.send, recv=hook.recv) self.wsock.extension_hooks.append(hook)
else:
self.wsock.extensions = []
# Check if requested resource location is served by this server # Check if requested resource location is served by this server
if ssock.locations: if ssock.locations:
...@@ -212,7 +212,7 @@ class ServerHandshake(Handshake): ...@@ -212,7 +212,7 @@ class ServerHandshake(Handshake):
location = '%s://%s%s' % (scheme, host, self.wsock.location) location = '%s://%s%s' % (scheme, host, self.wsock.location)
# Construct HTTP response header # Construct HTTP response header
yield 'HTTP/1.1 101 Web Socket Protocol Handshake' yield 'HTTP/1.1 101 Switching Protocols'
yield 'Upgrade', 'websocket' yield 'Upgrade', 'websocket'
yield 'Connection', 'Upgrade' yield 'Connection', 'Upgrade'
yield 'Sec-WebSocket-Origin', origin yield 'Sec-WebSocket-Origin', origin
...@@ -274,7 +274,7 @@ class ClientHandshake(Handshake): ...@@ -274,7 +274,7 @@ class ClientHandshake(Handshake):
# Compare extensions, add hooks only for those returned by server # Compare extensions, add hooks only for those returned by server
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions) supported_ext = dict((e.name, e) for e in self.wsock.extensions)
self.wsock.extensions = [] self.wsock.extension_hooks = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']): for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext) name, params = parse_param_hdr(ext)
...@@ -284,8 +284,7 @@ class ClientHandshake(Handshake): ...@@ -284,8 +284,7 @@ class ClientHandshake(Handshake):
'unsupported extension "%s"' % name) 'unsupported extension "%s"' % name)
hook = supported_ext[name].create_hook(**params) hook = supported_ext[name].create_hook(**params)
self.wsock.extensions.append(supported_ext[name]) self.wsock.extension_hooks.append(hook)
self.wsock.add_hook(send=hook.send, recv=hook.recv)
# Assert that returned protocol (if any) is supported # Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers: if 'Sec-WebSocket-Protocol' in headers:
......
...@@ -32,7 +32,7 @@ class Server(object): ...@@ -32,7 +32,7 @@ class Server(object):
""" """
def __init__(self, address, loglevel=logging.INFO, ssl_args=None, def __init__(self, address, loglevel=logging.INFO, ssl_args=None,
max_join_time=2.0, **kwargs): max_join_time=2.0, backlog_size=32, **kwargs):
""" """
Constructor for a simple web socket server. Constructor for a simple web socket server.
...@@ -53,6 +53,8 @@ class Server(object): ...@@ -53,6 +53,8 @@ class Server(object):
`max_join_time` is the maximum time (in seconds) to wait for client `max_join_time` is the maximum time (in seconds) to wait for client
responses after sending CLOSE frames, it defaults to 2 seconds. responses after sending CLOSE frames, it defaults to 2 seconds.
`backlog_size` is directly passed to `websocket.listen`.
""" """
logging.basicConfig(level=loglevel, logging.basicConfig(level=loglevel,
format='%(asctime)s: %(levelname)s: %(message)s', format='%(asctime)s: %(levelname)s: %(message)s',
...@@ -69,14 +71,14 @@ class Server(object): ...@@ -69,14 +71,14 @@ class Server(object):
self.sock.enable_ssl(server_side=True, **ssl_args) self.sock.enable_ssl(server_side=True, **ssl_args)
self.sock.bind(address) self.sock.bind(address)
self.sock.listen(5) self.sock.listen(backlog_size)
self.clients = []
self.client_threads = []
self.max_join_time = max_join_time self.max_join_time = max_join_time
def run(self): def run(self):
self.clients = []
self.client_threads = []
while True: while True:
try: try:
sock, address = self.sock.accept() sock, address = self.sock.accept()
...@@ -134,30 +136,22 @@ class Server(object): ...@@ -134,30 +136,22 @@ class Server(object):
self.onclose(client, code, reason) self.onclose(client, code, reason)
def onopen(self, client): def onopen(self, client):
logging.debug('Opened socket to %s', client) return NotImplemented
def onmessage(self, client, message): def onmessage(self, client, message):
logging.debug('Received %s from %s', message, client) return NotImplemented
def onping(self, client, payload): def onping(self, client, payload):
logging.debug('Sent ping "%s" to %s', payload, client) return NotImplemented
def onpong(self, client, payload): def onpong(self, client, payload):
logging.debug('Received pong "%s" from %s', payload, client) return NotImplemented
def onclose(self, client, code, reason): def onclose(self, client, code, reason):
msg = 'Closed socket to %s' % client return NotImplemented
if code is not None:
msg += ' [%d]' % code
if len(reason):
msg += ': ' + reason
logging.debug(msg)
def onerror(self, client, e): def onerror(self, client, e):
logging.error(format_exc(e)) return NotImplemented
class Client(Connection): class Client(Connection):
...@@ -176,21 +170,32 @@ class Client(Connection): ...@@ -176,21 +170,32 @@ class Client(Connection):
Connection.send(self, message, fragment_size=fragment_size, mask=mask) Connection.send(self, message, fragment_size=fragment_size, mask=mask)
def onopen(self): def onopen(self):
logging.debug('Opened socket to %s', self)
self.server.onopen(self) self.server.onopen(self)
def onmessage(self, message): def onmessage(self, message):
logging.debug('Received %s from %s', message, self)
self.server.onmessage(self, message) self.server.onmessage(self, message)
def onping(self, payload): def onping(self, payload):
logging.debug('Sent ping "%s" to %s', payload, self)
self.server.onping(self, payload) self.server.onping(self, payload)
def onpong(self, payload): def onpong(self, payload):
logging.debug('Received pong "%s" from %s', payload, self)
self.server.onpong(self, payload) self.server.onpong(self, payload)
def onclose(self, code, reason): def onclose(self, code, reason):
msg = 'Closed socket to %s' % self
if code is not None:
msg += ': [%d] %s' % (code, reason)
logging.debug(msg)
self.server.remove_client(self, code, reason) self.server.remove_client(self, code, reason)
def onerror(self, e): def onerror(self, e):
logging.error(format_exc(e))
self.server.onerror(self, e) self.server.onerror(self, e)
......
#!/usr/bin/env python #!/usr/bin/env python
import sys import sys
import ssl
from os.path import abspath, dirname from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..') basepath = abspath(dirname(abspath(__file__)) + '/..')
...@@ -20,7 +21,7 @@ class EchoClient(Connection): ...@@ -20,7 +21,7 @@ class EchoClient(Connection):
def onmessage(self, msg): def onmessage(self, msg):
print 'Received', msg print 'Received', msg
raise SocketClosed(None, 'response received') self.close(None, 'response received')
def onerror(self, e): def onerror(self, e):
print 'Error:', e print 'Error:', e
...@@ -29,8 +30,15 @@ class EchoClient(Connection): ...@@ -29,8 +30,15 @@ class EchoClient(Connection):
print 'Connection closed' print 'Connection closed'
secure = True
if __name__ == '__main__': if __name__ == '__main__':
print 'Connecting to ws://%s:%d' % ADDR scheme = 'wss' if secure else 'ws'
print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
sock = websocket() sock = websocket()
if secure:
sock.enable_ssl(ca_certs='cert.pem', cert_reqs=ssl.CERT_REQUIRED)
sock.connect(ADDR) sock.connect(ADDR)
EchoClient(sock).receive_forever() EchoClient(sock).receive_forever()
#!/usr/bin/env python
import sys
import socket
from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..')
sys.path.insert(0, basepath)
from websocket import websocket
from connection import Connection
from message import TextMessage
from errors import SocketClosed
if __name__ == '__main__':
if len(sys.argv) < 3:
print >> sys.stderr, 'Usage: python %s HOST PORT' % sys.argv[0]
sys.exit(1)
host = sys.argv[1]
port = int(sys.argv[2])
sock = websocket()
sock.connect((host, port))
sock.settimeout(1.0)
conn = Connection(sock)
try:
try:
while True:
msg = TextMessage(raw_input())
print 'send:', msg
conn.send(msg)
try:
print 'recv:', conn.recv()
except socket.timeout:
print 'no response'
except EOFError:
conn.close()
except SocketClosed as e:
if e.initialized:
print 'closed connection'
else:
print 'other side closed connection'
import socket import socket
import ssl import ssl
from frame import receive_frame from frame import receive_frame, pop_frame, contains_frame
from handshake import ServerHandshake, ClientHandshake from handshake import ServerHandshake, ClientHandshake
from errors import SSLError from errors import SSLError
...@@ -11,7 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername', ...@@ -11,7 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
'settimeout', 'gettimeout', 'shutdown', 'family', 'type', 'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
'proto'] 'proto']
class websocket(object): class websocket(object):
""" """
Implementation of web socket, upgrades a regular TCP socket to a websocket Implementation of web socket, upgrades a regular TCP socket to a websocket
...@@ -36,22 +35,23 @@ class websocket(object): ...@@ -36,22 +35,23 @@ class websocket(object):
>>> sock.connect(('', 8000)) >>> sock.connect(('', 8000))
>>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!')) >>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
""" """
def __init__(self, sock=None, protocols=[], extensions=[], origin=None, def __init__(self, sock=None, origin=None, protocols=[], extensions=[],
location='/', trusted_origins=[], locations=[], auth=None, location='/', trusted_origins=[], locations=[], auth=None,
sfamily=socket.AF_INET, sproto=0): recv_callback=None, sfamily=socket.AF_INET, sproto=0):
""" """
Create a regular TCP socket of family `family` and protocol Create a regular TCP socket of family `family` and protocol
`sock` is an optional regular TCP socket to be used for sending binary `sock` is an optional regular TCP socket to be used for sending binary
data. If not specified, a new socket is created. data. If not specified, a new socket is created.
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extensions (`Extension` instances).
`origin` (for client sockets) is the value for the "Origin" header sent `origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake . in a client handshake .
`protocols` is a list of supported protocol names.
`extensions` (for server sockets) is a list of supported extensions
(`Extension` instances).
`location` (for client sockets) is optional, used to request a `location` (for client sockets) is optional, used to request a
particular resource in the HTTP handshake. In a URL, this would show as particular resource in the HTTP handshake. In a URL, this would show as
ws://host[:port]/<location>. Use this when the server serves multiple ws://host[:port]/<location>. Use this when the server serves multiple
...@@ -71,10 +71,17 @@ class websocket(object): ...@@ -71,10 +71,17 @@ class websocket(object):
`auth` is optional, used for HTTP Basic or Digest authentication during `auth` is optional, used for HTTP Basic or Digest authentication during
the handshake. It must be specified as a (username, password) tuple. the handshake. It must be specified as a (username, password) tuple.
`recv_callback` is the callback for received frames in asynchronous
sockets. Use in conjunction with setblocking(0). The callback itself
may for example change the recv_callback attribute to change the
behaviour for the next received message. Can be set when calling
`queue_send`.
`sfamily` and `sproto` are used for the regular socket constructor. `sfamily` and `sproto` are used for the regular socket constructor.
""" """
self.protocols = protocols self.protocols = protocols
self.extensions = extensions self.extensions = extensions
self.extension_hooks = []
self.origin = origin self.origin = origin
self.location = location self.location = location
self.trusted_origins = trusted_origins self.trusted_origins = trusted_origins
...@@ -85,11 +92,16 @@ class websocket(object): ...@@ -85,11 +92,16 @@ class websocket(object):
self.handshake_sent = False self.handshake_sent = False
self.hooks_send = [] self.sendbuf_frames = []
self.hooks_recv = [] self.sendbuf = ''
self.recvbuf = ''
self.recv_callback = recv_callback
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto) self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
def set_extensions(self, extensions):
self.extensions = [ext.Hook() for ext in extensions]
def __getattr__(self, name): def __getattr__(self, name):
if name in INHERITED_ATTRS: if name in INHERITED_ATTRS:
return getattr(self.sock, name) return getattr(self.sock, name)
...@@ -122,29 +134,31 @@ class websocket(object): ...@@ -122,29 +134,31 @@ class websocket(object):
ClientHandshake(self).perform() ClientHandshake(self).perform()
self.handshake_sent = True self.handshake_sent = True
def apply_send_hooks(self, frame):
for hook in self.extension_hooks:
frame = hook.send(frame)
return frame
def apply_recv_hooks(self, frame):
for hook in reversed(self.extension_hooks):
frame = hook.recv(frame)
return frame
def send(self, *args): def send(self, *args):
""" """
Send a number of frames. Send a number of frames.
""" """
for frame in args: for frame in args:
for hook in self.hooks_send: self.sock.sendall(self.apply_send_hooks(frame).pack())
frame = hook(frame)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self.sock.sendall(frame.pack())
def recv(self): def recv(self):
""" """
Receive a single frames. This can be either a data frame or a control Receive a single frames. This can be either a data frame or a control
frame. frame.
""" """
frame = receive_frame(self.sock) return self.apply_recv_hooks(receive_frame(self.sock))
for hook in self.hooks_recv:
frame = hook(frame)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return frame
def recvn(self, n): def recvn(self, n):
""" """
...@@ -153,47 +167,79 @@ class websocket(object): ...@@ -153,47 +167,79 @@ class websocket(object):
""" """
return [self.recv() for i in xrange(n)] return [self.recv() for i in xrange(n)]
def enable_ssl(self, *args, **kwargs): def queue_send(self, frame, callback=None, recv_callback=None):
""" """
Transforms the regular socket.socket to an ssl.SSLSocket for secure Enqueue `frame` to the send buffer so that it is send on the next
connections. Any arguments are passed to ssl.wrap_socket: `do_async_send`. `callback` is an optional callable to call when the
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket frame has been fully written. `recv_callback` is an optional callable
to quickly set the `recv_callback` attribute to.
""" """
if self.handshake_sent: frame = self.apply_send_hooks(frame)
raise SSLError('can only enable SSL before handshake') self.sendbuf += frame.pack()
self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
self.secure = True if recv_callback:
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs) self.recv_callback = recv_callback
def add_hook(self, send=None, recv=None, prepend=False): def do_async_send(self):
""" """
Add a pair of send and receive hooks that are called for each frame Send any queued data. This function should only be called after a write
that is sent or received. A hook is a function that receives a single event on a file descriptor.
argument - a Frame instance - and returns a `Frame` instance as well. """
assert len(self.sendbuf)
`prepend` is a flag indicating whether the send hook is prepended to nwritten = self.sock.send(self.sendbuf)
the other send hooks. This is expecially useful when a program uses nframes = 0
extensions such as the built-in `DeflateFrame` extension. These
extensions are installed using these hooks as well.
For example, the following code creates a `Frame` instance for data for entry in self.sendbuf_frames:
being sent and removes the instance for received data. This way, data frame, offset, callback = entry
can be sent and received as if on a regular socket.
>>> import wspy
>>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
>>> lambda frame: frame.payload)
To add base64 encoding to the example above: if offset <= nwritten:
>>> import base64 nframes += 1
>>> sock.add_hook(base64.encodestring, base64.decodestring, True)
Note that here `prepend=True`, so that data passed to `send()` is first if callback:
encoded and then packed into a frame. Of course, one could also decide callback()
to add the base64 hook first, or to return a new `Frame` instance with else:
base64-encoded data. entry[1] -= nwritten
self.sendbuf = self.sendbuf[nwritten:]
self.sendbuf_frames = self.sendbuf_frames[nframes:]
def do_async_recv(self, bufsize):
"""
Receive any completed frames from the socket. This function should only
be called after a read event on a file descriptor.
""" """
if send: data = self.sock.recv(bufsize)
self.hooks_send.insert(0 if prepend else -1, send)
if len(data) == 0:
raise socket.error('no data to receive')
self.recvbuf += data
while contains_frame(self.recvbuf):
frame, self.recvbuf = pop_frame(self.recvbuf)
frame = self.apply_recv_hooks(frame)
if not self.recv_callback:
raise ValueError('no callback installed for %s' % frame)
self.recv_callback(frame)
def can_send(self):
return len(self.sendbuf) > 0
if recv: def can_recv(self):
self.hooks_recv.insert(-1 if prepend else 0, recv) return self.recv_callback is not None
def enable_ssl(self, *args, **kwargs):
"""
Transforms the regular socket.socket to an ssl.SSLSocket for secure
connections. Any arguments are passed to ssl.wrap_socket:
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
"""
if self.handshake_sent:
raise SSLError('can only enable SSL before handshake')
self.secure = True
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment