|
@@ -4,6 +4,7 @@ import socket
|
|
|
import ssl
|
|
import ssl
|
|
|
from hashlib import sha1
|
|
from hashlib import sha1
|
|
|
from base64 import b64encode
|
|
from base64 import b64encode
|
|
|
|
|
+from urlparse import urlparse
|
|
|
|
|
|
|
|
from frame import receive_frame
|
|
from frame import receive_frame
|
|
|
from errors import HandshakeError, SSLError
|
|
from errors import HandshakeError, SSLError
|
|
@@ -79,7 +80,7 @@ class websocket(object):
|
|
|
wsock.server_handshake()
|
|
wsock.server_handshake()
|
|
|
return wsock, address
|
|
return wsock, address
|
|
|
|
|
|
|
|
- def connect(self, address, path='/'):
|
|
|
|
|
|
|
+ def connect(self, address, path='/', auth=None):
|
|
|
"""
|
|
"""
|
|
|
Equivalent to socket.connect(), but sends an client handshake request
|
|
Equivalent to socket.connect(), but sends an client handshake request
|
|
|
after connecting.
|
|
after connecting.
|
|
@@ -88,9 +89,12 @@ class websocket(object):
|
|
|
|
|
|
|
|
`path` is optional, used as the *location* part of the HTTP handshake.
|
|
`path` is optional, used as the *location* part of the HTTP handshake.
|
|
|
In a URL, this would show as ws://host[:port]/path.
|
|
In a URL, this would show as ws://host[:port]/path.
|
|
|
|
|
+
|
|
|
|
|
+ `auth` is optional, used for the HTTP "Authorization" header of the
|
|
|
|
|
+ handshake request.
|
|
|
"""
|
|
"""
|
|
|
self.sock.connect(address)
|
|
self.sock.connect(address)
|
|
|
- self.client_handshake(address, path)
|
|
|
|
|
|
|
+ self.client_handshake(address, path, auth)
|
|
|
|
|
|
|
|
def send(self, *args):
|
|
def send(self, *args):
|
|
|
"""
|
|
"""
|
|
@@ -215,7 +219,7 @@ class websocket(object):
|
|
|
self.sock.sendall(shake + '\r\n')
|
|
self.sock.sendall(shake + '\r\n')
|
|
|
self.handshake_started = True
|
|
self.handshake_started = True
|
|
|
|
|
|
|
|
- def client_handshake(self, address, location):
|
|
|
|
|
|
|
+ def client_handshake(self, address, location, auth):
|
|
|
"""
|
|
"""
|
|
|
Executes a handshake as the client end point of the socket. May raise a
|
|
Executes a handshake as the client end point of the socket. May raise a
|
|
|
HandshakeError if the server response is invalid.
|
|
HandshakeError if the server response is invalid.
|
|
@@ -224,90 +228,129 @@ class websocket(object):
|
|
|
self.sock.close()
|
|
self.sock.close()
|
|
|
raise HandshakeError(msg)
|
|
raise HandshakeError(msg)
|
|
|
|
|
|
|
|
- if len(location) == 0:
|
|
|
|
|
- fail('request location is empty')
|
|
|
|
|
|
|
+ def send_request(location):
|
|
|
|
|
+ if len(location) == 0:
|
|
|
|
|
+ fail('request location is empty')
|
|
|
|
|
|
|
|
- # Generate a 16-byte random base64-encoded key for this connection
|
|
|
|
|
- key = b64encode(os.urandom(16))
|
|
|
|
|
|
|
+ # Generate a 16-byte random base64-encoded key for this connection
|
|
|
|
|
+ key = b64encode(os.urandom(16))
|
|
|
|
|
|
|
|
- # Send client handshake
|
|
|
|
|
- shake = 'GET %s HTTP/1.1\r\n' % location
|
|
|
|
|
- shake += 'Host: %s:%d\r\n' % address
|
|
|
|
|
- shake += 'Upgrade: websocket\r\n'
|
|
|
|
|
- shake += 'Connection: keep-alive, Upgrade\r\n'
|
|
|
|
|
- shake += 'Sec-WebSocket-Key: %s\r\n' % key
|
|
|
|
|
- shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
|
|
|
|
|
- shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
|
|
|
|
|
|
|
+ # Send client handshake
|
|
|
|
|
+ shake = 'GET %s HTTP/1.1\r\n' % location
|
|
|
|
|
+ shake += 'Host: %s:%d\r\n' % address
|
|
|
|
|
+ shake += 'Upgrade: websocket\r\n'
|
|
|
|
|
+ shake += 'Connection: keep-alive, Upgrade\r\n'
|
|
|
|
|
+ shake += 'Sec-WebSocket-Key: %s\r\n' % key
|
|
|
|
|
+ shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
|
|
|
|
|
+ shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
|
|
|
|
|
|
|
|
- # These are for eagerly caching webservers
|
|
|
|
|
- shake += 'Pragma: no-cache\r\n'
|
|
|
|
|
- shake += 'Cache-Control: no-cache\r\n'
|
|
|
|
|
|
|
+ # These are for eagerly caching webservers
|
|
|
|
|
+ shake += 'Pragma: no-cache\r\n'
|
|
|
|
|
+ shake += 'Cache-Control: no-cache\r\n'
|
|
|
|
|
|
|
|
- # Request protocols and extension, these are later checked with the
|
|
|
|
|
- # actual supported values from the server's response
|
|
|
|
|
- if self.protocols:
|
|
|
|
|
- shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
|
|
|
|
|
|
|
+ # Request protocols and extension, these are later checked with the
|
|
|
|
|
+ # actual supported values from the server's response
|
|
|
|
|
+ if self.protocols:
|
|
|
|
|
+ shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
|
|
|
|
|
|
|
|
- if self.extensions:
|
|
|
|
|
- shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
|
|
|
|
|
|
|
+ if self.extensions:
|
|
|
|
|
+ shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
|
|
|
|
|
|
|
|
- self.sock.sendall(shake + '\r\n')
|
|
|
|
|
- self.handshake_started = True
|
|
|
|
|
|
|
+ if auth:
|
|
|
|
|
+ shake += 'Authorization: %s\r\n' % auth
|
|
|
|
|
|
|
|
- # Receive and process server handshake
|
|
|
|
|
- raw_headers = ''
|
|
|
|
|
|
|
+ self.sock.sendall(shake + '\r\n')
|
|
|
|
|
+ return key
|
|
|
|
|
|
|
|
- while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
|
|
|
|
|
- raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
|
|
|
|
|
|
|
+ def receive_response(key):
|
|
|
|
|
+ # Receive and process server handshake
|
|
|
|
|
+ raw_headers = ''
|
|
|
|
|
|
|
|
- # Response must be HTTP (at least 1.1) with status 101
|
|
|
|
|
- if not raw_headers.startswith('HTTP/1.1 101'):
|
|
|
|
|
- # TODO: implement HTTP authentication (401) and redirect (3xx)?
|
|
|
|
|
- fail('not a valid HTTP 1.1 status 101 response')
|
|
|
|
|
|
|
+ while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
|
|
|
|
|
+ raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
|
|
|
|
|
|
|
|
- headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
|
|
|
|
|
- header_names = [name for name, value in headers]
|
|
|
|
|
|
|
+ # Response must be HTTP (at least 1.1) with status 101
|
|
|
|
|
+ match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
|
|
|
|
|
|
|
|
- def header(name):
|
|
|
|
|
- return ', '.join([v for n, v in headers if n == name])
|
|
|
|
|
|
|
+ if match is None:
|
|
|
|
|
+ fail('not a valid HTTP 1.1 response')
|
|
|
|
|
|
|
|
- # Check if headers that MUST be present are actually present
|
|
|
|
|
- for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
|
|
|
|
|
- if name not in header_names:
|
|
|
|
|
- fail('missing "%s" header' % name)
|
|
|
|
|
|
|
+ status = int(match.group(1))
|
|
|
|
|
+ headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
|
|
|
|
|
+ header_names = [name for name, value in headers]
|
|
|
|
|
|
|
|
- if 'websocket' not in header('Upgrade').lower():
|
|
|
|
|
- fail('"Upgrade" header must contain "websocket"')
|
|
|
|
|
|
|
+ def header(name):
|
|
|
|
|
+ return ', '.join([v for n, v in headers if n == name])
|
|
|
|
|
|
|
|
- if 'upgrade' not in header('Connection').lower():
|
|
|
|
|
- fail('"Connection" header must contain "Upgrade"')
|
|
|
|
|
|
|
+ if status == 401:
|
|
|
|
|
+ # HTTP authentication is required in the request
|
|
|
|
|
+ raise HandshakeError('HTTP authentication required: %s'
|
|
|
|
|
+ % header('WWW-Authenticate'))
|
|
|
|
|
+
|
|
|
|
|
+ if status in (301, 302, 303, 307, 308):
|
|
|
|
|
+ # Handle HTTP redirect
|
|
|
|
|
+ url = urlparse(header('Location').strip())
|
|
|
|
|
+
|
|
|
|
|
+ # Reconnect socket if net location changed
|
|
|
|
|
+ if not url.port:
|
|
|
|
|
+ url.port = 443 if self.secure else 80
|
|
|
|
|
+
|
|
|
|
|
+ addr = (url.netloc, url.port)
|
|
|
|
|
+
|
|
|
|
|
+ if addr != self.sock.getpeername():
|
|
|
|
|
+ self.sock.close()
|
|
|
|
|
+ self.sock.connect(addr)
|
|
|
|
|
+
|
|
|
|
|
+ # Send new handshake
|
|
|
|
|
+ receive_response(send_request(url.path))
|
|
|
|
|
+ return
|
|
|
|
|
|
|
|
- # Verify accept header
|
|
|
|
|
- accept = header('Sec-WebSocket-Accept').strip()
|
|
|
|
|
- required_accept = b64encode(sha1(key + WS_GUID).digest())
|
|
|
|
|
|
|
+ if status != 101:
|
|
|
|
|
+ # 101 means server has accepted the connection and sent
|
|
|
|
|
+ # handshake headers
|
|
|
|
|
+ fail('invalid HTTP response status %d' % status)
|
|
|
|
|
|
|
|
- if accept != required_accept:
|
|
|
|
|
- fail('invalid websocket accept header "%s"' % accept)
|
|
|
|
|
|
|
+ # Check if headers that MUST be present are actually present
|
|
|
|
|
+ for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
|
|
|
|
|
+ if name not in header_names:
|
|
|
|
|
+ fail('missing "%s" header' % name)
|
|
|
|
|
|
|
|
- # Compare extensions
|
|
|
|
|
- server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
|
|
|
|
|
|
|
+ if 'websocket' not in header('Upgrade').lower():
|
|
|
|
|
+ fail('"Upgrade" header must contain "websocket"')
|
|
|
|
|
|
|
|
- for ext in server_extensions:
|
|
|
|
|
- if ext not in self.extensions:
|
|
|
|
|
- fail('server extension "%s" is unsupported by client' % ext)
|
|
|
|
|
|
|
+ if 'upgrade' not in header('Connection').lower():
|
|
|
|
|
+ fail('"Connection" header must contain "Upgrade"')
|
|
|
|
|
|
|
|
- for ext in self.extensions:
|
|
|
|
|
- if ext not in server_extensions:
|
|
|
|
|
- fail('client extension "%s" is unsupported by server' % ext)
|
|
|
|
|
|
|
+ # Verify accept header
|
|
|
|
|
+ accept = header('Sec-WebSocket-Accept').strip()
|
|
|
|
|
+ required_accept = b64encode(sha1(key + WS_GUID).digest())
|
|
|
|
|
|
|
|
- # Assert that returned protocol is supported
|
|
|
|
|
- protocol = header('Sec-WebSocket-Protocol')
|
|
|
|
|
|
|
+ if accept != required_accept:
|
|
|
|
|
+ fail('invalid websocket accept header "%s"' % accept)
|
|
|
|
|
|
|
|
- if protocol:
|
|
|
|
|
- if protocol != 'null' and protocol not in self.protocols:
|
|
|
|
|
- fail('unsupported protocol "%s"' % protocol)
|
|
|
|
|
|
|
+ # Compare extensions
|
|
|
|
|
+ server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
|
|
|
|
|
+
|
|
|
|
|
+ for e in server_ext:
|
|
|
|
|
+ if e not in self.extensions:
|
|
|
|
|
+ fail('server extension "%s" is unsupported by client' % e)
|
|
|
|
|
+
|
|
|
|
|
+ for e in self.extensions:
|
|
|
|
|
+ if e not in server_ext:
|
|
|
|
|
+ fail('client extension "%s" is unsupported by server' % e)
|
|
|
|
|
+
|
|
|
|
|
+ # Assert that returned protocol (if any) is supported
|
|
|
|
|
+ protocol = header('Sec-WebSocket-Protocol')
|
|
|
|
|
+
|
|
|
|
|
+ if protocol:
|
|
|
|
|
+ if protocol != 'null' and protocol not in self.protocols:
|
|
|
|
|
+ fail('unsupported protocol "%s"' % protocol)
|
|
|
|
|
+
|
|
|
|
|
+ self.protocol = protocol
|
|
|
|
|
+
|
|
|
|
|
+ self.handshake_started = True
|
|
|
|
|
+ receive_response(send_request(location))
|
|
|
|
|
|
|
|
- self.protocol = protocol
|
|
|
|
|
|
|
|
|
|
def enable_ssl(self, *args, **kwargs):
|
|
def enable_ssl(self, *args, **kwargs):
|
|
|
"""
|
|
"""
|