Commit fa6f57d6 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Started testing and debugging:

- Renamed exceptions.py to errors.py to avoid name collision with stdlib
- Fixed some variable name collisions
- Worked on improving close handshake
- More minor fixes
parent 6807fc3f
...@@ -3,7 +3,7 @@ import struct ...@@ -3,7 +3,7 @@ import struct
from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \ from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
OPCODE_CONTINUATION OPCODE_CONTINUATION
from message import create_message from message import create_message
from exceptions import SocketClosed, PingError from errors import SocketClosed, PingError
class Connection(object): class Connection(object):
...@@ -21,8 +21,8 @@ class Connection(object): ...@@ -21,8 +21,8 @@ class Connection(object):
""" """
self.sock = sock self.sock = sock
self.received_close_params = None
self.close_frame_sent = False self.close_frame_sent = False
self.close_frame_received = False
self.ping_sent = False self.ping_sent = False
self.ping_payload = None self.ping_payload = None
...@@ -54,17 +54,17 @@ class Connection(object): ...@@ -54,17 +54,17 @@ class Connection(object):
if isinstance(frame, ControlFrame): if isinstance(frame, ControlFrame):
self.handle_control_frame(frame) self.handle_control_frame(frame)
# No more receiving data after a close message
if frame.opcode == OPCODE_CLOSE:
break
elif len(fragments) and frame.opcode != OPCODE_CONTINUATION: elif len(fragments) and frame.opcode != OPCODE_CONTINUATION:
raise ValueError('expected continuation/control frame, got %s ' raise ValueError('expected continuation/control frame, got %s '
'instead' % frame) 'instead' % frame)
else: else:
fragments.append(frame) fragments.append(frame)
payload = ''.join(f.payload for f in fragments) payload = bytearray()
for f in fragments:
payload += f.payload
return create_message(fragments[0].opcode, payload) return create_message(fragments[0].opcode, payload)
def handle_control_frame(self, frame): def handle_control_frame(self, frame):
...@@ -72,10 +72,20 @@ class Connection(object): ...@@ -72,10 +72,20 @@ class Connection(object):
Handle a control frame as defined by RFC 6455. Handle a control frame as defined by RFC 6455.
""" """
if frame.opcode == OPCODE_CLOSE: if frame.opcode == OPCODE_CLOSE:
# Set parameters and keep receiving the current fragmented frame # Handle a close message by sending a response close message if no
# chain, assuming that the CLOSE frame will be handled by # CLOSE frame was sent before, and closing the connection. The
# handle_close() as soon as possible # onclose() handler is called afterwards.
self.received_close_params = frame.unpack_close() self.close_frame_received = True
code, reason = frame.unpack_close()
if not self.close_frame_sent:
payload = '' if code is None else struct.pack('!H', code)
self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
self.sock.close()
# No more receiving data after a close message
raise SocketClosed(code, reason)
elif frame.opcode == OPCODE_PING: elif frame.opcode == OPCODE_PING:
# Respond with a pong message with identical payload # Respond with a pong message with identical payload
...@@ -103,12 +113,8 @@ class Connection(object): ...@@ -103,12 +113,8 @@ class Connection(object):
while True: while True:
try: try:
self.onmessage(self.receive()) self.onmessage(self.receive())
except SocketClosed as e:
if self.received_close_params is not None: self.onclose(e.code, e.reason)
self.handle_close(*self.received_close_params)
break
except SocketClosed:
self.onclose(None, '')
break break
except Exception as e: except Exception as e:
self.onerror(e) self.onerror(e)
...@@ -150,14 +156,24 @@ class Connection(object): ...@@ -150,14 +156,24 @@ class Connection(object):
has been sent, but before the response has been received. has been sent, but before the response has been received.
""" """
self.send_close(code, reason) self.send_close(code, reason)
# FIXME: swap the two lines below?
self.onclose(code, reason) if not self.close_frame_received:
frame = self.sock.recv() frame = self.sock.recv()
self.sock.close()
if frame.opcode != OPCODE_CLOSE: if frame.opcode != OPCODE_CLOSE:
raise ValueError('expected CLOSE frame, got %s instead' % frame) raise ValueError('expected CLOSE frame, got %s instead' % frame)
res_code, res_reason = frame.unpack_close()
# FIXME: check if res_code == code and res_reason == reason?
# FIXME: alternatively, keep receiving frames in a loop until a
# CLOSE frame is received, so that a fragmented chain may arrive
# fully first
self.sock.close()
self.onclose(code, reason)
def onopen(self): def onopen(self):
""" """
Called after the connection is initialized. Called after the connection is initialized.
......
class SocketClosed(Exception): class SocketClosed(Exception):
pass def __init__(self, code=None, reason=''):
self.code = code
self.reason = reason
@property
def message(self):
return ('' if self.code is None else '[%d] ' % self.code) + self.reason
class HandshakeError(Exception): class HandshakeError(Exception):
......
import struct import struct
from os import urandom from os import urandom
from curses.ascii import isprint
from exceptions import SocketClosed from errors import SocketClosed
OPCODE_CONTINUATION = 0x0 OPCODE_CONTINUATION = 0x0
...@@ -82,7 +83,7 @@ class Frame(object): ...@@ -82,7 +83,7 @@ class Frame(object):
| Payload Data continued ... | | Payload Data continued ... |
+---------------------------------------------------------------+ +---------------------------------------------------------------+
""" """
header = struct.pack('!B', (self.fin << 7) | (self.rsv1 << 6) | header = struct.pack('!B', (self.final << 7) | (self.rsv1 << 6) |
(self.rsv2 << 5) | (self.rsv3 << 4) | self.opcode) (self.rsv2 << 5) | (self.rsv3 << 4) | self.opcode)
mask = bool(self.masking_key) << 7 mask = bool(self.masking_key) << 7
...@@ -142,7 +143,8 @@ class Frame(object): ...@@ -142,7 +143,8 @@ class Frame(object):
% (self.__class__.__name__, self.opcode, len(self.payload)) % (self.__class__.__name__, self.opcode, len(self.payload))
if self.masking_key: if self.masking_key:
s += ' masking_key=%4s' % self.masking_key key = ''.join(c if isprint(c) else '.' for c in self.masking_key)
s += ' masking_key=%4s' % key
return s + '>' return s + '>'
...@@ -197,7 +199,7 @@ def receive_frame(sock): ...@@ -197,7 +199,7 @@ def receive_frame(sock):
rsv3 = bool(b1 & 0x10) rsv3 = bool(b1 & 0x10)
opcode = b1 & 0x0F opcode = b1 & 0x0F
mask = bool(b2 & 0x80) masked = bool(b2 & 0x80)
payload_len = b2 & 0x7F payload_len = b2 & 0x7F
if payload_len == 126: if payload_len == 126:
...@@ -205,7 +207,7 @@ def receive_frame(sock): ...@@ -205,7 +207,7 @@ def receive_frame(sock):
elif payload_len == 127: elif payload_len == 127:
payload_len = struct.unpack('!Q', recvn(sock, 8)) payload_len = struct.unpack('!Q', recvn(sock, 8))
if mask: if masked:
masking_key = recvn(sock, 4) masking_key = recvn(sock, 4)
payload = mask(masking_key, recvn(sock, payload_len)) payload = mask(masking_key, recvn(sock, payload_len))
else: else:
...@@ -229,7 +231,7 @@ def recvn(sock, n): ...@@ -229,7 +231,7 @@ def recvn(sock, n):
received = sock.recv(n - len(data)) received = sock.recv(n - len(data))
if not len(received): if not len(received):
raise SocketClosed() raise SocketClosed(None, 'no data read from socket')
data += received data += received
......
...@@ -24,7 +24,15 @@ class Message(object): ...@@ -24,7 +24,15 @@ class Message(object):
class TextMessage(Message): class TextMessage(Message):
def __init__(self, payload): def __init__(self, payload):
super(TextMessage, self).__init__(OPCODE_TEXT, payload.encode('utf-8')) text = str(payload).encode('utf-8')
super(TextMessage, self).__init__(OPCODE_TEXT, text)
def __str__(self):
if len(self.payload) > 30:
return '<TextMessage "%s"... size=%d>' \
% (self.payload[:30], len(self.payload))
return '<TextMessage "%s" size=%d>' % (self.payload, len(self.payload))
class BinaryMessage(Message): class BinaryMessage(Message):
......
...@@ -6,7 +6,7 @@ from threading import Thread ...@@ -6,7 +6,7 @@ from threading import Thread
from websocket import websocket from websocket import websocket
from connection import Connection from connection import Connection
from frame import CLOSE_NORMAL from frame import CLOSE_NORMAL
from exceptions import HandshakeError from errors import HandshakeError
class Server(object): class Server(object):
...@@ -48,9 +48,9 @@ class Server(object): ...@@ -48,9 +48,9 @@ class Server(object):
try: try:
sock, address = self.sock.accept() sock, address = self.sock.accept()
client = Client(self, sock, address) client = Client(self, sock)
self.clients.append(client) self.clients.append(client)
logging.info('Registered client %s', client) logging.debug('Registered client %s', client)
thread = Thread(target=client.receive_forever) thread = Thread(target=client.receive_forever)
thread.daemon = True thread.daemon = True
...@@ -102,8 +102,15 @@ class Server(object): ...@@ -102,8 +102,15 @@ class Server(object):
class Client(Connection): class Client(Connection):
def __init__(self, server, sock): def __init__(self, server, sock):
super(Client, self).__init__(sock)
self.server = server self.server = server
super(Client, self).__init__(sock)
def __str__(self):
return '<Client at %s:%d>' % self.sock.getpeername()
def send(self, message, fragment_size=None, mask=False):
logging.debug('Sending %s to %s', message, self)
Connection.send(self, message, fragment_size=fragment_size, mask=mask)
def onopen(self): def onopen(self):
self.server.onopen(self) self.server.onopen(self)
...@@ -123,9 +130,6 @@ class Client(Connection): ...@@ -123,9 +130,6 @@ class Client(Connection):
def onerror(self, e): def onerror(self, e):
self.server.onerror(self, e) self.server.onerror(self, e)
def __str__(self):
return '<Client at %s:%d>' % self.sock.getpeername()
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
......
<!doctype html>
<html>
<head>
<title>twspy echo test client</title>
</head>
<body>
<textarea id="log" rows="20" cols="80" readonly="readonly"></textarea>
<script type="text/javascript">
function log(line) {
document.getElementById('log').innerHTML += line + '\n';
}
var URL = 'localhost:8000';
log('Connecting to ' + URL);
var ws = new WebSocket('ws://' + URL);
ws.onopen = function() {
log('Connection complete, sending "foo"');
ws.send('foo');
};
ws.onmessage = function(msg) {
log('Received "' + msg + '", closing connection');
ws.close();
};
ws.onerror = function(e) {
log('Error', e);
};
ws.onclose = function() {
log('Connection closed');
};
</script>
</body>
</html>
#!/usr/bin/env python #!/usr/bin/env python
import logging
from server import Server
class EchoServer(Server):
def onmessage(self, client, message):
Server.onmessage(self, client, message)
client.send(message)
if __name__ == '__main__':
EchoServer(8000, 'localhost', loglevel=logging.DEBUG).run()
...@@ -3,7 +3,7 @@ import socket ...@@ -3,7 +3,7 @@ import socket
from hashlib import sha1 from hashlib import sha1
from frame import receive_frame from frame import receive_frame
from exceptions import HandshakeError from errors import HandshakeError
WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
...@@ -11,7 +11,7 @@ WS_VERSION = '13' ...@@ -11,7 +11,7 @@ WS_VERSION = '13'
def split_stripped(value, delim=','): def split_stripped(value, delim=','):
return map(str.strip, value.split(delim)) return map(str.strip, str(value).split(delim))
class websocket(object): class websocket(object):
...@@ -32,11 +32,14 @@ class websocket(object): ...@@ -32,11 +32,14 @@ class websocket(object):
>>> sock = websocket() >>> sock = websocket()
>>> sock.connect(('', 8000)) >>> sock.connect(('', 8000))
""" """
def __init__(self, protocols=[], extensions=[], sfamily=socket.AF_INET, def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET,
sproto=0): sproto=0):
""" """
Create a regular TCP socket of family `family` and protocol Create a regular TCP socket of family `family` and protocol
`sock` is an optional regular TCP socket to be used for sending binary
data. If not specified, a new socket is created.
`protocols` is a list of supported protocol names. `protocols` is a list of supported protocol names.
`extensions` is a list of supported extensions. `extensions` is a list of supported extensions.
...@@ -45,7 +48,7 @@ class websocket(object): ...@@ -45,7 +48,7 @@ class websocket(object):
""" """
self.protocols = protocols self.protocols = protocols
self.extensions = extensions self.extensions = extensions
self.sock = socket.socket(sfamily, socket.SOCK_STREAM, sproto) self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
def bind(self, address): def bind(self, address):
self.sock.bind(address) self.sock.bind(address)
...@@ -78,14 +81,24 @@ class websocket(object): ...@@ -78,14 +81,24 @@ class websocket(object):
Send a number of frames. Send a number of frames.
""" """
for frame in args: for frame in args:
print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername(), frame.payload
self.sock.sendall(frame.pack()) self.sock.sendall(frame.pack())
def recv(self, n=1): def recv(self):
"""
Receive a single frames. This can be either a data frame or a control
frame.
"""
frame = receive_frame(self.sock)
print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return frame
def recvn(self, n):
""" """
Receive exactly `n` frames. These can be either data frames or control Receive exactly `n` frames. These can be either data frames or control
frames, or a combination of both. frames, or a combination of both.
""" """
return [receive_frame(self.sock) for i in xrange(n)] return [self.recv() for i in xrange(n)]
def getpeername(self): def getpeername(self):
return self.sock.getpeername() return self.sock.getpeername()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment