Commit a992f352 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Imprroved handshake in terms of execptions, some minor bugfixes too

parent f094886a
......@@ -2,5 +2,9 @@ class SocketClosed(Exception):
pass
class InvalidRequest(ValueError):
pass
class PingError(Exception):
pass
......@@ -3,11 +3,11 @@ import logging
from traceback import format_exc
from websocket import WebSocket
from exceptions import InvalidRequest
class Server(object):
def __init__(self, port, address='', log_level=logging.INFO, protocols=[],
encoding=None):
def __init__(self, port, address='', log_level=logging.INFO, protocols=[]):
logging.basicConfig(level=log_level,
format='%(asctime)s: %(levelname)s: %(message)s',
datefmt='%H:%M:%S')
......@@ -20,17 +20,18 @@ class Server(object):
self.clients = []
self.protocols = protocols
self.encoding = encoding
def run(self):
while True:
try:
client_socket, address = self.sock.accept()
client = Client(self, client_socket, address)
client.send_handshake()
client.handshake()
self.clients.append(client)
logging.info('Registered client %s', client)
client.run_threaded()
except InvalidRequest as e:
logging.error('Invalid request: %s', e.message)
except KeyboardInterrupt:
logging.info('Received interrupt, stopping server...')
break
......@@ -78,4 +79,4 @@ class Client(WebSocket):
if __name__ == '__main__':
import sys
port = int(sys.argv[1]) if len(sys.argv) > 1 else 80
Server(port=port, log_level=logging.DEBUG, encoding='utf-8').run()
Server(port=port, log_level=logging.DEBUG).run()
......@@ -6,7 +6,7 @@ from threading import Thread
from frame import ControlFrame, receive_fragments, receive_frame, \
OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG
from message import create_message
from exceptions import SocketClosed, PingError
from exceptions import InvalidRequest, SocketClosed, PingError
WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
......@@ -14,9 +14,8 @@ WS_VERSION = '13'
class WebSocket(object):
def __init__(self, sock, encoding=None):
def __init__(self, sock):
self.sock = sock
self.encoding = encoding
self.received_close_params = None
self.close_frame_sent = False
......@@ -57,40 +56,53 @@ class WebSocket(object):
payload = ''.join([f.payload for f in frames])
return create_message(frames[0].opcode, payload)
def send_handshake(self):
raw_headers = self.sock.recv(512)
if self.encoding:
raw_headers = raw_headers.decode(self.encoding, 'ignore')
def handshake(self):
"""
Execute a handshake with the other end point of the socket. If the HTTP
request headers read from the socket are invalid, an InvalidRequest
exception will be raised.
"""
raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
# request must be HTTP (at least 1.1) GET request, find the location
location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
# Check if headers that MUST be present are actually present
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
'Origin', 'Sec-WebSocket-Version'):
assert name in headers
if name not in headers:
raise InvalidRequest('missing "%s" header' % name)
# Check WebSocket version used by client
assert headers['Sec-WebSocket-Version'] == WS_VERSION
version = headers['Sec-WebSocket-Version']
if version != WS_VERSION:
raise InvalidRequest('WebSocket version %s requested (only %s '
'is supported)' % (version, WS_VERSION))
# Make sure the requested protocols are supported by this server
if 'Sec-WebSocket-Protocol' in headers:
parts = headers['Sec-WebSocket-Protocol'].split(',')
protocols = map(str.strip, parts)
for protocol in protocols:
assert protocol in self.protocols
for p in protocols:
if p not in self.protocols:
raise InvalidRequest('unsupported protocol "%s"' % p)
else:
protocols = []
# Encode acceptation key using the WebSocket GUID
key = headers['Sec-WebSocket-Key']
accept = sha1(key + WS_GUID).digest().encode('base64')
# Construct HTTP response header
shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
shake += 'Upgrade: WebSocket\r\n'
shake += 'Connection: Upgrade\r\n'
shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
shake += 'WebSocket-Location: ws://%s%s\r\n' % (headers['Host'], location)
shake += 'WebSocket-Location: ws://%s%s\r\n' \
% (headers['Host'], location)
shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
if self.protocols:
......@@ -170,7 +182,7 @@ class WebSocket(object):
Called when a message is received. `message' is a Message object, which
can be constructed from a single frame or multiple fragmented frames.
"""
raise NotImplemented
return NotImplemented
def onping(self, payload):
"""
......@@ -178,13 +190,13 @@ class WebSocket(object):
used to start a timeout handler for a pong message that is not received
in time.
"""
raise NotImplemented
pass
def onpong(self, payload):
"""
Called when a pong control frame is received.
"""
raise NotImplemented
pass
def onclose(self, code, reason):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment