Commit 6d2cf25d authored by Taddeüs Kroes's avatar Taddeüs Kroes

Added SSL support (wss://...), updated some docs

parent a942e6bb
*.swp *.swp
*.pyc *.pyc
*~ *~
cert.pem
from websocket import websocket from websocket import websocket
from server import Server from server import Server
from frame import Frame, ControlFrame from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
from Connection import Connection OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
from message import Message, TextMesage, BinaryMessage, JSONMessage CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
__all__ = ['websocket', 'Server', 'Frame', 'ControlFrame', 'Connection', from connection import Connection
'Message', 'TextMessage', 'BinaryMessage', 'JSONMessage'] from message import Message, TextMessage, BinaryMessage, JSONMessage
...@@ -14,3 +14,7 @@ class HandshakeError(Exception): ...@@ -14,3 +14,7 @@ class HandshakeError(Exception):
class PingError(Exception): class PingError(Exception):
pass pass
class SSLError(Exception):
pass
...@@ -2,6 +2,7 @@ import socket ...@@ -2,6 +2,7 @@ import socket
import logging import logging
from traceback import format_exc from traceback import format_exc
from threading import Thread from threading import Thread
from ssl import SSLError
from websocket import websocket from websocket import websocket
from connection import Connection from connection import Connection
...@@ -11,12 +12,12 @@ from errors import HandshakeError ...@@ -11,12 +12,12 @@ from errors import HandshakeError
class Server(object): class Server(object):
""" """
Websocket server object, used to manage multiple client connections. Websocket server, manages multiple client connections.
Example usage:
>>> import websocket Example usage:
>>> import twspy
>>> class GameServer(websocket.Server): >>> class GameServer(twspy.Server):
>>> def onopen(self, client): >>> def onopen(self, client):
>>> # client connected >>> # client connected
...@@ -29,14 +30,38 @@ class Server(object): ...@@ -29,14 +30,38 @@ class Server(object):
>>> GameServer(8000).run() >>> GameServer(8000).run()
""" """
def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[]): def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[],
secure=False, **kwargs):
"""
Constructor for a simple websocket server.
`hostname` and `port` form the address to bind the websocket to.
`loglevel` values should be imported from the logging module.
logging.INFO only shows server start/stop messages, logging.DEBUG shows
clients (dis)connecting and messages being sent/received.
`protocols` is a list of supported protocols, passed directly to the
websocket constructor.
`secure` is a flag indicating whether the server is SSL enabled. In
this case, `keyfile` and `certfile` must be specified. Any additional
keyword arguments are passed to websocket.enable_ssl (and thus to
ssl.wrap_socket).
"""
logging.basicConfig(level=loglevel, logging.basicConfig(level=loglevel,
format='%(asctime)s: %(levelname)s: %(message)s', format='%(asctime)s: %(levelname)s: %(message)s',
datefmt='%H:%M:%S') datefmt='%H:%M:%S')
scheme = 'wss' if secure else 'ws'
logging.info('Starting server at %s://%s:%d', scheme, hostname, port)
self.sock = websocket() self.sock = websocket()
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
logging.info('Starting server at %s:%d', hostname, port)
if secure:
self.sock.enable_ssl(server_side=True, **kwargs)
self.sock.bind((hostname, port)) self.sock.bind((hostname, port))
self.sock.listen(5) self.sock.listen(5)
...@@ -55,6 +80,8 @@ class Server(object): ...@@ -55,6 +80,8 @@ class Server(object):
thread = Thread(target=client.receive_forever) thread = Thread(target=client.receive_forever)
thread.daemon = True thread.daemon = True
thread.start() thread.start()
except SSLError as e:
logging.error('SSL error: %s', e)
except HandshakeError as e: except HandshakeError as e:
logging.error('Invalid request: %s', e.message) logging.error('Invalid request: %s', e.message)
except KeyboardInterrupt: except KeyboardInterrupt:
......
...@@ -10,9 +10,9 @@ ...@@ -10,9 +10,9 @@
document.getElementById('log').innerHTML += line + '\n'; document.getElementById('log').innerHTML += line + '\n';
} }
var URL = 'localhost:8000'; var URL = 'ws://localhost:8000';
log('Connecting to ' + URL); log('Connecting to ' + URL);
var ws = new WebSocket('ws://' + URL); var ws = new WebSocket(URL);
ws.onopen = function() { ws.onopen = function() {
log('Connection complete, sending "foo"'); log('Connection complete, sending "foo"');
......
...@@ -11,4 +11,6 @@ class EchoServer(Server): ...@@ -11,4 +11,6 @@ class EchoServer(Server):
if __name__ == '__main__': if __name__ == '__main__':
EchoServer(8000, 'localhost', loglevel=logging.DEBUG).run() EchoServer(8000, 'localhost',
#secure=True, keyfile='cert.pem', certfile='cert.pem',
loglevel=logging.DEBUG).run()
import re import re
import socket import socket
import ssl
from hashlib import sha1 from hashlib import sha1
from base64 import b64encode from base64 import b64encode
from frame import receive_frame from frame import receive_frame
from errors import HandshakeError from errors import HandshakeError, SSLError
WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
...@@ -19,19 +20,25 @@ class websocket(object): ...@@ -19,19 +20,25 @@ class websocket(object):
""" """
Implementation of web socket, upgrades a regular TCP socket to a websocket Implementation of web socket, upgrades a regular TCP socket to a websocket
using the HTTP handshakes and frame (un)packing, as specified by RFC 6455. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
The API of a websocket is identical to that of a regular socket, as
illustrated by the examples below.
Server example: Server example:
>>> sock = websocket() >>> import twspy, socket
>>> sock = twspy.websocket()
>>> sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
>>> sock.bind(('', 8000)) >>> sock.bind(('', 8000))
>>> sock.listen() >>> sock.listen()
>>> client = sock.accept() >>> client = sock.accept()
>>> client.send(Frame(...)) >>> client.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Client!'))
>>> frame = client.recv() >>> frame = client.recv()
Client example: Client example:
>>> sock = websocket() >>> import twspy
>>> sock = twspy.websocket()
>>> sock.connect(('', 8000)) >>> sock.connect(('', 8000))
>>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
""" """
def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET, def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET,
sproto=0): sproto=0):
...@@ -50,6 +57,8 @@ class websocket(object): ...@@ -50,6 +57,8 @@ class websocket(object):
self.protocols = protocols self.protocols = protocols
self.extensions = extensions self.extensions = extensions
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto) self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
self.secure = False
self.handshake_started = False
def bind(self, address): def bind(self, address):
self.sock.bind(address) self.sock.bind(address)
...@@ -122,9 +131,13 @@ class websocket(object): ...@@ -122,9 +131,13 @@ class websocket(object):
request headers sent by the client are invalid, a HandshakeError request headers sent by the client are invalid, a HandshakeError
is raised. is raised.
""" """
raw_headers = self.sock.recv(512).decode('utf-8', 'ignore') # Receive HTTP header
raw_headers = ''
# request must be HTTP (at least 1.1) GET request, find the location while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
# Request must be HTTP (at least 1.1) GET request, find the location
location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1) location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers) headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers] header_names = [name for name, value in headers]
...@@ -175,6 +188,7 @@ class websocket(object): ...@@ -175,6 +188,7 @@ class websocket(object):
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions) shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
self.sock.sendall(shake + '\r\n') self.sock.sendall(shake + '\r\n')
self.handshake_started = True
def client_handshake(self): def client_handshake(self):
""" """
...@@ -182,4 +196,17 @@ class websocket(object): ...@@ -182,4 +196,17 @@ class websocket(object):
HandshakeError if the server response is invalid. HandshakeError if the server response is invalid.
""" """
# TODO: implement HTTP request headers for client handshake # TODO: implement HTTP request headers for client handshake
raise NotImplementedError() self.handshake_started = True
raise NotImplementedError
def enable_ssl(self, *args, **kwargs):
"""
Transform the regular socket.socket to an ssl.SSLSocket for secure
connections. Any arguments are passed to ssl.wrap_socket:
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
"""
if self.handshake_started:
raise SSLError('can only enable SSL before handshake')
self.secure = True
self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment