Procházet zdrojové kódy

Started testing and debugging:

- Renamed exceptions.py to errors.py to avoid name collision with stdlib
- Fixed some variable name collisions
- Worked on improving close handshake
- More minor fixes
Taddeus Kroes před 12 roky
rodič
revize
fa6f57d655
9 změnil soubory, kde provedl 152 přidání a 54 odebrání
  1. 39 23
      connection.py
  2. 16 0
      errors.py
  3. 0 10
      exceptions.py
  4. 8 6
      frame.py
  5. 9 1
      message.py
  6. 11 7
      server.py
  7. 36 0
      test.html
  8. 13 0
      test.py
  9. 20 7
      websocket.py

+ 39 - 23
connection.py

@@ -3,7 +3,7 @@ import struct
 from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
                   OPCODE_CONTINUATION
 from message import create_message
-from exceptions import SocketClosed, PingError
+from errors import SocketClosed, PingError
 
 
 class Connection(object):
@@ -21,8 +21,8 @@ class Connection(object):
         """
         self.sock = sock
 
-        self.received_close_params = None
         self.close_frame_sent = False
+        self.close_frame_received = False
 
         self.ping_sent = False
         self.ping_payload = None
@@ -54,17 +54,17 @@ class Connection(object):
 
             if isinstance(frame, ControlFrame):
                 self.handle_control_frame(frame)
-
-                # No more receiving data after a close message
-                if frame.opcode == OPCODE_CLOSE:
-                    break
             elif len(fragments) and frame.opcode != OPCODE_CONTINUATION:
                 raise ValueError('expected continuation/control frame, got %s '
                                  'instead' % frame)
             else:
                 fragments.append(frame)
 
-        payload = ''.join(f.payload for f in fragments)
+        payload = bytearray()
+
+        for f in fragments:
+            payload += f.payload
+
         return create_message(fragments[0].opcode, payload)
 
     def handle_control_frame(self, frame):
@@ -72,10 +72,20 @@ class Connection(object):
         Handle a control frame as defined by RFC 6455.
         """
         if frame.opcode == OPCODE_CLOSE:
-            # Set parameters and keep receiving the current fragmented frame
-            # chain, assuming that the CLOSE frame will be handled by
-            # handle_close() as soon as possible
-            self.received_close_params = frame.unpack_close()
+            # Handle a close message by sending a response close message if no
+            # CLOSE frame was sent before, and closing the connection. The
+            # onclose() handler is called afterwards.
+            self.close_frame_received = True
+            code, reason = frame.unpack_close()
+
+            if not self.close_frame_sent:
+                payload = '' if code is None else struct.pack('!H', code)
+                self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
+
+            self.sock.close()
+
+            # No more receiving data after a close message
+            raise SocketClosed(code, reason)
 
         elif frame.opcode == OPCODE_PING:
             # Respond with a pong message with identical payload
@@ -103,12 +113,8 @@ class Connection(object):
         while True:
             try:
                 self.onmessage(self.receive())
-
-                if self.received_close_params is not None:
-                    self.handle_close(*self.received_close_params)
-                    break
-            except SocketClosed:
-                self.onclose(None, '')
+            except SocketClosed as e:
+                self.onclose(e.code, e.reason)
                 break
             except Exception as e:
                 self.onerror(e)
@@ -150,13 +156,23 @@ class Connection(object):
         has been sent, but before the response has been received.
         """
         self.send_close(code, reason)
-        # FIXME: swap the two lines below?
-        self.onclose(code, reason)
-        frame = self.sock.recv()
-        self.sock.close()
 
-        if frame.opcode != OPCODE_CLOSE:
-            raise ValueError('expected CLOSE frame, got %s instead' % frame)
+        if not self.close_frame_received:
+            frame = self.sock.recv()
+
+            if frame.opcode != OPCODE_CLOSE:
+                raise ValueError('expected CLOSE frame, got %s instead' % frame)
+
+            res_code, res_reason = frame.unpack_close()
+
+            # FIXME: check if res_code == code and res_reason == reason?
+
+            # FIXME: alternatively, keep receiving frames in a loop until a
+            # CLOSE frame is received, so that a fragmented chain may arrive
+            # fully first
+
+        self.sock.close()
+        self.onclose(code, reason)
 
     def onopen(self):
         """

+ 16 - 0
errors.py

@@ -0,0 +1,16 @@
+class SocketClosed(Exception):
+    def __init__(self, code=None, reason=''):
+        self.code = code
+        self.reason = reason
+
+    @property
+    def message(self):
+        return ('' if self.code is None else '[%d] ' % self.code) + self.reason
+
+
+class HandshakeError(Exception):
+    pass
+
+
+class PingError(Exception):
+    pass

+ 0 - 10
exceptions.py

@@ -1,10 +0,0 @@
-class SocketClosed(Exception):
-    pass
-
-
-class HandshakeError(Exception):
-    pass
-
-
-class PingError(Exception):
-    pass

+ 8 - 6
frame.py

@@ -1,7 +1,8 @@
 import struct
 from os import urandom
+from curses.ascii import isprint
 
-from exceptions import SocketClosed
+from errors import SocketClosed
 
 
 OPCODE_CONTINUATION = 0x0
@@ -82,7 +83,7 @@ class Frame(object):
         |                     Payload Data continued ...                |
         +---------------------------------------------------------------+
         """
-        header = struct.pack('!B', (self.fin << 7) | (self.rsv1 << 6) |
+        header = struct.pack('!B', (self.final << 7) | (self.rsv1 << 6) |
                              (self.rsv2 << 5) | (self.rsv3 << 4) | self.opcode)
 
         mask = bool(self.masking_key) << 7
@@ -142,7 +143,8 @@ class Frame(object):
             % (self.__class__.__name__, self.opcode, len(self.payload))
 
         if self.masking_key:
-            s += ' masking_key=%4s' % self.masking_key
+            key = ''.join(c if isprint(c) else '.' for c in self.masking_key)
+            s += ' masking_key=%4s' % key
 
         return s + '>'
 
@@ -197,7 +199,7 @@ def receive_frame(sock):
     rsv3 = bool(b1 & 0x10)
     opcode = b1 & 0x0F
 
-    mask = bool(b2 & 0x80)
+    masked = bool(b2 & 0x80)
     payload_len = b2 & 0x7F
 
     if payload_len == 126:
@@ -205,7 +207,7 @@ def receive_frame(sock):
     elif payload_len == 127:
         payload_len = struct.unpack('!Q', recvn(sock, 8))
 
-    if mask:
+    if masked:
         masking_key = recvn(sock, 4)
         payload = mask(masking_key, recvn(sock, payload_len))
     else:
@@ -229,7 +231,7 @@ def recvn(sock, n):
         received = sock.recv(n - len(data))
 
         if not len(received):
-            raise SocketClosed()
+            raise SocketClosed(None, 'no data read from socket')
 
         data += received
 

+ 9 - 1
message.py

@@ -24,7 +24,15 @@ class Message(object):
 
 class TextMessage(Message):
     def __init__(self, payload):
-        super(TextMessage, self).__init__(OPCODE_TEXT, payload.encode('utf-8'))
+        text = str(payload).encode('utf-8')
+        super(TextMessage, self).__init__(OPCODE_TEXT, text)
+
+    def __str__(self):
+        if len(self.payload) > 30:
+            return '<TextMessage "%s"... size=%d>' \
+                    % (self.payload[:30], len(self.payload))
+
+        return '<TextMessage "%s" size=%d>' % (self.payload, len(self.payload))
 
 
 class BinaryMessage(Message):

+ 11 - 7
server.py

@@ -6,7 +6,7 @@ from threading import Thread
 from websocket import websocket
 from connection import Connection
 from frame import CLOSE_NORMAL
-from exceptions import HandshakeError
+from errors import HandshakeError
 
 
 class Server(object):
@@ -48,9 +48,9 @@ class Server(object):
             try:
                 sock, address = self.sock.accept()
 
-                client = Client(self, sock, address)
+                client = Client(self, sock)
                 self.clients.append(client)
-                logging.info('Registered client %s', client)
+                logging.debug('Registered client %s', client)
 
                 thread = Thread(target=client.receive_forever)
                 thread.daemon = True
@@ -102,8 +102,15 @@ class Server(object):
 
 class Client(Connection):
     def __init__(self, server, sock):
-        super(Client, self).__init__(sock)
         self.server = server
+        super(Client, self).__init__(sock)
+
+    def __str__(self):
+        return '<Client at %s:%d>' % self.sock.getpeername()
+
+    def send(self, message, fragment_size=None, mask=False):
+        logging.debug('Sending %s to %s', message, self)
+        Connection.send(self, message, fragment_size=fragment_size, mask=mask)
 
     def onopen(self):
         self.server.onopen(self)
@@ -123,9 +130,6 @@ class Client(Connection):
     def onerror(self, e):
         self.server.onerror(self, e)
 
-    def __str__(self):
-        return '<Client at %s:%d>' % self.sock.getpeername()
-
 
 if __name__ == '__main__':
     import sys

+ 36 - 0
test.html

@@ -0,0 +1,36 @@
+<!doctype html>
+<html>
+    <head>
+        <title>twspy echo test client</title>
+    </head>
+    <body>
+        <textarea id="log" rows="20" cols="80" readonly="readonly"></textarea>
+        <script type="text/javascript">
+            function log(line) {
+                document.getElementById('log').innerHTML += line + '\n';
+            }
+
+            var URL = 'localhost:8000';
+            log('Connecting to ' + URL);
+            var ws = new WebSocket('ws://' + URL);
+
+            ws.onopen = function() {
+                log('Connection complete, sending "foo"');
+                ws.send('foo');
+            };
+
+            ws.onmessage = function(msg) {
+                log('Received "' + msg + '", closing connection');
+                ws.close();
+            };
+
+            ws.onerror = function(e) {
+                log('Error', e);
+            };
+
+            ws.onclose = function() {
+                log('Connection closed');
+            };
+        </script>
+    </body>
+</html>

+ 13 - 0
test.py

@@ -1 +1,14 @@
 #!/usr/bin/env python
+import logging
+
+from server import Server
+
+
+class EchoServer(Server):
+    def onmessage(self, client, message):
+        Server.onmessage(self, client, message)
+        client.send(message)
+
+
+if __name__ == '__main__':
+    EchoServer(8000, 'localhost', loglevel=logging.DEBUG).run()

+ 20 - 7
websocket.py

@@ -3,7 +3,7 @@ import socket
 from hashlib import sha1
 
 from frame import receive_frame
-from exceptions import HandshakeError
+from errors import HandshakeError
 
 
 WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
@@ -11,7 +11,7 @@ WS_VERSION = '13'
 
 
 def split_stripped(value, delim=','):
-    return map(str.strip, value.split(delim))
+    return map(str.strip, str(value).split(delim))
 
 
 class websocket(object):
@@ -32,11 +32,14 @@ class websocket(object):
     >>> sock = websocket()
     >>> sock.connect(('', 8000))
     """
-    def __init__(self, protocols=[], extensions=[], sfamily=socket.AF_INET,
-            sproto=0):
+    def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET,
+                 sproto=0):
         """
         Create a regular TCP socket of family `family` and protocol
 
+        `sock` is an optional regular TCP socket to be used for sending binary
+        data. If not specified, a new socket is created.
+
         `protocols` is a list of supported protocol names.
 
         `extensions` is a list of supported extensions.
@@ -45,7 +48,7 @@ class websocket(object):
         """
         self.protocols = protocols
         self.extensions = extensions
-        self.sock = socket.socket(sfamily, socket.SOCK_STREAM, sproto)
+        self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
 
     def bind(self, address):
         self.sock.bind(address)
@@ -78,14 +81,24 @@ class websocket(object):
         Send a number of frames.
         """
         for frame in args:
+            print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername(), frame.payload
             self.sock.sendall(frame.pack())
 
-    def recv(self, n=1):
+    def recv(self):
+        """
+        Receive a single frames. This can be either a data frame or a control
+        frame.
+        """
+        frame = receive_frame(self.sock)
+        print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
+        return frame
+
+    def recvn(self, n):
         """
         Receive exactly `n` frames. These can be either data frames or control
         frames, or a combination of both.
         """
-        return [receive_frame(self.sock) for i in xrange(n)]
+        return [self.recv() for i in xrange(n)]
 
     def getpeername(self):
         return self.sock.getpeername()