|
|
@@ -1,3 +1,4 @@
|
|
|
+import os
|
|
|
import re
|
|
|
import socket
|
|
|
import ssl
|
|
|
@@ -13,7 +14,7 @@ WS_VERSION = '13'
|
|
|
|
|
|
|
|
|
def split_stripped(value, delim=','):
|
|
|
- return map(str.strip, str(value).split(delim))
|
|
|
+ return map(str.strip, str(value).split(delim)) if value else []
|
|
|
|
|
|
|
|
|
class websocket(object):
|
|
|
@@ -78,13 +79,18 @@ class websocket(object):
|
|
|
wsock.server_handshake()
|
|
|
return wsock, address
|
|
|
|
|
|
- def connect(self, address):
|
|
|
+ def connect(self, address, path='/'):
|
|
|
"""
|
|
|
Equivalent to socket.connect(), but sends an client handshake request
|
|
|
after connecting.
|
|
|
+
|
|
|
+ `address` is a (host, port) tuple of the server to connect to.
|
|
|
+
|
|
|
+ `path` is optional, used as the *location* part of the HTTP handshake.
|
|
|
+ In a URL, this would show as ws://host[:port]/path.
|
|
|
"""
|
|
|
- self.sock.sonnect(address)
|
|
|
- self.client_handshake()
|
|
|
+ self.sock.connect(address)
|
|
|
+ self.client_handshake(address, path)
|
|
|
|
|
|
def send(self, *args):
|
|
|
"""
|
|
|
@@ -131,6 +137,10 @@ class websocket(object):
|
|
|
request headers sent by the client are invalid, a HandshakeError
|
|
|
is raised.
|
|
|
"""
|
|
|
+ def fail(msg):
|
|
|
+ self.sock.close()
|
|
|
+ raise HandshakeError(msg)
|
|
|
+
|
|
|
# Receive HTTP header
|
|
|
raw_headers = ''
|
|
|
|
|
|
@@ -138,7 +148,12 @@ class websocket(object):
|
|
|
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
|
|
|
|
|
|
# Request must be HTTP (at least 1.1) GET request, find the location
|
|
|
- location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
|
|
|
+ match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
|
|
|
+
|
|
|
+ if match is None:
|
|
|
+ fail('not a valid HTTP 1.1 GET request')
|
|
|
+
|
|
|
+ location = match.group(1)
|
|
|
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
|
|
|
header_names = [name for name, value in headers]
|
|
|
|
|
|
@@ -147,25 +162,39 @@ class websocket(object):
|
|
|
|
|
|
# Check if headers that MUST be present are actually present
|
|
|
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
|
|
|
- 'Origin', 'Sec-WebSocket-Version'):
|
|
|
+ 'Sec-WebSocket-Version'):
|
|
|
if name not in header_names:
|
|
|
- raise HandshakeError('missing "%s" header' % name)
|
|
|
+ fail('missing "%s" header' % name)
|
|
|
|
|
|
# Check WebSocket version used by client
|
|
|
version = header('Sec-WebSocket-Version')
|
|
|
|
|
|
if version != WS_VERSION:
|
|
|
- raise HandshakeError('WebSocket version %s requested (only %s '
|
|
|
+ fail('WebSocket version %s requested (only %s '
|
|
|
'is supported)' % (version, WS_VERSION))
|
|
|
|
|
|
+ # Verify required header keywords
|
|
|
+ if 'websocket' not in header('Upgrade').lower():
|
|
|
+ fail('"Upgrade" header must contain "websocket"')
|
|
|
+
|
|
|
+ if 'upgrade' not in header('Connection').lower():
|
|
|
+ fail('"Connection" header must contain "Upgrade"')
|
|
|
+
|
|
|
+ # Origin must be present if browser client
|
|
|
+ if 'User-Agent' in header_names and 'Origin' not in header_names:
|
|
|
+ fail('browser client must specify "Origin" header')
|
|
|
+
|
|
|
# Only supported protocols are returned
|
|
|
- proto = header('Sec-WebSocket-Extensions')
|
|
|
- protocols = split_stripped(proto) if proto else []
|
|
|
- protocols = [p for p in protocols if p in self.protocols]
|
|
|
+ client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
|
|
|
+ protocol = 'null'
|
|
|
+
|
|
|
+ for p in client_protocols:
|
|
|
+ if p in self.protocols:
|
|
|
+ protocol = p
|
|
|
+ break
|
|
|
|
|
|
# Only supported extensions are returned
|
|
|
- ext = header('Sec-WebSocket-Extensions')
|
|
|
- extensions = split_stripped(ext) if ext else []
|
|
|
+ extensions = split_stripped(header('Sec-WebSocket-Extensions'))
|
|
|
extensions = [e for e in extensions if e in self.extensions]
|
|
|
|
|
|
# Encode acceptation key using the WebSocket GUID
|
|
|
@@ -174,34 +203,115 @@ class websocket(object):
|
|
|
|
|
|
# Construct HTTP response header
|
|
|
shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
|
|
|
- shake += 'Upgrade: WebSocket\r\n'
|
|
|
+ shake += 'Upgrade: websocket\r\n'
|
|
|
shake += 'Connection: Upgrade\r\n'
|
|
|
shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
|
|
|
shake += 'WebSocket-Location: ws://%s%s\r\n' \
|
|
|
% (header('Host'), location)
|
|
|
shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
|
|
|
-
|
|
|
- if protocols:
|
|
|
- shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
|
|
|
-
|
|
|
- if extensions:
|
|
|
- shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
|
|
|
+ shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
|
|
|
+ shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
|
|
|
|
|
|
self.sock.sendall(shake + '\r\n')
|
|
|
self.handshake_started = True
|
|
|
|
|
|
- def client_handshake(self):
|
|
|
+ def client_handshake(self, address, location):
|
|
|
"""
|
|
|
- Execute a handshake as the client end point of the socket. May raise a
|
|
|
+ Executes a handshake as the client end point of the socket. May raise a
|
|
|
HandshakeError if the server response is invalid.
|
|
|
"""
|
|
|
- # TODO: implement HTTP request headers for client handshake
|
|
|
+ def fail(msg):
|
|
|
+ self.sock.close()
|
|
|
+ raise HandshakeError(msg)
|
|
|
+
|
|
|
+ if len(location) == 0:
|
|
|
+ fail('request location is empty')
|
|
|
+
|
|
|
+ # Generate a 16-byte random base64-encoded key for this connection
|
|
|
+ key = b64encode(os.urandom(16))
|
|
|
+
|
|
|
+ # Send client handshake
|
|
|
+ shake = 'GET %s HTTP/1.1\r\n' % location
|
|
|
+ shake += 'Host: %s:%d\r\n' % address
|
|
|
+ shake += 'Upgrade: websocket\r\n'
|
|
|
+ shake += 'Connection: keep-alive, Upgrade\r\n'
|
|
|
+ shake += 'Sec-WebSocket-Key: %s\r\n' % key
|
|
|
+ shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
|
|
|
+ shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
|
|
|
+
|
|
|
+ # These are for eagerly caching webservers
|
|
|
+ shake += 'Pragma: no-cache\r\n'
|
|
|
+ shake += 'Cache-Control: no-cache\r\n'
|
|
|
+
|
|
|
+ # Request protocols and extension, these are later checked with the
|
|
|
+ # actual supported values from the server's response
|
|
|
+ if self.protocols:
|
|
|
+ shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
|
|
|
+
|
|
|
+ if self.extensions:
|
|
|
+ shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
|
|
|
+
|
|
|
+ self.sock.sendall(shake + '\r\n')
|
|
|
self.handshake_started = True
|
|
|
- raise NotImplementedError
|
|
|
+
|
|
|
+ # Receive and process server handshake
|
|
|
+ raw_headers = ''
|
|
|
+
|
|
|
+ while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
|
|
|
+ raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
|
|
|
+
|
|
|
+ # Response must be HTTP (at least 1.1) with status 101
|
|
|
+ if not raw_headers.startswith('HTTP/1.1 101'):
|
|
|
+ # TODO: implement HTTP authentication (401) and redirect (3xx)?
|
|
|
+ fail('not a valid HTTP 1.1 status 101 response')
|
|
|
+
|
|
|
+ headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
|
|
|
+ header_names = [name for name, value in headers]
|
|
|
+
|
|
|
+ def header(name):
|
|
|
+ return ', '.join([v for n, v in headers if n == name])
|
|
|
+
|
|
|
+ # Check if headers that MUST be present are actually present
|
|
|
+ for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
|
|
|
+ if name not in header_names:
|
|
|
+ fail('missing "%s" header' % name)
|
|
|
+
|
|
|
+ if 'websocket' not in header('Upgrade').lower():
|
|
|
+ fail('"Upgrade" header must contain "websocket"')
|
|
|
+
|
|
|
+ if 'upgrade' not in header('Connection').lower():
|
|
|
+ fail('"Connection" header must contain "Upgrade"')
|
|
|
+
|
|
|
+ # Verify accept header
|
|
|
+ accept = header('Sec-WebSocket-Accept').strip()
|
|
|
+ required_accept = b64encode(sha1(key + WS_GUID).digest())
|
|
|
+
|
|
|
+ if accept != required_accept:
|
|
|
+ fail('invalid websocket accept header "%s"' % accept)
|
|
|
+
|
|
|
+ # Compare extensions
|
|
|
+ server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
|
|
|
+
|
|
|
+ for ext in server_extensions:
|
|
|
+ if ext not in self.extensions:
|
|
|
+ fail('server extension "%s" is unsupported by client' % ext)
|
|
|
+
|
|
|
+ for ext in self.extensions:
|
|
|
+ if ext not in server_extensions:
|
|
|
+ fail('client extension "%s" is unsupported by server' % ext)
|
|
|
+
|
|
|
+ # Assert that returned protocol is supported
|
|
|
+ protocol = header('Sec-WebSocket-Protocol')
|
|
|
+
|
|
|
+ if protocol:
|
|
|
+ if protocol != 'null' and protocol not in self.protocols:
|
|
|
+ fail('unsupported protocol "%s"' % protocol)
|
|
|
+
|
|
|
+ self.protocol = protocol
|
|
|
|
|
|
def enable_ssl(self, *args, **kwargs):
|
|
|
"""
|
|
|
- Transform the regular socket.socket to an ssl.SSLSocket for secure
|
|
|
+ Transforms the regular socket.socket to an ssl.SSLSocket for secure
|
|
|
connections. Any arguments are passed to ssl.wrap_socket:
|
|
|
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
|
|
|
"""
|