Explorar el Código

websocket class now supports a list of supported extensions (framing part is not implemented yet)

Taddeus Kroes hace 13 años
padre
commit
ca2b552ac3
Se han modificado 1 ficheros con 38 adiciones y 22 borrados
  1. 38 22
      websocket.py

+ 38 - 22
websocket.py

@@ -10,6 +10,10 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
 WS_VERSION = '13'
 
 
+def split_stripped(value, delim=','):
+    return map(str.strip, value.split(delim))
+
+
 class websocket(object):
     """
     Implementation of web socket, upgrades a regular TCP socket to a websocket
@@ -28,13 +32,20 @@ class websocket(object):
     >>> sock = websocket()
     >>> sock.connect(('kompiler.org', 80))
     """
-    def __init__(self, wsprotocols=[], family=socket.AF_INET, proto=0):
+    def __init__(self, protocols=[], extensions=[], family=socket.AF_INET,
+            proto=0):
         """
-        Create aregular TCP socket of family `family` and protocol
-        `wsprotocols` is a list of supported protocol names.
+        Create a regular TCP socket of family `family` and protocol
+
+        `protocols` is a list of supported protocol names.
+
+        `extensions` is a list of supported extensions.
+
+        `family` and `proto` are used for the regular socket constructor.
         """
+        self.protocols = protocols
+        self.extensions = extensions
         self.sock = socket.socket(family, socket.SOCK_STREAM, proto)
-        self.protocols = wsprotocols
 
     def bind(self, address):
         self.sock.bind(address)
@@ -98,48 +109,53 @@ class websocket(object):
 
         # request must be HTTP (at least 1.1) GET request, find the location
         location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
-        headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
+        headers = re.findall(r'(.*?): (.*?)\r\n', raw_headers)
+        header_names = [name for name, value in headers]
+
+        def header(name):
+            return ', '.join([v for n, v in headers if n == name])
 
         # Check if headers that MUST be present are actually present
         for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
                      'Origin', 'Sec-WebSocket-Version'):
-            if name not in headers:
+            if name not in header_names:
                 raise InvalidRequest('missing "%s" header' % name)
 
         # Check WebSocket version used by client
-        version = headers['Sec-WebSocket-Version']
+        version = header('Sec-WebSocket-Version')
 
         if version != WS_VERSION:
             raise InvalidRequest('WebSocket version %s requested (only %s '
                                  'is supported)' % (version, WS_VERSION))
 
-        # Make sure the requested protocols are supported by this server
-        if 'Sec-WebSocket-Protocol' in headers:
-            parts = headers['Sec-WebSocket-Protocol'].split(',')
-            protocols = map(str.strip, parts)
+        # Only supported protocols are returned
+        proto = header('Sec-WebSocket-Extensions')
+        protocols = split_stripped(proto) if proto else []
+        protocols = [p for p in protocols if p in self.protocols]
 
-            for p in protocols:
-                if p not in self.protocols:
-                    raise InvalidRequest('unsupported protocol "%s"' % p)
-        else:
-            protocols = []
+        # Only supported extensions are returned
+        ext = header('Sec-WebSocket-Extensions')
+        extensions = split_stripped(ext) if ext else []
+        extensions = [e for e in extensions if e in self.extensions]
 
         # Encode acceptation key using the WebSocket GUID
-        key = headers['Sec-WebSocket-Key']
+        key = header('Sec-WebSocket-Key')
         accept = sha1(key + WS_GUID).digest().encode('base64')
 
         # Construct HTTP response header
         shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
         shake += 'Upgrade: WebSocket\r\n'
         shake += 'Connection: Upgrade\r\n'
-        shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
+        shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
         shake += 'WebSocket-Location: ws://%s%s\r\n' \
-                 % (headers['Host'], location)
+                 % (header('Host'), location)
         shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
 
-        if self.protocols:
-            shake += 'Sec-WebSocket-Protocol: %s\r\n' \
-                     % ', '.join(self.protocols)
+        if protocols:
+            shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
+
+        if extensions:
+            shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
 
         self.sock.send(shake + '\r\n')