Przeglądaj źródła

Implemented HTTP redirect and authentication in client handshake

Taddeus Kroes 12 lat temu
rodzic
commit
667d611527
1 zmienionych plików z 108 dodań i 65 usunięć
  1. 108 65
      websocket.py

+ 108 - 65
websocket.py

@@ -4,6 +4,7 @@ import socket
 import ssl
 import ssl
 from hashlib import sha1
 from hashlib import sha1
 from base64 import b64encode
 from base64 import b64encode
+from urlparse import urlparse
 
 
 from frame import receive_frame
 from frame import receive_frame
 from errors import HandshakeError, SSLError
 from errors import HandshakeError, SSLError
@@ -79,7 +80,7 @@ class websocket(object):
         wsock.server_handshake()
         wsock.server_handshake()
         return wsock, address
         return wsock, address
 
 
-    def connect(self, address, path='/'):
+    def connect(self, address, path='/', auth=None):
         """
         """
         Equivalent to socket.connect(), but sends an client handshake request
         Equivalent to socket.connect(), but sends an client handshake request
         after connecting.
         after connecting.
@@ -88,9 +89,12 @@ class websocket(object):
 
 
         `path` is optional, used as the *location* part of the HTTP handshake.
         `path` is optional, used as the *location* part of the HTTP handshake.
         In a URL, this would show as ws://host[:port]/path.
         In a URL, this would show as ws://host[:port]/path.
+
+        `auth` is optional, used for the HTTP "Authorization" header of the
+        handshake request.
         """
         """
         self.sock.connect(address)
         self.sock.connect(address)
-        self.client_handshake(address, path)
+        self.client_handshake(address, path, auth)
 
 
     def send(self, *args):
     def send(self, *args):
         """
         """
@@ -215,7 +219,7 @@ class websocket(object):
         self.sock.sendall(shake + '\r\n')
         self.sock.sendall(shake + '\r\n')
         self.handshake_started = True
         self.handshake_started = True
 
 
-    def client_handshake(self, address, location):
+    def client_handshake(self, address, location, auth):
         """
         """
         Executes a handshake as the client end point of the socket. May raise a
         Executes a handshake as the client end point of the socket. May raise a
         HandshakeError if the server response is invalid.
         HandshakeError if the server response is invalid.
@@ -224,90 +228,129 @@ class websocket(object):
             self.sock.close()
             self.sock.close()
             raise HandshakeError(msg)
             raise HandshakeError(msg)
 
 
-        if len(location) == 0:
-            fail('request location is empty')
+        def send_request(location):
+            if len(location) == 0:
+                fail('request location is empty')
 
 
-        # Generate a 16-byte random base64-encoded key for this connection
-        key = b64encode(os.urandom(16))
+            # Generate a 16-byte random base64-encoded key for this connection
+            key = b64encode(os.urandom(16))
 
 
-        # Send client handshake
-        shake = 'GET %s HTTP/1.1\r\n' % location
-        shake += 'Host: %s:%d\r\n' % address
-        shake += 'Upgrade: websocket\r\n'
-        shake += 'Connection: keep-alive, Upgrade\r\n'
-        shake += 'Sec-WebSocket-Key: %s\r\n' % key
-        shake += 'Origin: null\r\n'  # FIXME: is this correct/necessary?
-        shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
+            # Send client handshake
+            shake = 'GET %s HTTP/1.1\r\n' % location
+            shake += 'Host: %s:%d\r\n' % address
+            shake += 'Upgrade: websocket\r\n'
+            shake += 'Connection: keep-alive, Upgrade\r\n'
+            shake += 'Sec-WebSocket-Key: %s\r\n' % key
+            shake += 'Origin: null\r\n'  # FIXME: is this correct/necessary?
+            shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
 
 
-        # These are for eagerly caching webservers
-        shake += 'Pragma: no-cache\r\n'
-        shake += 'Cache-Control: no-cache\r\n'
+            # These are for eagerly caching webservers
+            shake += 'Pragma: no-cache\r\n'
+            shake += 'Cache-Control: no-cache\r\n'
 
 
-        # Request protocols and extension, these are later checked with the
-        # actual supported values from the server's response
-        if self.protocols:
-            shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
+            # Request protocols and extension, these are later checked with the
+            # actual supported values from the server's response
+            if self.protocols:
+                shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
 
 
-        if self.extensions:
-            shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
+            if self.extensions:
+                shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
 
 
-        self.sock.sendall(shake + '\r\n')
-        self.handshake_started = True
+            if auth:
+                shake += 'Authorization: %s\r\n' % auth
 
 
-        # Receive and process server handshake
-        raw_headers = ''
+            self.sock.sendall(shake + '\r\n')
+            return key
 
 
-        while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
-            raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
+        def receive_response(key):
+            # Receive and process server handshake
+            raw_headers = ''
 
 
-        # Response must be HTTP (at least 1.1) with status 101
-        if not raw_headers.startswith('HTTP/1.1 101'):
-            # TODO: implement HTTP authentication (401) and redirect (3xx)?
-            fail('not a valid HTTP 1.1 status 101 response')
+            while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
+                raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
 
 
-        headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
-        header_names = [name for name, value in headers]
+            # Response must be HTTP (at least 1.1) with status 101
+            match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
 
 
-        def header(name):
-            return ', '.join([v for n, v in headers if n == name])
+            if match is None:
+                fail('not a valid HTTP 1.1 response')
 
 
-        # Check if headers that MUST be present are actually present
-        for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
-            if name not in header_names:
-                fail('missing "%s" header' % name)
+            status = int(match.group(1))
+            headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
+            header_names = [name for name, value in headers]
 
 
-        if 'websocket' not in header('Upgrade').lower():
-            fail('"Upgrade" header must contain "websocket"')
+            def header(name):
+                return ', '.join([v for n, v in headers if n == name])
 
 
-        if 'upgrade' not in header('Connection').lower():
-            fail('"Connection" header must contain "Upgrade"')
+            if status == 401:
+                # HTTP authentication is required in the request
+                raise HandshakeError('HTTP authentication required: %s'
+                                     % header('WWW-Authenticate'))
+
+            if status in (301, 302, 303, 307, 308):
+                # Handle HTTP redirect
+                url = urlparse(header('Location').strip())
+
+                # Reconnect socket if net location changed
+                if not url.port:
+                    url.port = 443 if self.secure else 80
+
+                addr = (url.netloc, url.port)
+
+                if addr != self.sock.getpeername():
+                    self.sock.close()
+                    self.sock.connect(addr)
+
+                # Send new handshake
+                receive_response(send_request(url.path))
+                return
 
 
-        # Verify accept header
-        accept = header('Sec-WebSocket-Accept').strip()
-        required_accept = b64encode(sha1(key + WS_GUID).digest())
+            if status != 101:
+                # 101 means server has accepted the connection and sent
+                # handshake headers
+                fail('invalid HTTP response status %d' % status)
 
 
-        if accept != required_accept:
-            fail('invalid websocket accept header "%s"' % accept)
+            # Check if headers that MUST be present are actually present
+            for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
+                if name not in header_names:
+                    fail('missing "%s" header' % name)
 
 
-        # Compare extensions
-        server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
+            if 'websocket' not in header('Upgrade').lower():
+                fail('"Upgrade" header must contain "websocket"')
 
 
-        for ext in server_extensions:
-            if ext not in self.extensions:
-                fail('server extension "%s" is unsupported by client' % ext)
+            if 'upgrade' not in header('Connection').lower():
+                fail('"Connection" header must contain "Upgrade"')
 
 
-        for ext in self.extensions:
-            if ext not in server_extensions:
-                fail('client extension "%s" is unsupported by server' % ext)
+            # Verify accept header
+            accept = header('Sec-WebSocket-Accept').strip()
+            required_accept = b64encode(sha1(key + WS_GUID).digest())
 
 
-        # Assert that returned protocol is supported
-        protocol = header('Sec-WebSocket-Protocol')
+            if accept != required_accept:
+                fail('invalid websocket accept header "%s"' % accept)
 
 
-        if protocol:
-            if protocol != 'null' and protocol not in self.protocols:
-                fail('unsupported protocol "%s"' % protocol)
+            # Compare extensions
+            server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
+
+            for e in server_ext:
+                if e not in self.extensions:
+                    fail('server extension "%s" is unsupported by client' % e)
+
+            for e in self.extensions:
+                if e not in server_ext:
+                    fail('client extension "%s" is unsupported by server' % e)
+
+            # Assert that returned protocol (if any) is supported
+            protocol = header('Sec-WebSocket-Protocol')
+
+            if protocol:
+                if protocol != 'null' and protocol not in self.protocols:
+                    fail('unsupported protocol "%s"' % protocol)
+
+                self.protocol = protocol
+
+        self.handshake_started = True
+        receive_response(send_request(location))
 
 
-            self.protocol = protocol
 
 
     def enable_ssl(self, *args, **kwargs):
     def enable_ssl(self, *args, **kwargs):
         """
         """