Просмотр исходного кода

Implemented client handshake, and did some corresponding debugging

Taddeus Kroes 12 лет назад
Родитель
Сommit
8f15e28308
9 измененных файлов с 198 добавлено и 34 удалено
  1. 1 1
      Makefile
  2. 1 0
      __init__.py
  3. 12 4
      connection.py
  4. 3 2
      frame.py
  5. 4 1
      server.py
  6. 1 1
      test/client.html
  7. 36 0
      test/client.py
  8. 5 0
      test/server.py
  9. 135 25
      websocket.py

+ 1 - 1
Makefile

@@ -1,7 +1,7 @@
 .PHONY: check clean
 .PHONY: check clean
 
 
 check:
 check:
-	@python test.py
+	@python test/server.py
 
 
 clean:
 clean:
 	find -name \*.pyc -delete
 	find -name \*.pyc -delete

+ 1 - 0
__init__.py

@@ -7,3 +7,4 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
         CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
         CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
 from connection import Connection
 from connection import Connection
 from message import Message, TextMessage, BinaryMessage, JSONMessage
 from message import Message, TextMessage, BinaryMessage, JSONMessage
+from errors import SocketClosed

+ 12 - 4
connection.py

@@ -1,4 +1,5 @@
 import struct
 import struct
+import socket
 
 
 from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
 from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
                   OPCODE_CONTINUATION
                   OPCODE_CONTINUATION
@@ -106,6 +107,16 @@ class Connection(object):
             except SocketClosed as e:
             except SocketClosed as e:
                 self.close(e.code, e.reason)
                 self.close(e.code, e.reason)
                 break
                 break
+            except socket.error as e:
+                self.onerror(e)
+
+                try:
+                    self.sock.close()
+                except socket.error:
+                    pass
+
+                self.onclose(None, '')
+                break
             except Exception as e:
             except Exception as e:
                 self.onerror(e)
                 self.onerror(e)
 
 
@@ -124,10 +135,7 @@ class Connection(object):
         close message, unless such a message has already been received earlier
         close message, unless such a message has already been received earlier
         (prior to calling this function, for example). The onclose() handler is
         (prior to calling this function, for example). The onclose() handler is
         called after the response has been received, but before the socket is
         called after the response has been received, but before the socket is
-        actually closed. This order was chosen to prevent errors in
-        stringification in the onclose() handler. For example,
-        socket.getpeername() raises a Bad file descriptor error then the socket
-        is closed.
+        actually closed.
         """
         """
         # Send CLOSE frame
         # Send CLOSE frame
         payload = '' if code is None else struct.pack('!H', code) + reason
         payload = '' if code is None else struct.pack('!H', code) + reason

+ 3 - 2
frame.py

@@ -1,4 +1,5 @@
 import struct
 import struct
+import socket
 from os import urandom
 from os import urandom
 from string import printable
 from string import printable
 
 
@@ -150,7 +151,7 @@ class Frame(object):
             s += ' masking_key=%4s' % printstr(self.masking_key)
             s += ' masking_key=%4s' % printstr(self.masking_key)
 
 
         max_pl_disp = 30
         max_pl_disp = 30
-        pl = self.payload[:max_pl_disp]
+        pl = printstr(self.payload)[:max_pl_disp]
 
 
         if len(self.payload) > max_pl_disp:
         if len(self.payload) > max_pl_disp:
              pl += '...'
              pl += '...'
@@ -240,7 +241,7 @@ def recvn(sock, n):
         received = sock.recv(n - len(data))
         received = sock.recv(n - len(data))
 
 
         if not len(received):
         if not len(received):
-            raise SocketClosed(None, 'no data read from socket')
+            raise socket.error('no data read from socket')
 
 
         data += received
         data += received
 
 

+ 4 - 1
server.py

@@ -133,7 +133,10 @@ class Client(Connection):
         super(Client, self).__init__(sock)
         super(Client, self).__init__(sock)
 
 
     def __str__(self):
     def __str__(self):
-        return '<Client at %s:%d>' % self.sock.getpeername()
+        try:
+            return '<Client at %s:%d>' % self.sock.getpeername()
+        except socket.error:
+            return '<Client on closed socket>'
 
 
     def send(self, message, fragment_size=None, mask=False):
     def send(self, message, fragment_size=None, mask=False):
         logging.debug('Sending %s to %s', message, self)
         logging.debug('Sending %s to %s', message, self)

+ 1 - 1
test.html → test/client.html

@@ -15,7 +15,7 @@
             var ws = new WebSocket(URL);
             var ws = new WebSocket(URL);
 
 
             ws.onopen = function() {
             ws.onopen = function() {
-                log('Connection complete, sending "foo"');
+                log('Connection established, sending "foo"');
                 ws.send('foo');
                 ws.send('foo');
             };
             };
 
 

+ 36 - 0
test/client.py

@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+import sys
+from os.path import abspath, dirname
+
+basepath = abspath(dirname(abspath(__file__)) + '/..')
+sys.path.insert(0, basepath)
+
+from websocket import websocket
+from connection import Connection
+from message import TextMessage
+from errors import SocketClosed
+
+ADDR = ('localhost', 8000)
+
+
+class EchoClient(Connection):
+    def onopen(self):
+        print 'Connection established, sending "foo"'
+        self.send(TextMessage('foo'))
+
+    def onmessage(self, msg):
+        print 'Received', msg
+        raise SocketClosed(None, 'response received')
+
+    def onerror(self, e):
+        print 'Error:', e
+
+    def onclose(self, code, reason):
+        print 'Connection closed'
+
+
+if __name__ == '__main__':
+    print 'Connecting to ws://%s:%d' % ADDR
+    sock = websocket()
+    sock.connect(ADDR)
+    EchoClient(sock).receive_forever()

+ 5 - 0
test.py → test/server.py

@@ -1,5 +1,10 @@
 #!/usr/bin/env python
 #!/usr/bin/env python
+import sys
 import logging
 import logging
+from os.path import abspath, dirname
+
+basepath = abspath(dirname(abspath(__file__)) + '/..')
+sys.path.insert(0, basepath)
 
 
 from server import Server
 from server import Server
 
 

+ 135 - 25
websocket.py

@@ -1,3 +1,4 @@
+import os
 import re
 import re
 import socket
 import socket
 import ssl
 import ssl
@@ -13,7 +14,7 @@ WS_VERSION = '13'
 
 
 
 
 def split_stripped(value, delim=','):
 def split_stripped(value, delim=','):
-    return map(str.strip, str(value).split(delim))
+    return map(str.strip, str(value).split(delim)) if value else []
 
 
 
 
 class websocket(object):
 class websocket(object):
@@ -78,13 +79,18 @@ class websocket(object):
         wsock.server_handshake()
         wsock.server_handshake()
         return wsock, address
         return wsock, address
 
 
-    def connect(self, address):
+    def connect(self, address, path='/'):
         """
         """
         Equivalent to socket.connect(), but sends an client handshake request
         Equivalent to socket.connect(), but sends an client handshake request
         after connecting.
         after connecting.
+
+        `address` is a (host, port) tuple of the server to connect to.
+
+        `path` is optional, used as the *location* part of the HTTP handshake.
+        In a URL, this would show as ws://host[:port]/path.
         """
         """
-        self.sock.sonnect(address)
-        self.client_handshake()
+        self.sock.connect(address)
+        self.client_handshake(address, path)
 
 
     def send(self, *args):
     def send(self, *args):
         """
         """
@@ -131,6 +137,10 @@ class websocket(object):
         request headers sent by the client are invalid, a HandshakeError
         request headers sent by the client are invalid, a HandshakeError
         is raised.
         is raised.
         """
         """
+        def fail(msg):
+            self.sock.close()
+            raise HandshakeError(msg)
+
         # Receive HTTP header
         # Receive HTTP header
         raw_headers = ''
         raw_headers = ''
 
 
@@ -138,7 +148,12 @@ class websocket(object):
             raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
             raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
 
 
         # Request must be HTTP (at least 1.1) GET request, find the location
         # Request must be HTTP (at least 1.1) GET request, find the location
-        location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
+        match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
+
+        if match is None:
+            fail('not a valid HTTP 1.1 GET request')
+
+        location = match.group(1)
         headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
         headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
         header_names = [name for name, value in headers]
         header_names = [name for name, value in headers]
 
 
@@ -147,25 +162,39 @@ class websocket(object):
 
 
         # Check if headers that MUST be present are actually present
         # Check if headers that MUST be present are actually present
         for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
         for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
-                     'Origin', 'Sec-WebSocket-Version'):
+                     'Sec-WebSocket-Version'):
             if name not in header_names:
             if name not in header_names:
-                raise HandshakeError('missing "%s" header' % name)
+                fail('missing "%s" header' % name)
 
 
         # Check WebSocket version used by client
         # Check WebSocket version used by client
         version = header('Sec-WebSocket-Version')
         version = header('Sec-WebSocket-Version')
 
 
         if version != WS_VERSION:
         if version != WS_VERSION:
-            raise HandshakeError('WebSocket version %s requested (only %s '
+            fail('WebSocket version %s requested (only %s '
                                  'is supported)' % (version, WS_VERSION))
                                  'is supported)' % (version, WS_VERSION))
 
 
+        # Verify required header keywords
+        if 'websocket' not in header('Upgrade').lower():
+            fail('"Upgrade" header must contain "websocket"')
+
+        if 'upgrade' not in header('Connection').lower():
+            fail('"Connection" header must contain "Upgrade"')
+
+        # Origin must be present if browser client
+        if 'User-Agent' in header_names and 'Origin' not in header_names:
+            fail('browser client must specify "Origin" header')
+
         # Only supported protocols are returned
         # Only supported protocols are returned
-        proto = header('Sec-WebSocket-Extensions')
-        protocols = split_stripped(proto) if proto else []
-        protocols = [p for p in protocols if p in self.protocols]
+        client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
+        protocol = 'null'
+
+        for p in client_protocols:
+            if p in self.protocols:
+                protocol = p
+                break
 
 
         # Only supported extensions are returned
         # Only supported extensions are returned
-        ext = header('Sec-WebSocket-Extensions')
-        extensions = split_stripped(ext) if ext else []
+        extensions = split_stripped(header('Sec-WebSocket-Extensions'))
         extensions = [e for e in extensions if e in self.extensions]
         extensions = [e for e in extensions if e in self.extensions]
 
 
         # Encode acceptation key using the WebSocket GUID
         # Encode acceptation key using the WebSocket GUID
@@ -174,34 +203,115 @@ class websocket(object):
 
 
         # Construct HTTP response header
         # Construct HTTP response header
         shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
         shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
-        shake += 'Upgrade: WebSocket\r\n'
+        shake += 'Upgrade: websocket\r\n'
         shake += 'Connection: Upgrade\r\n'
         shake += 'Connection: Upgrade\r\n'
         shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
         shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
         shake += 'WebSocket-Location: ws://%s%s\r\n' \
         shake += 'WebSocket-Location: ws://%s%s\r\n' \
                  % (header('Host'), location)
                  % (header('Host'), location)
         shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
         shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
-
-        if protocols:
-            shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
-
-        if extensions:
-            shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
+        shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
+        shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
 
 
         self.sock.sendall(shake + '\r\n')
         self.sock.sendall(shake + '\r\n')
         self.handshake_started = True
         self.handshake_started = True
 
 
-    def client_handshake(self):
+    def client_handshake(self, address, location):
         """
         """
-        Execute a handshake as the client end point of the socket. May raise a
+        Executes a handshake as the client end point of the socket. May raise a
         HandshakeError if the server response is invalid.
         HandshakeError if the server response is invalid.
         """
         """
-        # TODO: implement HTTP request headers for client handshake
+        def fail(msg):
+            self.sock.close()
+            raise HandshakeError(msg)
+
+        if len(location) == 0:
+            fail('request location is empty')
+
+        # Generate a 16-byte random base64-encoded key for this connection
+        key = b64encode(os.urandom(16))
+
+        # Send client handshake
+        shake = 'GET %s HTTP/1.1\r\n' % location
+        shake += 'Host: %s:%d\r\n' % address
+        shake += 'Upgrade: websocket\r\n'
+        shake += 'Connection: keep-alive, Upgrade\r\n'
+        shake += 'Sec-WebSocket-Key: %s\r\n' % key
+        shake += 'Origin: null\r\n'  # FIXME: is this correct/necessary?
+        shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
+
+        # These are for eagerly caching webservers
+        shake += 'Pragma: no-cache\r\n'
+        shake += 'Cache-Control: no-cache\r\n'
+
+        # Request protocols and extension, these are later checked with the
+        # actual supported values from the server's response
+        if self.protocols:
+            shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
+
+        if self.extensions:
+            shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
+
+        self.sock.sendall(shake + '\r\n')
         self.handshake_started = True
         self.handshake_started = True
-        raise NotImplementedError
+
+        # Receive and process server handshake
+        raw_headers = ''
+
+        while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
+            raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
+
+        # Response must be HTTP (at least 1.1) with status 101
+        if not raw_headers.startswith('HTTP/1.1 101'):
+            # TODO: implement HTTP authentication (401) and redirect (3xx)?
+            fail('not a valid HTTP 1.1 status 101 response')
+
+        headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
+        header_names = [name for name, value in headers]
+
+        def header(name):
+            return ', '.join([v for n, v in headers if n == name])
+
+        # Check if headers that MUST be present are actually present
+        for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
+            if name not in header_names:
+                fail('missing "%s" header' % name)
+
+        if 'websocket' not in header('Upgrade').lower():
+            fail('"Upgrade" header must contain "websocket"')
+
+        if 'upgrade' not in header('Connection').lower():
+            fail('"Connection" header must contain "Upgrade"')
+
+        # Verify accept header
+        accept = header('Sec-WebSocket-Accept').strip()
+        required_accept = b64encode(sha1(key + WS_GUID).digest())
+
+        if accept != required_accept:
+            fail('invalid websocket accept header "%s"' % accept)
+
+        # Compare extensions
+        server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
+
+        for ext in server_extensions:
+            if ext not in self.extensions:
+                fail('server extension "%s" is unsupported by client' % ext)
+
+        for ext in self.extensions:
+            if ext not in server_extensions:
+                fail('client extension "%s" is unsupported by server' % ext)
+
+        # Assert that returned protocol is supported
+        protocol = header('Sec-WebSocket-Protocol')
+
+        if protocol:
+            if protocol != 'null' and protocol not in self.protocols:
+                fail('unsupported protocol "%s"' % protocol)
+
+            self.protocol = protocol
 
 
     def enable_ssl(self, *args, **kwargs):
     def enable_ssl(self, *args, **kwargs):
         """
         """
-        Transform the regular socket.socket to an ssl.SSLSocket for secure
+        Transforms the regular socket.socket to an ssl.SSLSocket for secure
         connections. Any arguments are passed to ssl.wrap_socket:
         connections. Any arguments are passed to ssl.wrap_socket:
         http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
         http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
         """
         """