Przeglądaj źródła

Implemented client handshake, and did some corresponding debugging

Taddeus Kroes 12 lat temu
rodzic
commit
8f15e28308
9 zmienionych plików z 198 dodań i 34 usunięć
  1. 1 1
      Makefile
  2. 1 0
      __init__.py
  3. 12 4
      connection.py
  4. 3 2
      frame.py
  5. 4 1
      server.py
  6. 1 1
      test/client.html
  7. 36 0
      test/client.py
  8. 5 0
      test/server.py
  9. 135 25
      websocket.py

+ 1 - 1
Makefile

@@ -1,7 +1,7 @@
 .PHONY: check clean
 
 check:
-	@python test.py
+	@python test/server.py
 
 clean:
 	find -name \*.pyc -delete

+ 1 - 0
__init__.py

@@ -7,3 +7,4 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
         CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
 from connection import Connection
 from message import Message, TextMessage, BinaryMessage, JSONMessage
+from errors import SocketClosed

+ 12 - 4
connection.py

@@ -1,4 +1,5 @@
 import struct
+import socket
 
 from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
                   OPCODE_CONTINUATION
@@ -106,6 +107,16 @@ class Connection(object):
             except SocketClosed as e:
                 self.close(e.code, e.reason)
                 break
+            except socket.error as e:
+                self.onerror(e)
+
+                try:
+                    self.sock.close()
+                except socket.error:
+                    pass
+
+                self.onclose(None, '')
+                break
             except Exception as e:
                 self.onerror(e)
 
@@ -124,10 +135,7 @@ class Connection(object):
         close message, unless such a message has already been received earlier
         (prior to calling this function, for example). The onclose() handler is
         called after the response has been received, but before the socket is
-        actually closed. This order was chosen to prevent errors in
-        stringification in the onclose() handler. For example,
-        socket.getpeername() raises a Bad file descriptor error then the socket
-        is closed.
+        actually closed.
         """
         # Send CLOSE frame
         payload = '' if code is None else struct.pack('!H', code) + reason

+ 3 - 2
frame.py

@@ -1,4 +1,5 @@
 import struct
+import socket
 from os import urandom
 from string import printable
 
@@ -150,7 +151,7 @@ class Frame(object):
             s += ' masking_key=%4s' % printstr(self.masking_key)
 
         max_pl_disp = 30
-        pl = self.payload[:max_pl_disp]
+        pl = printstr(self.payload)[:max_pl_disp]
 
         if len(self.payload) > max_pl_disp:
              pl += '...'
@@ -240,7 +241,7 @@ def recvn(sock, n):
         received = sock.recv(n - len(data))
 
         if not len(received):
-            raise SocketClosed(None, 'no data read from socket')
+            raise socket.error('no data read from socket')
 
         data += received
 

+ 4 - 1
server.py

@@ -133,7 +133,10 @@ class Client(Connection):
         super(Client, self).__init__(sock)
 
     def __str__(self):
-        return '<Client at %s:%d>' % self.sock.getpeername()
+        try:
+            return '<Client at %s:%d>' % self.sock.getpeername()
+        except socket.error:
+            return '<Client on closed socket>'
 
     def send(self, message, fragment_size=None, mask=False):
         logging.debug('Sending %s to %s', message, self)

+ 1 - 1
test.html → test/client.html

@@ -15,7 +15,7 @@
             var ws = new WebSocket(URL);
 
             ws.onopen = function() {
-                log('Connection complete, sending "foo"');
+                log('Connection established, sending "foo"');
                 ws.send('foo');
             };
 

+ 36 - 0
test/client.py

@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+import sys
+from os.path import abspath, dirname
+
+basepath = abspath(dirname(abspath(__file__)) + '/..')
+sys.path.insert(0, basepath)
+
+from websocket import websocket
+from connection import Connection
+from message import TextMessage
+from errors import SocketClosed
+
+ADDR = ('localhost', 8000)
+
+
+class EchoClient(Connection):
+    def onopen(self):
+        print 'Connection established, sending "foo"'
+        self.send(TextMessage('foo'))
+
+    def onmessage(self, msg):
+        print 'Received', msg
+        raise SocketClosed(None, 'response received')
+
+    def onerror(self, e):
+        print 'Error:', e
+
+    def onclose(self, code, reason):
+        print 'Connection closed'
+
+
+if __name__ == '__main__':
+    print 'Connecting to ws://%s:%d' % ADDR
+    sock = websocket()
+    sock.connect(ADDR)
+    EchoClient(sock).receive_forever()

+ 5 - 0
test.py → test/server.py

@@ -1,5 +1,10 @@
 #!/usr/bin/env python
+import sys
 import logging
+from os.path import abspath, dirname
+
+basepath = abspath(dirname(abspath(__file__)) + '/..')
+sys.path.insert(0, basepath)
 
 from server import Server
 

+ 135 - 25
websocket.py

@@ -1,3 +1,4 @@
+import os
 import re
 import socket
 import ssl
@@ -13,7 +14,7 @@ WS_VERSION = '13'
 
 
 def split_stripped(value, delim=','):
-    return map(str.strip, str(value).split(delim))
+    return map(str.strip, str(value).split(delim)) if value else []
 
 
 class websocket(object):
@@ -78,13 +79,18 @@ class websocket(object):
         wsock.server_handshake()
         return wsock, address
 
-    def connect(self, address):
+    def connect(self, address, path='/'):
         """
         Equivalent to socket.connect(), but sends an client handshake request
         after connecting.
+
+        `address` is a (host, port) tuple of the server to connect to.
+
+        `path` is optional, used as the *location* part of the HTTP handshake.
+        In a URL, this would show as ws://host[:port]/path.
         """
-        self.sock.sonnect(address)
-        self.client_handshake()
+        self.sock.connect(address)
+        self.client_handshake(address, path)
 
     def send(self, *args):
         """
@@ -131,6 +137,10 @@ class websocket(object):
         request headers sent by the client are invalid, a HandshakeError
         is raised.
         """
+        def fail(msg):
+            self.sock.close()
+            raise HandshakeError(msg)
+
         # Receive HTTP header
         raw_headers = ''
 
@@ -138,7 +148,12 @@ class websocket(object):
             raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
 
         # Request must be HTTP (at least 1.1) GET request, find the location
-        location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
+        match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
+
+        if match is None:
+            fail('not a valid HTTP 1.1 GET request')
+
+        location = match.group(1)
         headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
         header_names = [name for name, value in headers]
 
@@ -147,25 +162,39 @@ class websocket(object):
 
         # Check if headers that MUST be present are actually present
         for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
-                     'Origin', 'Sec-WebSocket-Version'):
+                     'Sec-WebSocket-Version'):
             if name not in header_names:
-                raise HandshakeError('missing "%s" header' % name)
+                fail('missing "%s" header' % name)
 
         # Check WebSocket version used by client
         version = header('Sec-WebSocket-Version')
 
         if version != WS_VERSION:
-            raise HandshakeError('WebSocket version %s requested (only %s '
+            fail('WebSocket version %s requested (only %s '
                                  'is supported)' % (version, WS_VERSION))
 
+        # Verify required header keywords
+        if 'websocket' not in header('Upgrade').lower():
+            fail('"Upgrade" header must contain "websocket"')
+
+        if 'upgrade' not in header('Connection').lower():
+            fail('"Connection" header must contain "Upgrade"')
+
+        # Origin must be present if browser client
+        if 'User-Agent' in header_names and 'Origin' not in header_names:
+            fail('browser client must specify "Origin" header')
+
         # Only supported protocols are returned
-        proto = header('Sec-WebSocket-Extensions')
-        protocols = split_stripped(proto) if proto else []
-        protocols = [p for p in protocols if p in self.protocols]
+        client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
+        protocol = 'null'
+
+        for p in client_protocols:
+            if p in self.protocols:
+                protocol = p
+                break
 
         # Only supported extensions are returned
-        ext = header('Sec-WebSocket-Extensions')
-        extensions = split_stripped(ext) if ext else []
+        extensions = split_stripped(header('Sec-WebSocket-Extensions'))
         extensions = [e for e in extensions if e in self.extensions]
 
         # Encode acceptation key using the WebSocket GUID
@@ -174,34 +203,115 @@ class websocket(object):
 
         # Construct HTTP response header
         shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
-        shake += 'Upgrade: WebSocket\r\n'
+        shake += 'Upgrade: websocket\r\n'
         shake += 'Connection: Upgrade\r\n'
         shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
         shake += 'WebSocket-Location: ws://%s%s\r\n' \
                  % (header('Host'), location)
         shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
-
-        if protocols:
-            shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
-
-        if extensions:
-            shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
+        shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
+        shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
 
         self.sock.sendall(shake + '\r\n')
         self.handshake_started = True
 
-    def client_handshake(self):
+    def client_handshake(self, address, location):
         """
-        Execute a handshake as the client end point of the socket. May raise a
+        Executes a handshake as the client end point of the socket. May raise a
         HandshakeError if the server response is invalid.
         """
-        # TODO: implement HTTP request headers for client handshake
+        def fail(msg):
+            self.sock.close()
+            raise HandshakeError(msg)
+
+        if len(location) == 0:
+            fail('request location is empty')
+
+        # Generate a 16-byte random base64-encoded key for this connection
+        key = b64encode(os.urandom(16))
+
+        # Send client handshake
+        shake = 'GET %s HTTP/1.1\r\n' % location
+        shake += 'Host: %s:%d\r\n' % address
+        shake += 'Upgrade: websocket\r\n'
+        shake += 'Connection: keep-alive, Upgrade\r\n'
+        shake += 'Sec-WebSocket-Key: %s\r\n' % key
+        shake += 'Origin: null\r\n'  # FIXME: is this correct/necessary?
+        shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
+
+        # These are for eagerly caching webservers
+        shake += 'Pragma: no-cache\r\n'
+        shake += 'Cache-Control: no-cache\r\n'
+
+        # Request protocols and extension, these are later checked with the
+        # actual supported values from the server's response
+        if self.protocols:
+            shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
+
+        if self.extensions:
+            shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
+
+        self.sock.sendall(shake + '\r\n')
         self.handshake_started = True
-        raise NotImplementedError
+
+        # Receive and process server handshake
+        raw_headers = ''
+
+        while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
+            raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
+
+        # Response must be HTTP (at least 1.1) with status 101
+        if not raw_headers.startswith('HTTP/1.1 101'):
+            # TODO: implement HTTP authentication (401) and redirect (3xx)?
+            fail('not a valid HTTP 1.1 status 101 response')
+
+        headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
+        header_names = [name for name, value in headers]
+
+        def header(name):
+            return ', '.join([v for n, v in headers if n == name])
+
+        # Check if headers that MUST be present are actually present
+        for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
+            if name not in header_names:
+                fail('missing "%s" header' % name)
+
+        if 'websocket' not in header('Upgrade').lower():
+            fail('"Upgrade" header must contain "websocket"')
+
+        if 'upgrade' not in header('Connection').lower():
+            fail('"Connection" header must contain "Upgrade"')
+
+        # Verify accept header
+        accept = header('Sec-WebSocket-Accept').strip()
+        required_accept = b64encode(sha1(key + WS_GUID).digest())
+
+        if accept != required_accept:
+            fail('invalid websocket accept header "%s"' % accept)
+
+        # Compare extensions
+        server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
+
+        for ext in server_extensions:
+            if ext not in self.extensions:
+                fail('server extension "%s" is unsupported by client' % ext)
+
+        for ext in self.extensions:
+            if ext not in server_extensions:
+                fail('client extension "%s" is unsupported by server' % ext)
+
+        # Assert that returned protocol is supported
+        protocol = header('Sec-WebSocket-Protocol')
+
+        if protocol:
+            if protocol != 'null' and protocol not in self.protocols:
+                fail('unsupported protocol "%s"' % protocol)
+
+            self.protocol = protocol
 
     def enable_ssl(self, *args, **kwargs):
         """
-        Transform the regular socket.socket to an ssl.SSLSocket for secure
+        Transforms the regular socket.socket to an ssl.SSLSocket for secure
         connections. Any arguments are passed to ssl.wrap_socket:
         http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
         """