Explorar o código

Added SSL support (wss://...), updated some docs

Taddeus Kroes %!s(int64=12) %!d(string=hai) anos
pai
achega
6d2cf25d8d
Modificáronse 7 ficheiros con 84 adicións e 23 borrados
  1. 1 0
      .gitignore
  2. 7 7
      __init__.py
  3. 4 0
      errors.py
  4. 33 6
      server.py
  5. 2 2
      test.html
  6. 3 1
      test.py
  7. 34 7
      websocket.py

+ 1 - 0
.gitignore

@@ -1,3 +1,4 @@
 *.swp
 *.pyc
 *~
+cert.pem

+ 7 - 7
__init__.py

@@ -1,9 +1,9 @@
 from websocket import websocket
 from server import Server
-from frame import Frame, ControlFrame
-from Connection import Connection
-from message import Message, TextMesage, BinaryMessage, JSONMessage
-
-
-__all__ = ['websocket', 'Server', 'Frame', 'ControlFrame', 'Connection',
-           'Message', 'TextMessage', 'BinaryMessage', 'JSONMessage']
+from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
+        OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, CLOSE_NORMAL, \
+        CLOSE_GOING_AWAY, CLOSE_PROTOCOL_ERROR, CLOSE_NOACCEPT_DTYPE, \
+        CLOSE_INVALID_DATA, CLOSE_POLICY, CLOSE_MESSAGE_TOOBIG, \
+        CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
+from connection import Connection
+from message import Message, TextMessage, BinaryMessage, JSONMessage

+ 4 - 0
errors.py

@@ -14,3 +14,7 @@ class HandshakeError(Exception):
 
 class PingError(Exception):
     pass
+
+
+class SSLError(Exception):
+    pass

+ 33 - 6
server.py

@@ -2,6 +2,7 @@ import socket
 import logging
 from traceback import format_exc
 from threading import Thread
+from ssl import SSLError
 
 from websocket import websocket
 from connection import Connection
@@ -11,12 +12,12 @@ from errors import HandshakeError
 
 class Server(object):
     """
-    Websocket server object, used to manage multiple client connections.
-    Example usage:
+    Websocket server, manages multiple client connections.
 
-    >>> import websocket
+    Example usage:
+    >>> import twspy
 
-    >>> class GameServer(websocket.Server):
+    >>> class GameServer(twspy.Server):
     >>>     def onopen(self, client):
     >>>         # client connected
 
@@ -29,14 +30,38 @@ class Server(object):
     >>> GameServer(8000).run()
     """
 
-    def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[]):
+    def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[],
+                 secure=False, **kwargs):
+        """
+        Constructor for a simple websocket server.
+
+        `hostname` and `port` form the address to bind the websocket to.
+
+        `loglevel` values should be imported from the logging module.
+        logging.INFO only shows server start/stop messages, logging.DEBUG shows
+        clients (dis)connecting and messages being sent/received.
+
+        `protocols` is a list of supported protocols, passed directly to the
+        websocket constructor.
+
+        `secure` is a flag indicating whether the server is SSL enabled. In
+        this case, `keyfile` and `certfile` must be specified. Any additional
+        keyword arguments are passed to websocket.enable_ssl (and thus to
+        ssl.wrap_socket).
+        """
         logging.basicConfig(level=loglevel,
                 format='%(asctime)s: %(levelname)s: %(message)s',
                 datefmt='%H:%M:%S')
 
+        scheme = 'wss' if secure else 'ws'
+        logging.info('Starting server at %s://%s:%d', scheme, hostname, port)
+
         self.sock = websocket()
         self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        logging.info('Starting server at %s:%d', hostname, port)
+
+        if secure:
+            self.sock.enable_ssl(server_side=True, **kwargs)
+
         self.sock.bind((hostname, port))
         self.sock.listen(5)
 
@@ -55,6 +80,8 @@ class Server(object):
                 thread = Thread(target=client.receive_forever)
                 thread.daemon = True
                 thread.start()
+            except SSLError as e:
+                logging.error('SSL error: %s', e)
             except HandshakeError as e:
                 logging.error('Invalid request: %s', e.message)
             except KeyboardInterrupt:

+ 2 - 2
test.html

@@ -10,9 +10,9 @@
                 document.getElementById('log').innerHTML += line + '\n';
             }
 
-            var URL = 'localhost:8000';
+            var URL = 'ws://localhost:8000';
             log('Connecting to ' + URL);
-            var ws = new WebSocket('ws://' + URL);
+            var ws = new WebSocket(URL);
 
             ws.onopen = function() {
                 log('Connection complete, sending "foo"');

+ 3 - 1
test.py

@@ -11,4 +11,6 @@ class EchoServer(Server):
 
 
 if __name__ == '__main__':
-    EchoServer(8000, 'localhost', loglevel=logging.DEBUG).run()
+    EchoServer(8000, 'localhost',
+               #secure=True, keyfile='cert.pem', certfile='cert.pem',
+               loglevel=logging.DEBUG).run()

+ 34 - 7
websocket.py

@@ -1,10 +1,11 @@
 import re
 import socket
+import ssl
 from hashlib import sha1
 from base64 import b64encode
 
 from frame import receive_frame
-from errors import HandshakeError
+from errors import HandshakeError, SSLError
 
 
 WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
@@ -19,19 +20,25 @@ class websocket(object):
     """
     Implementation of web socket, upgrades a regular TCP socket to a websocket
     using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
+    The API of a websocket is identical to that of a regular socket, as
+    illustrated by the examples below.
 
     Server example:
-    >>> sock = websocket()
+    >>> import twspy, socket
+    >>> sock = twspy.websocket()
+    >>> sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
     >>> sock.bind(('', 8000))
     >>> sock.listen()
 
     >>> client = sock.accept()
-    >>> client.send(Frame(...))
+    >>> client.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Client!'))
     >>> frame = client.recv()
 
     Client example:
-    >>> sock = websocket()
+    >>> import twspy
+    >>> sock = twspy.websocket()
     >>> sock.connect(('', 8000))
+    >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
     """
     def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET,
                  sproto=0):
@@ -50,6 +57,8 @@ class websocket(object):
         self.protocols = protocols
         self.extensions = extensions
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
+        self.secure = False
+        self.handshake_started = False
 
     def bind(self, address):
         self.sock.bind(address)
@@ -122,9 +131,13 @@ class websocket(object):
         request headers sent by the client are invalid, a HandshakeError
         is raised.
         """
-        raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
+        # Receive HTTP header
+        raw_headers = ''
 
-        # request must be HTTP (at least 1.1) GET request, find the location
+        while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
+            raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
+
+        # Request must be HTTP (at least 1.1) GET request, find the location
         location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
         headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
         header_names = [name for name, value in headers]
@@ -175,6 +188,7 @@ class websocket(object):
             shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
 
         self.sock.sendall(shake + '\r\n')
+        self.handshake_started = True
 
     def client_handshake(self):
         """
@@ -182,4 +196,17 @@ class websocket(object):
         HandshakeError if the server response is invalid.
         """
         # TODO: implement HTTP request headers for client handshake
-        raise NotImplementedError()
+        self.handshake_started = True
+        raise NotImplementedError
+
+    def enable_ssl(self, *args, **kwargs):
+        """
+        Transform the regular socket.socket to an ssl.SSLSocket for secure
+        connections. Any arguments are passed to ssl.wrap_socket:
+        http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
+        """
+        if self.handshake_started:
+            raise SSLError('can only enable SSL before handshake')
+
+        self.secure = True
+        self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)