Selaa lähdekoodia

Imprroved handshake in terms of execptions, some minor bugfixes too

Taddeus Kroes 13 vuotta sitten
vanhempi
sitoutus
a992f35289
3 muutettua tiedostoa jossa 38 lisäystä ja 21 poistoa
  1. 4 0
      exceptions.py
  2. 6 5
      server.py
  3. 28 16
      websocket.py

+ 4 - 0
exceptions.py

@@ -2,5 +2,9 @@ class SocketClosed(Exception):
     pass
 
 
+class InvalidRequest(ValueError):
+    pass
+
+
 class PingError(Exception):
     pass

+ 6 - 5
server.py

@@ -3,11 +3,11 @@ import logging
 from traceback import format_exc
 
 from websocket import WebSocket
+from exceptions import InvalidRequest
 
 
 class Server(object):
-    def __init__(self, port, address='', log_level=logging.INFO, protocols=[],
-            encoding=None):
+    def __init__(self, port, address='', log_level=logging.INFO, protocols=[]):
         logging.basicConfig(level=log_level,
                 format='%(asctime)s: %(levelname)s: %(message)s',
                 datefmt='%H:%M:%S')
@@ -20,17 +20,18 @@ class Server(object):
 
         self.clients = []
         self.protocols = protocols
-        self.encoding = encoding
 
     def run(self):
         while True:
             try:
                 client_socket, address = self.sock.accept()
                 client = Client(self, client_socket, address)
-                client.send_handshake()
+                client.handshake()
                 self.clients.append(client)
                 logging.info('Registered client %s', client)
                 client.run_threaded()
+            except InvalidRequest as e:
+                logging.error('Invalid request: %s', e.message)
             except KeyboardInterrupt:
                 logging.info('Received interrupt, stopping server...')
                 break
@@ -78,4 +79,4 @@ class Client(WebSocket):
 if __name__ == '__main__':
     import sys
     port = int(sys.argv[1]) if len(sys.argv) > 1 else 80
-    Server(port=port, log_level=logging.DEBUG, encoding='utf-8').run()
+    Server(port=port, log_level=logging.DEBUG).run()

+ 28 - 16
websocket.py

@@ -6,7 +6,7 @@ from threading import Thread
 from frame import ControlFrame, receive_fragments, receive_frame, \
         OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG
 from message import create_message
-from exceptions import SocketClosed, PingError
+from exceptions import InvalidRequest, SocketClosed, PingError
 
 
 WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
@@ -14,9 +14,8 @@ WS_VERSION = '13'
 
 
 class WebSocket(object):
-    def __init__(self, sock, encoding=None):
+    def __init__(self, sock):
         self.sock = sock
-        self.encoding = encoding
 
         self.received_close_params = None
         self.close_frame_sent = False
@@ -57,40 +56,53 @@ class WebSocket(object):
         payload = ''.join([f.payload for f in frames])
         return create_message(frames[0].opcode, payload)
 
-    def send_handshake(self):
-        raw_headers = self.sock.recv(512)
-
-        if self.encoding:
-            raw_headers = raw_headers.decode(self.encoding, 'ignore')
+    def handshake(self):
+        """
+        Execute a handshake with the other end point of the socket. If the HTTP
+        request headers read from the socket are invalid, an InvalidRequest
+        exception will be raised.
+        """
+        raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
 
+        # request must be HTTP (at least 1.1) GET request, find the location
         location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
         headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
 
         # Check if headers that MUST be present are actually present
         for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
                      'Origin', 'Sec-WebSocket-Version'):
-            assert name in headers
+            if name not in headers:
+                raise InvalidRequest('missing "%s" header' % name)
 
         # Check WebSocket version used by client
-        assert headers['Sec-WebSocket-Version'] == WS_VERSION
+        version = headers['Sec-WebSocket-Version']
+
+        if version != WS_VERSION:
+            raise InvalidRequest('WebSocket version %s requested (only %s '
+                                 'is supported)' % (version, WS_VERSION))
 
         # Make sure the requested protocols are supported by this server
         if 'Sec-WebSocket-Protocol' in headers:
             parts = headers['Sec-WebSocket-Protocol'].split(',')
             protocols = map(str.strip, parts)
 
-            for protocol in protocols:
-                assert protocol in self.protocols
+            for p in protocols:
+                if p not in self.protocols:
+                    raise InvalidRequest('unsupported protocol "%s"' % p)
         else:
             protocols = []
 
+        # Encode acceptation key using the WebSocket GUID
         key = headers['Sec-WebSocket-Key']
         accept = sha1(key + WS_GUID).digest().encode('base64')
+
+        # Construct HTTP response header
         shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
         shake += 'Upgrade: WebSocket\r\n'
         shake += 'Connection: Upgrade\r\n'
         shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
-        shake += 'WebSocket-Location: ws://%s%s\r\n' % (headers['Host'], location)
+        shake += 'WebSocket-Location: ws://%s%s\r\n' \
+                 % (headers['Host'], location)
         shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
 
         if self.protocols:
@@ -170,7 +182,7 @@ class WebSocket(object):
         Called when a message is received. `message' is a Message object, which
         can be constructed from a single frame or multiple fragmented frames.
         """
-        raise NotImplemented
+        return NotImplemented
 
     def onping(self, payload):
         """
@@ -178,13 +190,13 @@ class WebSocket(object):
         used to start a timeout handler for a pong message that is not received
         in time.
         """
-        raise NotImplemented
+        pass
 
     def onpong(self, payload):
         """
         Called when a pong control frame is received.
         """
-        raise NotImplemented
+        pass
 
     def onclose(self, code, reason):
         """