Ver Fonte

Implemented extensions in handshakes and installed extension hooks in websocket

Taddeus Kroes há 12 anos atrás
pai
commit
7d246d5dba
4 ficheiros alterados com 98 adições e 47 exclusões
  1. 2 1
      TODO
  2. 17 24
      extension.py
  3. 72 22
      handshake.py
  4. 7 0
      websocket.py

+ 2 - 1
TODO

@@ -1,4 +1,5 @@
 - (Unit) tests
 - Mutual exclusion in Server/Client (multiple threads sending stuff at the same
   time will go wrong)
-- Extensions
+- Extensions: pass parameters for Extension.header_params() to websocket
+  constructor

+ 17 - 24
extension.py

@@ -2,6 +2,7 @@ from errors import HandshakeError
 
 
 class Extension(object):
+    name = ''
     rsv1 = False
     rsv2 = False
     rsv3 = False
@@ -21,7 +22,16 @@ class Extension(object):
 
             setattr(self, param, value)
 
-    def client_params(self, frame):
+    def __str__(self, frame):
+        if len(self.parameters):
+            params = ' ' + ', '.join(p + '=' + str(getattr(self, p))
+                                     for p in self.parameters)
+        else:
+            params = ''
+
+        return '<Extension "%s"%s>' % (self.name, params)
+
+    def header_params(self, frame):
         return {}
 
     def hook_send(self, frame):
@@ -38,7 +48,10 @@ class DeflateFrame(Extension):
 
     def __init__(self, **kwargs):
         super(DeflateFrame, self).__init__(**kwargs)
-        self.max_window_bits = int(self.max_window_bits)
+
+        if self.max_window_bits is None:
+            # FIXME: is this correct? None may actually be a better value
+            self.max_window_bits = 0
 
     def hook_send(self, frame):
         # FIXME: original `frame` is modified, maybe it should be copied?
@@ -58,7 +71,7 @@ class DeflateFrame(Extension):
 
         return frame
 
-    def client_params(self):
+    def header_params(self):
         raise NotImplementedError  # TODO
 
     def encode(self, data):
@@ -71,7 +84,7 @@ class DeflateFrame(Extension):
 def filter_extensions(extensions):
     """
     Remove extensions that use conflicting rsv bits and/or opcodes, with the
-    first options being most preferable.
+    first options being the most preferable.
     """
     rsv1_reserved = True
     rsv2_reserved = True
@@ -93,23 +106,3 @@ def filter_extensions(extensions):
         compat.append(ext)
 
     return compat
-
-
-"""
-Class map used to find contructors for client-specified extensions. Not to be
-modified manually, only through `register_extension`.
-"""
-extension_class_map = {}
-
-
-def register_extension(ext):
-    if not isinstance(ext, Extension):
-        raise ValueError('extensions should extend the `Extension` class')
-
-    if ext.name in extension_clas_map:
-        raise KeyError('extension "%s" has already been registered' % ext.name)
-
-    extension_class_map[ext.name] = ext
-
-
-register_extension(DeflateFrame)

+ 72 - 22
handshake.py

@@ -6,6 +6,7 @@ from urlparse import urlparse
 
 from python_digest import build_authorization_request
 from errors import HandshakeError
+from extension import filter_extensions
 
 
 WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
@@ -13,10 +14,6 @@ WS_VERSION = '13'
 MAX_REDIRECTS = 10
 
 
-def split_stripped(value, delim=','):
-    return map(str.strip, str(value).split(delim)) if value else []
-
-
 class Handshake(object):
     def __init__(self, wsock):
         self.wsock = wsock
@@ -134,20 +131,28 @@ class ServerHandshake(Handshake):
         # Only a supported protocol can be returned
         client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
                        if 'Sec-WebSocket-Protocol' in headers else []
-        protocol = None
+        self.wsock.protocol = None
 
         for p in client_proto:
-            if p in self.wsock.proto:
-                protocol = p
+            if p in self.wsock.protocols:
+                self.wsock.protocol = p
                 break
 
         # Only supported extensions are returned
         if 'Sec-WebSocket-Extensions' in headers:
-            client_ext = split_stripped(headers['Sec-WebSocket-Extensions'])
-            extensions = [e for e in client_ext if e in self.wsock.extensions]
-        else:
+            supported_ext = dict((e.name, e) for e in self.wsock.extensions)
             extensions = []
 
+            for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
+                name, params = parse_param_hdr(ext)
+
+                if name in supported_ext:
+                    extensions.append(supported_ext[name](**params))
+
+            self.wsock.extensions = filter_extensions(extensions)
+        else:
+            self.wsock.extensions = []
+
         # Encode acceptation key using the WebSocket GUID
         key = headers['Sec-WebSocket-Key'].strip()
         accept = b64encode(sha1(key + WS_GUID).digest())
@@ -174,12 +179,13 @@ class ServerHandshake(Handshake):
         yield 'WebSocket-Location', location
         yield 'Sec-WebSocket-Accept', accept
 
-        if protocol:
-            yield 'Sec-WebSocket-Protocol', protocol
-
-        if extensions:
-            yield 'Sec-WebSocket-Extensions', ', '.join(extensions)
+        if self.wsock.protocol:
+            yield 'Sec-WebSocket-Protocol', self.wsock.protocol
 
+        if self.wsock.extensions:
+            values = [format_param_hdr(e.name, e.header_params())
+                      for e in self.wsock.extensions]
+            yield 'Sec-WebSocket-Extensions', ', '.join(values)
 
 class ClientHandshake(Handshake):
     """
@@ -226,13 +232,17 @@ class ClientHandshake(Handshake):
 
         # Compare extensions
         if 'Sec-WebSocket-Extensions' in headers:
-            server_ext = split_stripped(headers['Sec-WebSocket-Extensions'])
+            supported_ext = dict((e.name, e) for e in self.wsock.extensions)
+            self.wsock.extensions = []
 
-            for e in set(server_ext) - set(self.wsock.extensions):
-                self.fail('server extension "%s" unsupported by client' % e)
+            for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
+                name, params = parse_param_hdr(ext)
 
-            for e in set(self.wsock.extensions) - set(server_ext):
-                self.fail('client extension "%s" unsupported by server' % e)
+                if name not in supported_ext:
+                    raise HandshakeError('server handshake contains '
+                                         'unsupported extension "%s"' % name)
+
+                self.wsock.extensions.append(supported_ext[name](**params))
 
         # Assert that returned protocol (if any) is supported
         if 'Sec-WebSocket-Protocol' in headers:
@@ -309,13 +319,15 @@ class ClientHandshake(Handshake):
         yield 'Pragma', 'no-cache'
         yield 'Cache-Control', 'no-cache'
 
-        # Request protocols and extension, these are later checked with the
+        # Request protocols and extensions, these are later checked with the
         # actual supported values from the server's response
         if self.wsock.protocols:
             yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
 
         if self.wsock.extensions:
-            yield 'Sec-WebSocket-Extensions', ', '.join(self.wsock.extensions)
+            values = [format_param_hdr(e.name, e.header_params())
+                      for e in self.wsock.extensions]
+            yield 'Sec-WebSocket-Extensions', ', '.join(values)
 
     def http_auth_basic_headers(self, **kwargs):
         u, p = self.wsock.auth
@@ -334,3 +346,41 @@ class ClientHandshake(Handshake):
                                 nonce=kwargs['nonce'],
                                 opaque=kwargs['opaque'],
                                 password=password.encode('utf-8'))
+
+
+def split_stripped(value, delim=','):
+    return map(str.strip, str(value).split(delim)) if value else []
+
+
+def parse_param_hdr(hdr):
+    name, paramstr = split_stripped(hdr, ';')
+    params = {}
+
+    for param in split_stripped(paramstr):
+        if '=' in param:
+            key, value = split_stripped(param, '=')
+
+            if value.isdigit():
+                value = int(value)
+        else:
+            key = param
+            value = True
+
+        params[key] = value
+
+    yield name, params
+
+
+def format_param_hdr(value, params):
+    if not params:
+        return value
+
+    def fmt_param((k, v)):
+        if v is True:
+            return k
+
+        if v is not False and v is not None:
+            return k + '=' + str(v)
+
+    strparams = filter(None, map(fmt_param, params.items()))
+    return '%s; %s' % (value, ', '.join(strparams))

+ 7 - 0
websocket.py

@@ -104,6 +104,9 @@ class websocket(object):
         Send a number of frames.
         """
         for frame in args:
+            for ext in self.extensions:
+                frame = ext.hook_send(frame)
+
             #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
             self.sock.sendall(frame.pack())
 
@@ -113,6 +116,10 @@ class websocket(object):
         frame.
         """
         frame = receive_frame(self.sock)
+
+        for ext in reversed(self.extensions):
+            frame = ext.hook_recv(frame)
+
         #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
         return frame