Commit 6c79550e authored by Taddeüs Kroes's avatar Taddeüs Kroes

Merge branch 'async'

parents 4d9fbb0c 447ee6fa
......@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library:
- Secure sockets using SSL certificates (for 'wss://...' URLs).
- The possibility to add extensions to the web socket protocol. An included
implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
- Asynchronous sockets with an EPOLL-based server.
Installation
......
......@@ -4,10 +4,11 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE, read_frame, pop_frame, \
contains_frame
from connection import Connection
from message import Message, TextMessage, BinaryMessage
from errors import SocketClosed, HandshakeError, PingError, SSLError
from extension import Extension
from deflate_frame import DeflateFrame, WebkitDeflateFrame
#from multiplex import Multiplex
from async import AsyncConnection, AsyncServer
import socket
from select import epoll, EPOLLIN, EPOLLOUT, EPOLLHUP
from traceback import format_exc
import logging
from connection import Connection
from frame import ControlFrame, OPCODE_PING, OPCODE_CONTINUATION, \
create_close_frame
from server import Server, Client
from errors import HandshakeError, SocketClosed
class AsyncConnection(Connection):
def __init__(self, sock):
sock.recv_callback = self.contruct_message
sock.recv_close_callback = self.onclose
self.recvbuf = []
Connection.__init__(self, sock)
def contruct_message(self, frame):
if isinstance(frame, ControlFrame):
self.handle_control_frame(frame)
return
self.recvbuf.append(frame)
if frame.final:
message = self.concat_fragments(self.recvbuf)
self.recvbuf = []
self.onmessage(message)
elif len(self.recvbuf) > 1 and frame.opcode != OPCODE_CONTINUATION:
raise ValueError('expected continuation/control frame, got %s '
'instead' % frame)
def send(self, message, fragment_size=None, mask=False):
frames = list(self.message_to_frames(message, fragment_size, mask))
for frame in frames[:-1]:
self.sock.queue_send(frame)
self.sock.queue_send(frames[-1], lambda: self.onsend(message))
def send_frame(self, frame, callback):
self.sock.queue_send(frame, callback)
def do_async_send(self):
self.execute_controlled(self.sock.do_async_send)
def do_async_recv(self, bufsize):
self.execute_controlled(self.sock.do_async_recv, bufsize)
def execute_controlled(self, func, *args, **kwargs):
try:
func(*args, **kwargs)
except (KeyboardInterrupt, SystemExit, SocketClosed):
raise
except Exception as e:
self.onerror(e)
self.onclose(None, 'error: %s' % e)
try:
self.sock.close()
except socket.error:
pass
raise e
def send_close_frame(self, code, reason):
self.sock.queue_send(create_close_frame(code, reason),
self.shutdown_write)
self.close_frame_sent = True
def close(self, code=None, reason=''):
self.send_close_frame(code, reason)
def send_ping(self, payload=''):
self.sock.queue_send(ControlFrame(OPCODE_PING, payload),
lambda: self.onping(payload))
self.ping_payload = payload
self.ping_sent = True
def onsend(self, message):
"""
Called after a message has been written.
"""
return NotImplemented
class AsyncServer(Server):
def __init__(self, *args, **kwargs):
Server.__init__(self, *args, **kwargs)
self.recvbuf_size = kwargs.get('recvbuf_size', 2048)
self.epoll = epoll()
self.epoll.register(self.sock.fileno(), EPOLLIN)
self.conns = {}
@property
def clients(self):
return self.conns.values()
def remove_client(self, client, code, reason):
self.epoll.unregister(client.fno)
del self.conns[client.fno]
self.onclose(client, code, reason)
def handle_events(self):
for fileno, event in self.epoll.poll(1):
if fileno == self.sock.fileno():
try:
sock, addr = self.sock.accept()
except HandshakeError as e:
logging.error('Invalid request: %s', e.message)
continue
client = AsyncClient(self, sock)
client.fno = sock.fileno()
sock.setblocking(0)
self.epoll.register(client.fno, EPOLLIN)
self.conns[client.fno] = client
logging.debug('Registered client %s', client)
elif event & EPOLLHUP:
self.epoll.unregister(fileno)
del self.conns[fileno]
else:
conn = self.conns[fileno]
try:
if event & EPOLLOUT:
conn.do_async_send()
elif event & EPOLLIN:
conn.do_async_recv(self.recvbuf_size)
except (KeyboardInterrupt, SystemExit):
raise
except SocketClosed:
continue
except Exception as e:
logging.error(format_exc(e).rstrip())
continue
self.update_mask(conn)
def run(self):
try:
while True:
self.handle_events()
except (KeyboardInterrupt, SystemExit):
logging.info('Received interrupt, stopping server...')
finally:
self.epoll.unregister(self.sock.fileno())
self.epoll.close()
self.sock.close()
def update_mask(self, conn):
mask = 0
if conn.sock.can_send():
mask |= EPOLLOUT
if conn.sock.can_recv():
mask |= EPOLLIN
self.epoll.modify(conn.sock.fileno(), mask)
def onsend(self, client, message):
return NotImplemented
class AsyncClient(Client, AsyncConnection):
def __init__(self, server, sock):
self.server = server
AsyncConnection.__init__(self, sock)
def send(self, message, fragment_size=None, mask=False):
logging.debug('Enqueueing %s to %s', message, self)
AsyncConnection.send(self, message, fragment_size, mask)
self.server.update_mask(self)
def onsend(self, message):
logging.debug('Finished sending %s to %s', message, self)
self.server.onsend(self, message)
if __name__ == '__main__':
import sys
port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
AsyncServer(('', port), loglevel=logging.DEBUG).run()
import struct
import socket
from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
OPCODE_CONTINUATION
OPCODE_CONTINUATION, create_close_frame
from message import create_message
from errors import SocketClosed, PingError
......@@ -54,19 +53,30 @@ class Connection(object):
self.onopen()
def message_to_frames(self, message, fragment_size=None, mask=False):
for hook in self.hooks_send:
message = hook(message)
if fragment_size is None:
yield message.frame(mask=mask)
else:
for frame in message.fragment(fragment_size, mask=mask):
yield frame
def send(self, message, fragment_size=None, mask=False):
"""
Send a message. If `fragment_size` is specified, the message is
fragmented into multiple frames whose payload size does not extend
`fragment_size`.
"""
for hook in self.hooks_send:
message = hook(message)
for frame in self.message_to_frames(message, fragment_size, mask):
self.send_frame(frame)
if fragment_size is None:
self.sock.send(message.frame(mask=mask))
else:
self.sock.send(*message.fragment(fragment_size, mask=mask))
def send_frame(self, frame, callback=None):
self.sock.send(frame)
if callback:
callback()
def recv(self):
"""
......@@ -82,12 +92,15 @@ class Connection(object):
if isinstance(frame, ControlFrame):
self.handle_control_frame(frame)
elif len(fragments) and frame.opcode != OPCODE_CONTINUATION:
elif len(fragments) > 0 and frame.opcode != OPCODE_CONTINUATION:
raise ValueError('expected continuation/control frame, got %s '
'instead' % frame)
else:
fragments.append(frame)
return self.concat_fragments(fragments)
def concat_fragments(self, fragments):
payload = bytearray()
for f in fragments:
......@@ -105,16 +118,20 @@ class Connection(object):
Handle a control frame as defined by RFC 6455.
"""
if frame.opcode == OPCODE_CLOSE:
# Close the connection from this end as well
self.close_frame_received = True
code, reason = frame.unpack_close()
# No more receiving data after a close message
raise SocketClosed(code, reason)
if self.close_frame_sent:
self.onclose(code, reason)
self.sock.close()
raise SocketClosed(True)
else:
self.close_params = (code, reason)
self.send_close_frame(code, reason)
elif frame.opcode == OPCODE_PING:
# Respond with a pong message with identical payload
self.sock.send(ControlFrame(OPCODE_PONG, frame.payload))
self.send_frame(ControlFrame(OPCODE_PONG, frame.payload))
elif frame.opcode == OPCODE_PONG:
# Assert that the PONG payload is identical to that of the PING
......@@ -138,38 +155,40 @@ class Connection(object):
while True:
try:
self.onmessage(self.recv())
except SocketClosed as e:
self.close(e.code, e.reason)
except (KeyboardInterrupt, SystemExit, SocketClosed):
break
except socket.error as e:
except Exception as e:
self.onerror(e)
self.onclose(None, 'error: %s' % e)
try:
self.sock.close()
except socket.error:
pass
self.onclose(None, '')
break
except Exception as e:
self.onerror(e)
raise e
def send_ping(self, payload=''):
"""
Send a PING control frame with an optional payload.
"""
self.sock.send(ControlFrame(OPCODE_PING, payload))
self.send_frame(ControlFrame(OPCODE_PING, payload),
lambda: self.onping(payload))
self.ping_payload = payload
self.ping_sent = True
self.onping(payload)
def send_close_frame(self, code=None, reason=''):
"""
Send a CLOSE control frame.
"""
payload = '' if code is None else struct.pack('!H', code) + reason
self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
def send_close_frame(self, code, reason):
self.send_frame(create_close_frame(code, reason))
self.close_frame_sent = True
self.shutdown_write()
def shutdown_write(self):
if self.close_frame_received:
self.onclose(*self.close_params)
self.sock.close()
raise SocketClosed(False)
else:
self.sock.shutdown(socket.SHUT_WR)
def close(self, code=None, reason=''):
"""
......@@ -179,28 +198,14 @@ class Connection(object):
called after the response has been received, but before the socket is
actually closed.
"""
# Send CLOSE frame
if not self.close_frame_sent:
self.send_close_frame(code, reason)
# Receive CLOSE frame
if not self.close_frame_received:
frame = self.sock.recv()
if frame.opcode != OPCODE_CLOSE:
raise ValueError('expected CLOSE frame, got %s' % frame)
self.close_frame_received = True
res_code, res_reason = frame.unpack_close()
# FIXME: check if res_code == code and res_reason == reason?
# FIXME: alternatively, keep receiving frames in a loop until a
# CLOSE frame is received, so that a fragmented chain may arrive
# fully first
self.onclose(code, reason)
self.sock.close()
self.handle_control_frame(frame)
def add_hook(self, send=None, recv=None, prepend=False):
"""
......
......@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
name = 'deflate-frame'
rsv1 = True
defaults = {'max_window_bits': 15, 'no_context_takeover': False}
defaults = {'max_window_bits': zlib.MAX_WBITS, 'no_context_takeover': False}
def __init__(self, defaults={}, request={}):
Extension.__init__(self, defaults, request)
COMPRESSION_THRESHOLD = 64 # minimal payload size for compression
def init(self):
mwb = self.defaults['max_window_bits']
cto = self.defaults['no_context_takeover']
if not isinstance(mwb, int):
raise ValueError('"max_window_bits" must be an integer')
elif mwb > 15:
raise ValueError('"max_window_bits" may not be larger than 15')
if not isinstance(mwb, int) or mwb < 1 or mwb > zlib.MAX_WBITS:
raise ValueError('"max_window_bits" must be in range 1-15')
if cto is not False and cto is not True:
raise ValueError('"no_context_takeover" must have no value')
class Hook(Extension.Hook):
def __init__(self, extension, **kwargs):
Extension.Hook.__init__(self, extension, **kwargs)
if not self.no_context_takeover:
def init(self, extension):
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED,
-self.max_window_bits)
other_wbits = self.extension.request.get('max_window_bits', 15)
zlib.DEFLATED, -self.max_window_bits)
other_wbits = extension.request.get('max_window_bits', zlib.MAX_WBITS)
self.dec = zlib.decompressobj(-other_wbits)
def send(self, frame):
if not frame.rsv1 and not isinstance(frame, ControlFrame):
# FIXME: this does not seem to work properly on Android
if not frame.rsv1 and not isinstance(frame, ControlFrame) and \
len(frame.payload) > DeflateFrame.COMPRESSION_THRESHOLD:
frame.rsv1 = True
frame.payload = self.deflate(frame.payload)
frame.payload = self.deflate(frame)
return frame
......@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
return frame
def deflate(self, data):
if self.no_context_takeover:
defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
# FIXME: why the '\x00' below? This was borrowed from
# https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
def deflate(self, frame):
compressed = self.defl.compress(frame.payload)
compressed = self.defl.compress(data)
if frame.final or self.no_context_takeover:
compressed += self.defl.flush(zlib.Z_FINISH) + '\x00'
self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -self.max_window_bits)
else:
compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
assert compressed[-4:] == '\x00\x00\xff\xff'
return compressed[:-4]
compressed = compressed[:-4]
return compressed
def inflate(self, data):
data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
assert not self.dec.unused_data
return data
return self.dec.decompress(data + '\x00\x00\xff\xff') + \
self.dec.flush(zlib.Z_SYNC_FLUSH)
class WebkitDeflateFrame(DeflateFrame):
......
class SocketClosed(Exception):
def __init__(self, code=None, reason=''):
self.code = code
self.reason = reason
def __init__(self, initialized):
self.initialized = initialized
@property
def message(self):
return ('' if self.code is None else '[%d] ' % self.code) + self.reason
s = 'socket closed'
if self.initialized:
s += ' (initialized)'
return s
class HandshakeError(Exception):
......
......@@ -19,23 +19,31 @@ class Extension(object):
self.request = dict(self.__class__.request)
self.request.update(request)
self.init()
def __str__(self):
return '<Extension "%s" defaults=%s request=%s>' \
% (self.name, self.defaults, self.request)
def init(self):
return NotImplemented
def create_hook(self, **kwargs):
params = {}
params.update(self.defaults)
params.update(kwargs)
return self.Hook(self, **params)
hook = self.Hook(**params)
hook.init(self)
return hook
class Hook:
def __init__(self, extension, **kwargs):
self.extension = extension
def __init__(self, **kwargs):
for param, value in kwargs.iteritems():
setattr(self, param, value)
def init(self, extension):
return NotImplemented
def send(self, frame):
return frame
......@@ -43,28 +51,19 @@ class Extension(object):
return frame
def filter_extensions(extensions):
"""
Remove extensions that use conflicting rsv bits and/or opcodes, with the
first options being the most preferable.
"""
def extension_conflicts(ext, existing):
rsv1_reserved = False
rsv2_reserved = False
rsv3_reserved = False
opcodes_reserved = []
compat = []
reserved_opcodes = []
for ext in extensions:
if ext.rsv1 and rsv1_reserved \
for e in existing:
rsv1_reserved |= e.rsv1
rsv2_reserved |= e.rsv2
rsv3_reserved |= e.rsv3
reserved_opcodes.extend(e.opcodes)
return ext.rsv1 and rsv1_reserved \
or ext.rsv2 and rsv2_reserved \
or ext.rsv3 and rsv3_reserved \
or len(set(ext.opcodes) & set(opcodes_reserved)):
continue
rsv1_reserved |= ext.rsv1
rsv2_reserved |= ext.rsv2
rsv3_reserved |= ext.rsv3
opcodes_reserved.extend(ext.opcodes)
compat.append(ext)
return compat
or len(set(ext.opcodes) & set(reserved_opcodes))
......@@ -21,9 +21,11 @@ CLOSE_MESSAGE_TOOBIG = 1009
CLOSE_MISSING_EXTENSIONS = 1010
CLOSE_UNABLE = 1011
line_printable = [c for c in printable if c not in '\r\n\x0b\x0c']
def printstr(s):
return ''.join(c if c in printable else '.' for c in s)
return ''.join(c if c in line_printable else '.' for c in str(s))
class Frame(object):
......@@ -154,7 +156,18 @@ class Frame(object):
if len(self.payload) > max_pl_disp:
pl += '...'
return s + ' payload=%s>' % pl
s += ' payload=%s' % pl
if self.rsv1:
s += ' rsv1'
if self.rsv2:
s += ' rsv2'
if self.rsv3:
s += ' rsv3'
return s + '>'
class ControlFrame(Frame):
......@@ -194,12 +207,8 @@ class ControlFrame(Frame):
return code, reason
def receive_frame(sock):
"""
Receive a single frame on socket `sock`. The frame scheme is explained in
the docs of Frame.pack().
"""
b1, b2 = struct.unpack('!BB', recvn(sock, 2))
def decode_frame(reader):
b1, b2 = struct.unpack('!BB', reader.readn(2))
final = bool(b1 & 0x80)
rsv1 = bool(b1 & 0x40)
......@@ -211,16 +220,16 @@ def receive_frame(sock):
payload_len = b2 & 0x7F
if payload_len == 126:
payload_len = struct.unpack('!H', recvn(sock, 2))
payload_len = struct.unpack('!H', reader.readn(2))
elif payload_len == 127:
payload_len = struct.unpack('!Q', recvn(sock, 8))
payload_len = struct.unpack('!Q', reader.readn(8))
if masked:
masking_key = recvn(sock, 4)
payload = mask(masking_key, recvn(sock, payload_len))
masking_key = reader.readn(4)
payload = mask(masking_key, reader.readn(payload_len))
else:
masking_key = ''
payload = recvn(sock, payload_len)
payload = reader.readn(payload_len)
# Control frames have most significant bit 1
cls = ControlFrame if opcode & 0x8 else Frame
......@@ -229,14 +238,44 @@ def receive_frame(sock):
rsv1=rsv1, rsv2=rsv2, rsv3=rsv3)
def recvn(sock, n):
def receive_frame(sock):
return decode_frame(SocketReader(sock))
def read_frame(data):
reader = BufferReader(data)
frame = decode_frame(reader)
return frame, reader.offset
def pop_frame(data):
frame, size = read_frame(data)
return frame, data[size:]
class BufferReader(object):
def __init__(self, data):
self.data = data
self.offset = 0
def readn(self, n):
assert len(self.data) - self.offset >= n
self.offset += n
return self.data[self.offset - n:self.offset]
class SocketReader(object):
def __init__(self, sock):
self.sock = sock
def readn(self, n):
"""
Keep receiving data from `sock` until exactly `n` bytes have been read.
Keep receiving data until exactly `n` bytes have been read.
"""
data = ''
while len(data) < n:
received = sock.recv(n - len(data))
received = self.sock.recv(n - len(data))
if not len(received):
raise socket.error('no data read from socket')
......@@ -246,6 +285,32 @@ def recvn(sock, n):
return data
def contains_frame(data):
"""
Read the frame length from the start of `data` and check if the data is
long enough to contain the entire frame.
"""
if len(data) < 2:
return False
b2 = struct.unpack('!B', data[1])[0]
payload_len = b2 & 0x7F
payload_start = 2
if payload_len == 126:
if len(data) > 4:
payload_len = struct.unpack('!H', data[2:4])
payload_start = 4
elif payload_len == 127:
if len(data) > 12:
payload_len = struct.unpack('!Q', data[4:12])
payload_start = 12
return len(data) >= payload_len + payload_start
def mask(key, original):
"""
Mask an octet string using the given masking key.
......@@ -265,3 +330,8 @@ def mask(key, original):
masked[i] ^= key[i % 4]
return masked
def create_close_frame(code, reason):
payload = '' if code is None else struct.pack('!H', code) + reason
return ControlFrame(OPCODE_CLOSE, payload)
......@@ -7,7 +7,7 @@ from hashlib import sha1
from urlparse import urlparse
from errors import HandshakeError
from extension import filter_extensions
from extension import extension_conflicts
from python_digest import build_authorization_request
......@@ -15,7 +15,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
WS_VERSION = '13'
MAX_REDIRECTS = 10
HDR_TIMEOUT = 5
MAX_HDR_LEN = 512
MAX_HDR_LEN = 1024
class Handshake(object):
......@@ -65,7 +65,11 @@ class Handshake(object):
start_time = time.time()
while hdr[-4:] != '\r\n\r\n' and len(hdr) < MAX_HDR_LEN:
while hdr[-4:] != '\r\n\r\n':
if len(hdr) == MAX_HDR_LEN:
raise HandshakeError('request exceeds maximum header '
'length of %d' % MAX_HDR_LEN)
hdr += self.sock.recv(1)
time_diff = time.time() - start_time
......@@ -169,23 +173,19 @@ class ServerHandshake(Handshake):
# Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in ssock.extensions)
self.wsock.extension_hooks = []
extensions = []
all_params = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext)
if name in supported_ext:
extensions.append(supported_ext[name])
all_params.append(params)
self.wsock.extensions = filter_extensions(extensions)
ext = supported_ext[name]
for ext, params in zip(self.wsock.extensions, all_params):
if not extension_conflicts(ext, extensions):
extensions.append(ext)
hook = ext.create_hook(**params)
self.wsock.add_hook(send=hook.send, recv=hook.recv)
else:
self.wsock.extensions = []
self.wsock.extension_hooks.append(hook)
# Check if requested resource location is served by this server
if ssock.locations:
......@@ -212,7 +212,7 @@ class ServerHandshake(Handshake):
location = '%s://%s%s' % (scheme, host, self.wsock.location)
# Construct HTTP response header
yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
yield 'HTTP/1.1 101 Switching Protocols'
yield 'Upgrade', 'websocket'
yield 'Connection', 'Upgrade'
yield 'Sec-WebSocket-Origin', origin
......@@ -274,7 +274,7 @@ class ClientHandshake(Handshake):
# Compare extensions, add hooks only for those returned by server
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions)
self.wsock.extensions = []
self.wsock.extension_hooks = []
for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
name, params = parse_param_hdr(ext)
......@@ -284,8 +284,7 @@ class ClientHandshake(Handshake):
'unsupported extension "%s"' % name)
hook = supported_ext[name].create_hook(**params)
self.wsock.extensions.append(supported_ext[name])
self.wsock.add_hook(send=hook.send, recv=hook.recv)
self.wsock.extension_hooks.append(hook)
# Assert that returned protocol (if any) is supported
if 'Sec-WebSocket-Protocol' in headers:
......
......@@ -32,7 +32,7 @@ class Server(object):
"""
def __init__(self, address, loglevel=logging.INFO, ssl_args=None,
max_join_time=2.0, **kwargs):
max_join_time=2.0, backlog_size=32, **kwargs):
"""
Constructor for a simple web socket server.
......@@ -53,6 +53,8 @@ class Server(object):
`max_join_time` is the maximum time (in seconds) to wait for client
responses after sending CLOSE frames, it defaults to 2 seconds.
`backlog_size` is directly passed to `websocket.listen`.
"""
logging.basicConfig(level=loglevel,
format='%(asctime)s: %(levelname)s: %(message)s',
......@@ -69,14 +71,14 @@ class Server(object):
self.sock.enable_ssl(server_side=True, **ssl_args)
self.sock.bind(address)
self.sock.listen(5)
self.clients = []
self.client_threads = []
self.sock.listen(backlog_size)
self.max_join_time = max_join_time
def run(self):
self.clients = []
self.client_threads = []
while True:
try:
sock, address = self.sock.accept()
......@@ -134,30 +136,22 @@ class Server(object):
self.onclose(client, code, reason)
def onopen(self, client):
logging.debug('Opened socket to %s', client)
return NotImplemented
def onmessage(self, client, message):
logging.debug('Received %s from %s', message, client)
return NotImplemented
def onping(self, client, payload):
logging.debug('Sent ping "%s" to %s', payload, client)
return NotImplemented
def onpong(self, client, payload):
logging.debug('Received pong "%s" from %s', payload, client)
return NotImplemented
def onclose(self, client, code, reason):
msg = 'Closed socket to %s' % client
if code is not None:
msg += ' [%d]' % code
if len(reason):
msg += ': ' + reason
logging.debug(msg)
return NotImplemented
def onerror(self, client, e):
logging.error(format_exc(e))
return NotImplemented
class Client(Connection):
......@@ -176,21 +170,32 @@ class Client(Connection):
Connection.send(self, message, fragment_size=fragment_size, mask=mask)
def onopen(self):
logging.debug('Opened socket to %s', self)
self.server.onopen(self)
def onmessage(self, message):
logging.debug('Received %s from %s', message, self)
self.server.onmessage(self, message)
def onping(self, payload):
logging.debug('Sent ping "%s" to %s', payload, self)
self.server.onping(self, payload)
def onpong(self, payload):
logging.debug('Received pong "%s" from %s', payload, self)
self.server.onpong(self, payload)
def onclose(self, code, reason):
msg = 'Closed socket to %s' % self
if code is not None:
msg += ': [%d] %s' % (code, reason)
logging.debug(msg)
self.server.remove_client(self, code, reason)
def onerror(self, e):
logging.error(format_exc(e))
self.server.onerror(self, e)
......
#!/usr/bin/env python
import sys
import ssl
from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..')
......@@ -20,7 +21,7 @@ class EchoClient(Connection):
def onmessage(self, msg):
print 'Received', msg
raise SocketClosed(None, 'response received')
self.close(None, 'response received')
def onerror(self, e):
print 'Error:', e
......@@ -29,8 +30,15 @@ class EchoClient(Connection):
print 'Connection closed'
secure = True
if __name__ == '__main__':
print 'Connecting to ws://%s:%d' % ADDR
scheme = 'wss' if secure else 'ws'
print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
sock = websocket()
if secure:
sock.enable_ssl(ca_certs='cert.pem', cert_reqs=ssl.CERT_REQUIRED)
sock.connect(ADDR)
EchoClient(sock).receive_forever()
#!/usr/bin/env python
import sys
import socket
from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..')
sys.path.insert(0, basepath)
from websocket import websocket
from connection import Connection
from message import TextMessage
from errors import SocketClosed
if __name__ == '__main__':
if len(sys.argv) < 3:
print >> sys.stderr, 'Usage: python %s HOST PORT' % sys.argv[0]
sys.exit(1)
host = sys.argv[1]
port = int(sys.argv[2])
sock = websocket()
sock.connect((host, port))
sock.settimeout(1.0)
conn = Connection(sock)
try:
try:
while True:
msg = TextMessage(raw_input())
print 'send:', msg
conn.send(msg)
try:
print 'recv:', conn.recv()
except socket.timeout:
print 'no response'
except EOFError:
conn.close()
except SocketClosed as e:
if e.initialized:
print 'closed connection'
else:
print 'other side closed connection'
import socket
import ssl
from frame import receive_frame
from frame import receive_frame, pop_frame, contains_frame
from handshake import ServerHandshake, ClientHandshake
from errors import SSLError
......@@ -11,7 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
'proto']
class websocket(object):
"""
Implementation of web socket, upgrades a regular TCP socket to a websocket
......@@ -36,22 +35,23 @@ class websocket(object):
>>> sock.connect(('', 8000))
>>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
"""
def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
def __init__(self, sock=None, origin=None, protocols=[], extensions=[],
location='/', trusted_origins=[], locations=[], auth=None,
sfamily=socket.AF_INET, sproto=0):
recv_callback=None, sfamily=socket.AF_INET, sproto=0):
"""
Create a regular TCP socket of family `family` and protocol
`sock` is an optional regular TCP socket to be used for sending binary
data. If not specified, a new socket is created.
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extensions (`Extension` instances).
`origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake .
`protocols` is a list of supported protocol names.
`extensions` (for server sockets) is a list of supported extensions
(`Extension` instances).
`location` (for client sockets) is optional, used to request a
particular resource in the HTTP handshake. In a URL, this would show as
ws://host[:port]/<location>. Use this when the server serves multiple
......@@ -71,10 +71,17 @@ class websocket(object):
`auth` is optional, used for HTTP Basic or Digest authentication during
the handshake. It must be specified as a (username, password) tuple.
`recv_callback` is the callback for received frames in asynchronous
sockets. Use in conjunction with setblocking(0). The callback itself
may for example change the recv_callback attribute to change the
behaviour for the next received message. Can be set when calling
`queue_send`.
`sfamily` and `sproto` are used for the regular socket constructor.
"""
self.protocols = protocols
self.extensions = extensions
self.extension_hooks = []
self.origin = origin
self.location = location
self.trusted_origins = trusted_origins
......@@ -85,11 +92,16 @@ class websocket(object):
self.handshake_sent = False
self.hooks_send = []
self.hooks_recv = []
self.sendbuf_frames = []
self.sendbuf = ''
self.recvbuf = ''
self.recv_callback = recv_callback
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
def set_extensions(self, extensions):
self.extensions = [ext.Hook() for ext in extensions]
def __getattr__(self, name):
if name in INHERITED_ATTRS:
return getattr(self.sock, name)
......@@ -122,29 +134,31 @@ class websocket(object):
ClientHandshake(self).perform()
self.handshake_sent = True
def apply_send_hooks(self, frame):
for hook in self.extension_hooks:
frame = hook.send(frame)
return frame
def apply_recv_hooks(self, frame):
for hook in reversed(self.extension_hooks):
frame = hook.recv(frame)
return frame
def send(self, *args):
"""
Send a number of frames.
"""
for frame in args:
for hook in self.hooks_send:
frame = hook(frame)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self.sock.sendall(frame.pack())
self.sock.sendall(self.apply_send_hooks(frame).pack())
def recv(self):
"""
Receive a single frames. This can be either a data frame or a control
frame.
"""
frame = receive_frame(self.sock)
for hook in self.hooks_recv:
frame = hook(frame)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return frame
return self.apply_recv_hooks(receive_frame(self.sock))
def recvn(self, n):
"""
......@@ -153,47 +167,79 @@ class websocket(object):
"""
return [self.recv() for i in xrange(n)]
def enable_ssl(self, *args, **kwargs):
def queue_send(self, frame, callback=None, recv_callback=None):
"""
Transforms the regular socket.socket to an ssl.SSLSocket for secure
connections. Any arguments are passed to ssl.wrap_socket:
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
Enqueue `frame` to the send buffer so that it is send on the next
`do_async_send`. `callback` is an optional callable to call when the
frame has been fully written. `recv_callback` is an optional callable
to quickly set the `recv_callback` attribute to.
"""
if self.handshake_sent:
raise SSLError('can only enable SSL before handshake')
frame = self.apply_send_hooks(frame)
self.sendbuf += frame.pack()
self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
self.secure = True
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
if recv_callback:
self.recv_callback = recv_callback
def add_hook(self, send=None, recv=None, prepend=False):
def do_async_send(self):
"""
Add a pair of send and receive hooks that are called for each frame
that is sent or received. A hook is a function that receives a single
argument - a Frame instance - and returns a `Frame` instance as well.
Send any queued data. This function should only be called after a write
event on a file descriptor.
"""
assert len(self.sendbuf)
`prepend` is a flag indicating whether the send hook is prepended to
the other send hooks. This is expecially useful when a program uses
extensions such as the built-in `DeflateFrame` extension. These
extensions are installed using these hooks as well.
nwritten = self.sock.send(self.sendbuf)
nframes = 0
For example, the following code creates a `Frame` instance for data
being sent and removes the instance for received data. This way, data
can be sent and received as if on a regular socket.
>>> import wspy
>>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
>>> lambda frame: frame.payload)
for entry in self.sendbuf_frames:
frame, offset, callback = entry
To add base64 encoding to the example above:
>>> import base64
>>> sock.add_hook(base64.encodestring, base64.decodestring, True)
if offset <= nwritten:
nframes += 1
Note that here `prepend=True`, so that data passed to `send()` is first
encoded and then packed into a frame. Of course, one could also decide
to add the base64 hook first, or to return a new `Frame` instance with
base64-encoded data.
if callback:
callback()
else:
entry[1] -= nwritten
self.sendbuf = self.sendbuf[nwritten:]
self.sendbuf_frames = self.sendbuf_frames[nframes:]
def do_async_recv(self, bufsize):
"""
Receive any completed frames from the socket. This function should only
be called after a read event on a file descriptor.
"""
if send:
self.hooks_send.insert(0 if prepend else -1, send)
data = self.sock.recv(bufsize)
if len(data) == 0:
raise socket.error('no data to receive')
self.recvbuf += data
while contains_frame(self.recvbuf):
frame, self.recvbuf = pop_frame(self.recvbuf)
frame = self.apply_recv_hooks(frame)
if not self.recv_callback:
raise ValueError('no callback installed for %s' % frame)
self.recv_callback(frame)
def can_send(self):
return len(self.sendbuf) > 0
if recv:
self.hooks_recv.insert(-1 if prepend else 0, recv)
def can_recv(self):
return self.recv_callback is not None
def enable_ssl(self, *args, **kwargs):
"""
Transforms the regular socket.socket to an ssl.SSLSocket for secure
connections. Any arguments are passed to ssl.wrap_socket:
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
"""
if self.handshake_sent:
raise SSLError('can only enable SSL before handshake')
self.secure = True
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment