|
|
@@ -10,6 +10,10 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
|
|
WS_VERSION = '13'
|
|
|
|
|
|
|
|
|
+def split_stripped(value, delim=','):
|
|
|
+ return map(str.strip, value.split(delim))
|
|
|
+
|
|
|
+
|
|
|
class websocket(object):
|
|
|
"""
|
|
|
Implementation of web socket, upgrades a regular TCP socket to a websocket
|
|
|
@@ -28,13 +32,20 @@ class websocket(object):
|
|
|
>>> sock = websocket()
|
|
|
>>> sock.connect(('kompiler.org', 80))
|
|
|
"""
|
|
|
- def __init__(self, wsprotocols=[], family=socket.AF_INET, proto=0):
|
|
|
+ def __init__(self, protocols=[], extensions=[], family=socket.AF_INET,
|
|
|
+ proto=0):
|
|
|
"""
|
|
|
- Create aregular TCP socket of family `family` and protocol
|
|
|
- `wsprotocols` is a list of supported protocol names.
|
|
|
+ Create a regular TCP socket of family `family` and protocol
|
|
|
+
|
|
|
+ `protocols` is a list of supported protocol names.
|
|
|
+
|
|
|
+ `extensions` is a list of supported extensions.
|
|
|
+
|
|
|
+ `family` and `proto` are used for the regular socket constructor.
|
|
|
"""
|
|
|
+ self.protocols = protocols
|
|
|
+ self.extensions = extensions
|
|
|
self.sock = socket.socket(family, socket.SOCK_STREAM, proto)
|
|
|
- self.protocols = wsprotocols
|
|
|
|
|
|
def bind(self, address):
|
|
|
self.sock.bind(address)
|
|
|
@@ -98,48 +109,53 @@ class websocket(object):
|
|
|
|
|
|
# request must be HTTP (at least 1.1) GET request, find the location
|
|
|
location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
|
|
|
- headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
|
|
|
+ headers = re.findall(r'(.*?): (.*?)\r\n', raw_headers)
|
|
|
+ header_names = [name for name, value in headers]
|
|
|
+
|
|
|
+ def header(name):
|
|
|
+ return ', '.join([v for n, v in headers if n == name])
|
|
|
|
|
|
# Check if headers that MUST be present are actually present
|
|
|
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
|
|
|
'Origin', 'Sec-WebSocket-Version'):
|
|
|
- if name not in headers:
|
|
|
+ if name not in header_names:
|
|
|
raise InvalidRequest('missing "%s" header' % name)
|
|
|
|
|
|
# Check WebSocket version used by client
|
|
|
- version = headers['Sec-WebSocket-Version']
|
|
|
+ version = header('Sec-WebSocket-Version')
|
|
|
|
|
|
if version != WS_VERSION:
|
|
|
raise InvalidRequest('WebSocket version %s requested (only %s '
|
|
|
'is supported)' % (version, WS_VERSION))
|
|
|
|
|
|
- # Make sure the requested protocols are supported by this server
|
|
|
- if 'Sec-WebSocket-Protocol' in headers:
|
|
|
- parts = headers['Sec-WebSocket-Protocol'].split(',')
|
|
|
- protocols = map(str.strip, parts)
|
|
|
+ # Only supported protocols are returned
|
|
|
+ proto = header('Sec-WebSocket-Extensions')
|
|
|
+ protocols = split_stripped(proto) if proto else []
|
|
|
+ protocols = [p for p in protocols if p in self.protocols]
|
|
|
|
|
|
- for p in protocols:
|
|
|
- if p not in self.protocols:
|
|
|
- raise InvalidRequest('unsupported protocol "%s"' % p)
|
|
|
- else:
|
|
|
- protocols = []
|
|
|
+ # Only supported extensions are returned
|
|
|
+ ext = header('Sec-WebSocket-Extensions')
|
|
|
+ extensions = split_stripped(ext) if ext else []
|
|
|
+ extensions = [e for e in extensions if e in self.extensions]
|
|
|
|
|
|
# Encode acceptation key using the WebSocket GUID
|
|
|
- key = headers['Sec-WebSocket-Key']
|
|
|
+ key = header('Sec-WebSocket-Key')
|
|
|
accept = sha1(key + WS_GUID).digest().encode('base64')
|
|
|
|
|
|
# Construct HTTP response header
|
|
|
shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
|
|
|
shake += 'Upgrade: WebSocket\r\n'
|
|
|
shake += 'Connection: Upgrade\r\n'
|
|
|
- shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
|
|
|
+ shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
|
|
|
shake += 'WebSocket-Location: ws://%s%s\r\n' \
|
|
|
- % (headers['Host'], location)
|
|
|
+ % (header('Host'), location)
|
|
|
shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
|
|
|
|
|
|
- if self.protocols:
|
|
|
- shake += 'Sec-WebSocket-Protocol: %s\r\n' \
|
|
|
- % ', '.join(self.protocols)
|
|
|
+ if protocols:
|
|
|
+ shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
|
|
|
+
|
|
|
+ if extensions:
|
|
|
+ shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
|
|
|
|
|
|
self.sock.send(shake + '\r\n')
|
|
|
|