Commit a55b2fad authored by Taddeüs Kroes's avatar Taddeüs Kroes

Rewrote websocket + connection to add support for asynchronous sockets, added...

Rewrote websocket + connection to add support for asynchronous sockets, added asynchrounous connecten & server
parent bd447183
......@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library:
- Secure sockets using SSL certificates (for 'wss://...' URLs).
- The possibility to add extensions to the web socket protocol. An included
implementation is [deflate-frame](http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06).
- Asynchronous sockets with an EPOLL-based server.
Installation
......
from websocket import websocket, STATE_INIT, STATE_READ, STATE_WRITE, \
STATE_CLOSE
from websocket import websocket
from server import Server
from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
......@@ -12,3 +11,4 @@ from message import Message, TextMessage, BinaryMessage
from errors import SocketClosed, HandshakeError, PingError, SSLError
from extension import Extension
from deflate_frame import DeflateFrame, WebkitDeflateFrame
from async import AsyncConnection, AsyncServer
import socket
from select import epoll, EPOLLIN, EPOLLOUT, EPOLLHUP
from traceback import format_exc
import logging
from connection import Connection
from frame import ControlFrame, OPCODE_PING, OPCODE_CONTINUATION, \
create_close_frame
from server import Server, Client
from errors import HandshakeError, SocketClosed
class AsyncConnection(Connection):
def __init__(self, sock):
sock.recv_callback = self.contruct_message
sock.recv_close_callback = self.onclose
self.recvbuf = []
Connection.__init__(self, sock)
def contruct_message(self, frame):
if isinstance(frame, ControlFrame):
self.handle_control_frame(frame)
return
self.recvbuf.append(frame)
if frame.final:
message = self.concat_fragments(self.recvbuf)
self.recvbuf = []
self.onmessage(message)
elif len(self.recvbuf) > 1 and frame.opcode != OPCODE_CONTINUATION:
raise ValueError('expected continuation/control frame, got %s '
'instead' % frame)
def send(self, message, fragment_size=None, mask=False):
frames = self.message_to_frames(message, fragment_size, mask)
for frame in frames[:-1]:
self.sock.queue_send(frame)
self.sock.queue_send(frames[-1], lambda: self.onsend(message))
def send_frame(self, frame, callback):
self.sock.queue_send(frame, callback)
def do_async_send(self):
self.execute_controlled(self.sock.do_async_send)
def do_async_recv(self, bufsize):
self.execute_controlled(self.sock.do_async_recv, bufsize)
def execute_controlled(self, func, *args, **kwargs):
try:
func(*args, **kwargs)
except (KeyboardInterrupt, SystemExit, SocketClosed):
raise
except Exception as e:
self.onerror(e)
self.onclose(None, 'error: %s' % e)
try:
self.sock.close()
except socket.error:
pass
raise e
def send_close_frame(self, code, reason):
self.sock.queue_send(create_close_frame(code, reason),
self.shutdown_write)
self.close_frame_sent = True
def close(self, code=None, reason=''):
self.send_close_frame(code, reason)
def send_ping(self, payload=''):
self.sock.queue_send(ControlFrame(OPCODE_PING, payload),
lambda: self.onping(payload))
self.ping_payload = payload
self.ping_sent = True
def onsend(self, message):
"""
Called after a message has been written.
"""
return NotImplemented
class AsyncServer(Server):
def __init__(self, *args, **kwargs):
Server.__init__(self, *args, **kwargs)
self.recvbuf_size = kwargs.get('recvbuf_size', 2048)
self.epoll = epoll()
self.epoll.register(self.sock.fileno(), EPOLLIN)
self.conns = {}
@property
def clients(self):
return self.conns.itervalues()
def remove_client(self, client, code, reason):
self.epoll.unregister(client.fno)
del self.conns[client.fno]
self.onclose(client, code, reason)
def handle_events(self):
for fileno, event in self.epoll.poll(1):
if fileno == self.sock.fileno():
try:
sock, addr = self.sock.accept()
except HandshakeError as e:
logging.error('Invalid request: %s', e.message)
continue
client = AsyncClient(self, sock)
client.fno = sock.fileno()
sock.setblocking(0)
self.epoll.register(client.fno, EPOLLIN)
self.conns[client.fno] = client
logging.debug('Registered client %s', client)
elif event & EPOLLHUP:
self.epoll.unregister(fileno)
del self.conns[fileno]
else:
conn = self.conns[fileno]
try:
if event & EPOLLOUT:
conn.do_async_send()
elif event & EPOLLIN:
conn.do_async_recv(self.recvbuf_size)
except (KeyboardInterrupt, SystemExit):
raise
except SocketClosed:
continue
except Exception as e:
logging.error(format_exc(e))
continue
mask = 0
if conn.sock.can_send():
mask |= EPOLLOUT
if conn.sock.can_recv():
mask |= EPOLLIN
self.epoll.modify(fileno, mask)
def run(self):
try:
while True:
self.handle_events()
except (KeyboardInterrupt, SystemExit):
logging.info('Received interrupt, stopping server...')
finally:
self.epoll.unregister(self.sock.fileno())
self.epoll.close()
self.sock.close()
def onsend(self, client, message):
logging.debug('Written "%s" to %s', message, client)
class AsyncClient(Client, AsyncConnection):
def __init__(self, server, sock):
self.server = server
AsyncConnection.__init__(self, sock)
def send(self, message, fragment_size=None, mask=False):
logging.debug('Enqueueing %s to %s', message, self)
AsyncConnection.send(self, message, fragment_size, mask)
def onsend(self, message):
self.server.onsend(self, message)
if __name__ == '__main__':
import sys
port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
AsyncServer(('', port), loglevel=logging.DEBUG).run()
import struct
import socket
from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
OPCODE_CONTINUATION
OPCODE_CONTINUATION, create_close_frame
from message import create_message
from errors import SocketClosed, PingError
......@@ -54,19 +53,30 @@ class Connection(object):
self.onopen()
def message_to_frames(self, message, fragment_size=None, mask=False):
for hook in self.hooks_send:
message = hook(message)
if fragment_size is None:
yield message.frame(mask=mask)
else:
for frame in message.fragment(fragment_size, mask=mask):
yield frame
def send(self, message, fragment_size=None, mask=False):
"""
Send a message. If `fragment_size` is specified, the message is
fragmented into multiple frames whose payload size does not extend
`fragment_size`.
"""
for hook in self.hooks_send:
message = hook(message)
for frame in self.message_to_frames(message, fragment_size, mask):
self.send_frame(frame)
if fragment_size is None:
self.sock.send(message.frame(mask=mask))
else:
self.sock.send(*message.fragment(fragment_size, mask=mask))
def send_frame(self, frame, callback=None):
self.sock.send(frame)
if callback:
callback()
def recv(self):
"""
......@@ -82,12 +92,15 @@ class Connection(object):
if isinstance(frame, ControlFrame):
self.handle_control_frame(frame)
elif len(fragments) and frame.opcode != OPCODE_CONTINUATION:
elif len(fragments) > 0 and frame.opcode != OPCODE_CONTINUATION:
raise ValueError('expected continuation/control frame, got %s '
'instead' % frame)
else:
fragments.append(frame)
return self.concat_fragments(fragments)
def concat_fragments(self, fragments):
payload = bytearray()
for f in fragments:
......@@ -105,16 +118,20 @@ class Connection(object):
Handle a control frame as defined by RFC 6455.
"""
if frame.opcode == OPCODE_CLOSE:
# Close the connection from this end as well
self.close_frame_received = True
code, reason = frame.unpack_close()
# No more receiving data after a close message
raise SocketClosed(code, reason)
if self.close_frame_sent:
self.onclose(code, reason)
self.sock.close()
raise SocketClosed(True)
else:
self.close_params = (code, reason)
self.send_close_frame(code, reason)
elif frame.opcode == OPCODE_PING:
# Respond with a pong message with identical payload
self.sock.send(ControlFrame(OPCODE_PONG, frame.payload))
self.send_frame(ControlFrame(OPCODE_PONG, frame.payload))
elif frame.opcode == OPCODE_PONG:
# Assert that the PONG payload is identical to that of the PING
......@@ -138,38 +155,40 @@ class Connection(object):
while True:
try:
self.onmessage(self.recv())
except SocketClosed as e:
self.close(e.code, e.reason)
except (KeyboardInterrupt, SystemExit, SocketClosed):
break
except socket.error as e:
except Exception as e:
self.onerror(e)
self.onclose(None, 'error: %s' % e)
try:
self.sock.close()
except socket.error:
pass
self.onclose(None, '')
break
except Exception as e:
self.onerror(e)
raise e
def send_ping(self, payload=''):
"""
Send a PING control frame with an optional payload.
"""
self.sock.send(ControlFrame(OPCODE_PING, payload))
self.send_frame(ControlFrame(OPCODE_PING, payload),
lambda: self.onping(payload))
self.ping_payload = payload
self.ping_sent = True
self.onping(payload)
def send_close_frame(self, code=None, reason=''):
"""
Send a CLOSE control frame.
"""
payload = '' if code is None else struct.pack('!H', code) + reason
self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
def send_close_frame(self, code, reason):
self.send_frame(create_close_frame(code, reason))
self.close_frame_sent = True
self.shutdown_write()
def shutdown_write(self):
if self.close_frame_received:
self.onclose(*self.close_params)
self.sock.close()
raise SocketClosed(False)
else:
self.sock.shutdown(socket.SHUT_WR)
def close(self, code=None, reason=''):
"""
......@@ -179,28 +198,14 @@ class Connection(object):
called after the response has been received, but before the socket is
actually closed.
"""
# Send CLOSE frame
if not self.close_frame_sent:
self.send_close_frame(code, reason)
# Receive CLOSE frame
if not self.close_frame_received:
frame = self.sock.recv()
if frame.opcode != OPCODE_CLOSE:
raise ValueError('expected CLOSE frame, got %s' % frame)
self.close_frame_received = True
res_code, res_reason = frame.unpack_close()
self.send_close_frame(code, reason)
# FIXME: check if res_code == code and res_reason == reason?
frame = self.sock.recv()
# FIXME: alternatively, keep receiving frames in a loop until a
# CLOSE frame is received, so that a fragmented chain may arrive
# fully first
if frame.opcode != OPCODE_CLOSE:
raise ValueError('expected CLOSE frame, got %s' % frame)
self.onclose(code, reason)
self.sock.close()
self.handle_control_frame(frame)
def add_hook(self, send=None, recv=None, prepend=False):
"""
......
class SocketClosed(Exception):
def __init__(self, code=None, reason=''):
self.code = code
self.reason = reason
def __init__(self, initialized):
self.initialized = initialized
@property
def message(self):
return ('' if self.code is None else '[%d] ' % self.code) + self.reason
s = 'socket closed'
if self.initialized:
s += ' (initialized)'
return s
class HandshakeError(Exception):
......
......@@ -232,11 +232,12 @@ def receive_frame(sock):
def read_frame(data):
reader = BufferReader(data)
frame = decode_frame(reader)
return frame, len(data) - reader.offset
return frame, reader.offset
def pop_frame(data):
frame, l = read_frame(data)
frame, size = read_frame(data)
return frame, data[size:]
class BufferReader(object):
......@@ -279,7 +280,7 @@ def contains_frame(data):
if len(data) < 2:
return False
b2 = struct.unpack('!B', data[1])
b2 = struct.unpack('!B', data[1])[0]
payload_len = b2 & 0x7F
payload_start = 2
......@@ -316,3 +317,8 @@ def mask(key, original):
masked[i] ^= key[i % 4]
return masked
def create_close_frame(code, reason):
payload = '' if code is None else struct.pack('!H', code) + reason
return ControlFrame(OPCODE_CLOSE, payload)
......@@ -25,14 +25,14 @@ class Server(object):
>>> print 'Received message "%s"' % message.payload
>>> client.send(wspy.TextMessage(message.payload))
>>> def onclose(self, client):
>>> def onclose(self, client, code, reason):
>>> print 'Client %s disconnected' % client
>>> EchoServer(('', 8000)).run()
"""
def __init__(self, address, loglevel=logging.INFO, ssl_args=None,
max_join_time=2.0, **kwargs):
max_join_time=2.0, backlog_size=32, **kwargs):
"""
Constructor for a simple web socket server.
......@@ -53,6 +53,8 @@ class Server(object):
`max_join_time` is the maximum time (in seconds) to wait for client
responses after sending CLOSE frames, it defaults to 2 seconds.
`backlog_size` is directly passed to `websocket.listen`.
"""
logging.basicConfig(level=loglevel,
format='%(asctime)s: %(levelname)s: %(message)s',
......@@ -69,14 +71,14 @@ class Server(object):
self.sock.enable_ssl(server_side=True, **ssl_args)
self.sock.bind(address)
self.sock.listen(5)
self.clients = []
self.client_threads = []
self.sock.listen(backlog_size)
self.max_join_time = max_join_time
def run(self):
self.clients = []
self.client_threads = []
while True:
try:
sock, address = self.sock.accept()
......@@ -149,10 +151,7 @@ class Server(object):
msg = 'Closed socket to %s' % client
if code is not None:
msg += ' [%d]' % code
if len(reason):
msg += ': ' + reason
msg += ': [%d] %s' % (code, reason)
logging.debug(msg)
......
#!/usr/bin/env python
import sys
import ssl
from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..')
......@@ -20,7 +21,7 @@ class EchoClient(Connection):
def onmessage(self, msg):
print 'Received', msg
raise SocketClosed(None, 'response received')
self.close(None, 'response received')
def onerror(self, e):
print 'Error:', e
......@@ -29,8 +30,15 @@ class EchoClient(Connection):
print 'Connection closed'
secure = True
if __name__ == '__main__':
print 'Connecting to ws://%s:%d' % ADDR
scheme = 'wss' if secure else 'ws'
print 'Connecting to %s://%s' % (scheme, '%s:%d' % ADDR)
sock = websocket()
if secure:
sock.enable_ssl(ca_certs='cert.pem', cert_reqs=ssl.CERT_REQUIRED)
sock.connect(ADDR)
EchoClient(sock).receive_forever()
......@@ -20,5 +20,5 @@ if __name__ == '__main__':
deflate = WebkitDeflateFrame()
#deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True})
EchoServer(('localhost', 8000), extensions=[deflate],
#ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
loglevel=logging.DEBUG).run()
......@@ -11,12 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
'settimeout', 'gettimeout', 'shutdown', 'family', 'type',
'proto']
STATE_INIT = 0
STATE_READ = 1
STATE_WRITE = 2
STATE_CLOSE = 4
class websocket(object):
"""
Implementation of web socket, upgrades a regular TCP socket to a websocket
......@@ -43,7 +37,7 @@ class websocket(object):
"""
def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
location='/', trusted_origins=[], locations=[], auth=None,
sfamily=socket.AF_INET, sproto=0):
recv_callback=None, sfamily=socket.AF_INET, sproto=0):
"""
Create a regular TCP socket of family `family` and protocol
......@@ -76,6 +70,12 @@ class websocket(object):
`auth` is optional, used for HTTP Basic or Digest authentication during
the handshake. It must be specified as a (username, password) tuple.
`recv_callback` is the callback for received frames in asynchronous
sockets. Use in conjunction with setblocking(0). The callback itself
may for example change the recv_callback attribute to change the
behaviour for the next received message. Can be set when calling
`queue_send`.
`sfamily` and `sproto` are used for the regular socket constructor.
"""
self.protocols = protocols
......@@ -93,10 +93,10 @@ class websocket(object):
self.hooks_send = []
self.hooks_recv = []
self.state = STATE_INIT
self.sendbuf_frames = []
self.sendbuf = ''
self.recvbuf = ''
self.recv_callbacks = []
self.recv_callback = recv_callback
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
......@@ -163,61 +163,69 @@ class websocket(object):
"""
return [self.recv() for i in xrange(n)]
def queue_send(self, frame):
def queue_send(self, frame, callback=None, recv_callback=None):
"""
Enqueue `frame` to the send buffer so that it is send on the next
`do_async_send`.
`do_async_send`. `callback` is an optional callable to call when the
frame has been fully written. `recv_callback` is an optional callable
to quickly set the `recv_callback` attribute to.
"""
for hook in self.hooks_send:
frame = hook(frame)
self.sendbuf += frame.pack()
self.state |= STATE_WRITE
self.sendbuf_frames.append([frame, len(self.sendbuf), callback])
def queue_recv(self, callback):
"""
Enqueue `callback` to be called when the next frame is recieved by
`do_async_recv`.
"""
self.recv_callbacks.push(callback)
self.state |= STATE_READ
def queue_close(self):
self.state |= STATE_CLOSE
if recv_callback:
self.recv_callback = recv_callback
def do_async_send(self):
"""
Send any queued data. If all data is sent, STATE_WRITE is removed from
the state mask.
Send any queued data.
"""
assert self.state & STATE_WRITE
assert len(self.sendbuf)
nwritten = self.sock.send(self.sendbuf)
self.sendbuf = self.sendbuf[nwritten:]
nframes = 0
if len(self.sendbuf) == 0:
self.state ^= STATE_WRITE
for entry in self.sendbuf_frames:
frame, offset, callback = entry
if offset <= nwritten:
nframes += 1
if callback:
print 'write cb'
callback()
else:
entry[1] -= nwritten
self.sendbuf = self.sendbuf[nwritten:]
self.sendbuf_frames = self.sendbuf_frames[nframes:]
def do_async_recv(self, bufsize):
"""
"""
assert self.state & STATE_READ
data = self.sock.recv(bufsize)
if len(data) == 0:
raise socket.error('no data to receive')
self.recvbuf += self.sock.recv(bufsize)
self.recvbuf += data
while contains_frame(self.recvbuf):
frame, self.recvbuf = pop_frame(self.recvbuf)
if len(self.recv_callbacks) == 0:
raise IndexError('no callback installed for received frame %s'
% frame)
if not self.recv_callback:
raise ValueError('no callback installed for %s' % frame)
self.recv_callback(frame)
cb = self.recv_callbacks.pop(0)
cb(frame)
def can_send(self):
return len(self.sendbuf) > 0
if len(self.recvbuf) == 0:
self.state ^= STATE_READ
def can_recv(self):
return self.recv_callback is not None
def enable_ssl(self, *args, **kwargs):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment