Commit ca2b552a authored by Taddeüs Kroes's avatar Taddeüs Kroes

websocket class now supports a list of supported extensions (framing part is not implemented yet)

parent c7572ccb
...@@ -10,6 +10,10 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' ...@@ -10,6 +10,10 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
WS_VERSION = '13' WS_VERSION = '13'
def split_stripped(value, delim=','):
return map(str.strip, value.split(delim))
class websocket(object): class websocket(object):
""" """
Implementation of web socket, upgrades a regular TCP socket to a websocket Implementation of web socket, upgrades a regular TCP socket to a websocket
...@@ -28,13 +32,20 @@ class websocket(object): ...@@ -28,13 +32,20 @@ class websocket(object):
>>> sock = websocket() >>> sock = websocket()
>>> sock.connect(('kompiler.org', 80)) >>> sock.connect(('kompiler.org', 80))
""" """
def __init__(self, wsprotocols=[], family=socket.AF_INET, proto=0): def __init__(self, protocols=[], extensions=[], family=socket.AF_INET,
proto=0):
""" """
Create aregular TCP socket of family `family` and protocol Create a regular TCP socket of family `family` and protocol
`wsprotocols` is a list of supported protocol names.
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extensions.
`family` and `proto` are used for the regular socket constructor.
""" """
self.protocols = protocols
self.extensions = extensions
self.sock = socket.socket(family, socket.SOCK_STREAM, proto) self.sock = socket.socket(family, socket.SOCK_STREAM, proto)
self.protocols = wsprotocols
def bind(self, address): def bind(self, address):
self.sock.bind(address) self.sock.bind(address)
...@@ -98,48 +109,53 @@ class websocket(object): ...@@ -98,48 +109,53 @@ class websocket(object):
# request must be HTTP (at least 1.1) GET request, find the location # request must be HTTP (at least 1.1) GET request, find the location
location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1) location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers)) headers = re.findall(r'(.*?): (.*?)\r\n', raw_headers)
header_names = [name for name, value in headers]
def header(name):
return ', '.join([v for n, v in headers if n == name])
# Check if headers that MUST be present are actually present # Check if headers that MUST be present are actually present
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key', for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
'Origin', 'Sec-WebSocket-Version'): 'Origin', 'Sec-WebSocket-Version'):
if name not in headers: if name not in header_names:
raise InvalidRequest('missing "%s" header' % name) raise InvalidRequest('missing "%s" header' % name)
# Check WebSocket version used by client # Check WebSocket version used by client
version = headers['Sec-WebSocket-Version'] version = header('Sec-WebSocket-Version')
if version != WS_VERSION: if version != WS_VERSION:
raise InvalidRequest('WebSocket version %s requested (only %s ' raise InvalidRequest('WebSocket version %s requested (only %s '
'is supported)' % (version, WS_VERSION)) 'is supported)' % (version, WS_VERSION))
# Make sure the requested protocols are supported by this server # Only supported protocols are returned
if 'Sec-WebSocket-Protocol' in headers: proto = header('Sec-WebSocket-Extensions')
parts = headers['Sec-WebSocket-Protocol'].split(',') protocols = split_stripped(proto) if proto else []
protocols = map(str.strip, parts) protocols = [p for p in protocols if p in self.protocols]
for p in protocols: # Only supported extensions are returned
if p not in self.protocols: ext = header('Sec-WebSocket-Extensions')
raise InvalidRequest('unsupported protocol "%s"' % p) extensions = split_stripped(ext) if ext else []
else: extensions = [e for e in extensions if e in self.extensions]
protocols = []
# Encode acceptation key using the WebSocket GUID # Encode acceptation key using the WebSocket GUID
key = headers['Sec-WebSocket-Key'] key = header('Sec-WebSocket-Key')
accept = sha1(key + WS_GUID).digest().encode('base64') accept = sha1(key + WS_GUID).digest().encode('base64')
# Construct HTTP response header # Construct HTTP response header
shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n' shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
shake += 'Upgrade: WebSocket\r\n' shake += 'Upgrade: WebSocket\r\n'
shake += 'Connection: Upgrade\r\n' shake += 'Connection: Upgrade\r\n'
shake += 'WebSocket-Origin: %s\r\n' % headers['Origin'] shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
shake += 'WebSocket-Location: ws://%s%s\r\n' \ shake += 'WebSocket-Location: ws://%s%s\r\n' \
% (headers['Host'], location) % (header('Host'), location)
shake += 'Sec-WebSocket-Accept: %s\r\n' % accept shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
if self.protocols: if protocols:
shake += 'Sec-WebSocket-Protocol: %s\r\n' \ shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
% ', '.join(self.protocols)
if extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
self.sock.send(shake + '\r\n') self.sock.send(shake + '\r\n')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment