فهرست منبع

Revised extension instantiation, now 'hooks' are installed which are cleaner and more flexible

Taddeus Kroes 12 سال پیش
والد
کامیت
e465862ff0
3فایلهای تغییر یافته به همراه84 افزوده شده و 80 حذف شده
  1. 56 70
      extension.py
  2. 14 5
      handshake.py
  3. 14 5
      websocket.py

+ 56 - 70
extension.py

@@ -1,44 +1,38 @@
-from errors import HandshakeError
-
-
 class Extension(object):
     name = ''
     rsv1 = False
     rsv2 = False
     rsv3 = False
     opcodes = []
-    parameters = []
-
-    def __init__(self, **kwargs):
-        for param in self.parameters:
-            setattr(self, param, None)
+    defaults = {}
+    request = {}
 
-        for param, value in kwargs.items():
-            if param not in self.parameters:
-                raise HandshakeError('unrecognized parameter "%s"' % param)
+    def __init__(self, defaults={}, request={}):
+        for param in defaults.keys() + request.keys():
+            if param not in self.defaults:
+                raise KeyError('unrecognized parameter "%s"' % param)
 
-            if value is None:
-                value = True
+        # Copy dict first to avoid duplicate references to the same object
+        self.defaults = dict(self.__class__.defaults)
+        self.defaults.update(defaults)
 
-            setattr(self, param, value)
+        self.request = dict(self.__class__.request)
+        self.request.update(request)
 
     def __str__(self, frame):
-        if len(self.parameters):
-            params = ' ' + ', '.join(p + '=' + str(getattr(self, p))
-                                     for p in self.parameters)
-        else:
-            params = ''
+        return '<Extension "%s" defaults=%s request=%s>' \
+               % (self.name, self.defaults, self.request)
 
-        return '<Extension "%s"%s>' % (self.name, params)
+    class Hook:
+        def __init__(self, **kwargs):
+            for param, value in kwargs.iteritems():
+                setattr(self, param, value)
 
-    def header_params(self, frame):
-        return {}
+        def send(self, frame):
+            return frame
 
-    def hook_send(self, frame):
-        return frame
-
-    def hook_receive(self, frame):
-        return frame
+        def recv(self, frame):
+            return frame
 
 
 class DeflateFrame(Extension):
@@ -57,49 +51,43 @@ class DeflateFrame(Extension):
 
     name = 'deflate-frame'
     rsv1 = True
-    parameters = ['max_window_bits', 'no_context_takeover']
-
-    # FIXME: is this correct?
-    default_max_window_bits = 32768
+    # FIXME: is 32768 (below) correct?
+    defaults = {'max_window_bits': 32768, 'no_context_takeover': True}
 
-    def __init__(self, **kwargs):
-        super(DeflateFrame, self).__init__(**kwargs)
+    def __init__(self, defaults={}, request={}):
+        Extension.__init__(self, defaults, request)
 
-        if self.max_window_bits is None:
-            self.max_window_bits = self.default_max_window_bits
-        elif not isinstance(self.max_window_bits, int):
-            raise HandshakeError('"max_window_bits" must be an integer')
-        elif self.max_window_bits > 32768:
-            raise HandshakeError('"max_window_bits" may not be larger than '
-                                 '32768')
+        mwb = self.defaults['max_window_bits']
+        cto = self.defaults['no_context_takeover']
 
-        if self.no_context_takeover is None:
-            self.no_context_takeover = False
-        elif self.no_context_takeover is not True:
-            raise HandshakeError('"no_context_takeover" must have no value')
+        if not isinstance(mwb, int):
+            raise ValueError('"max_window_bits" must be an integer')
+        elif mwb > 32768:
+            raise ValueError('"max_window_bits" may not be larger than 32768')
 
-    def hook_send(self, frame):
-        if not frame.rsv1:
-            frame.rsv1 = True
-            frame.payload = self.deflate(frame.payload)
+        if cto is not False and cto is not True:
+            raise ValueError('"no_context_takeover" must have no value')
 
-        return frame
+    class Hook:
+        def send(self, frame):
+            if not frame.rsv1:
+                frame.rsv1 = True
+                frame.payload = self.deflate(frame.payload)
 
-    def hook_recv(self, frame):
-        if frame.rsv1:
-            frame.rsv1 = False
-            frame.payload = self.inflate(frame.payload)
+            return frame
 
-        return frame
+        def recv(self, frame):
+            if frame.rsv1:
+                frame.rsv1 = False
+                frame.payload = self.inflate(frame.payload)
 
-    def header_params(self):
-        raise NotImplementedError  # TODO
+            return frame
 
-    def deflate(self, data):
-        raise NotImplementedError  # TODO
+        def deflate(self, data):
+            raise NotImplementedError  # TODO
 
-    def inflate(self, data):
-        raise NotImplementedError  # TODO
+        def inflate(self, data):
+            raise NotImplementedError  # TODO
 
 
 class Multiplex(Extension):
@@ -115,21 +103,19 @@ class Multiplex(Extension):
     rsv1 = True  # FIXME
     rsv2 = True  # FIXME
     rsv3 = True  # FIXME
-    parameters = ['quota']
+    defaults = {'quota': None}
 
-    def __init__(self, **kwargs):
-        super(Multiplex, self).__init__(**kwargs)
+    def __init__(self, defaults={}, request={}):
+        Extension.__init__(self, defaults, request)
 
         # TODO: check "quota" value
 
-    def hook_send(self, frame):
-        raise NotImplementedError  # TODO
-
-    def hook_recv(self, frame):
-        raise NotImplementedError  # TODO
+    class Hook:
+        def send(self, frame):
+            raise NotImplementedError  # TODO
 
-    def header_params(self):
-        raise NotImplementedError  # TODO
+        def recv(self, frame):
+            raise NotImplementedError  # TODO
 
 
 def filter_extensions(extensions):

+ 14 - 5
handshake.py

@@ -142,14 +142,20 @@ class ServerHandshake(Handshake):
         if 'Sec-WebSocket-Extensions' in headers:
             supported_ext = dict((e.name, e) for e in self.wsock.extensions)
             extensions = []
+            all_params = []
 
             for ext in split_stripped(headers['Sec-WebSocket-Extensions']):
                 name, params = parse_param_hdr(ext)
 
                 if name in supported_ext:
-                    extensions.append(supported_ext[name](**params))
+                    extensions.append(supported_ext[name])
+                    all_params.append(params)
 
             self.wsock.extensions = filter_extensions(extensions)
+
+            for ext, params in zip(self.wsock.extensions, all_params):
+                hook = ext.Hook(**params)
+                self.wsock.add_hook(send=hook.send, recv=hook.recv)
         else:
             self.wsock.extensions = []
 
@@ -183,10 +189,11 @@ class ServerHandshake(Handshake):
             yield 'Sec-WebSocket-Protocol', self.wsock.protocol
 
         if self.wsock.extensions:
-            values = [format_param_hdr(e.name, e.header_params())
+            values = [format_param_hdr(e.name, e.request)
                       for e in self.wsock.extensions]
             yield 'Sec-WebSocket-Extensions', ', '.join(values)
 
+
 class ClientHandshake(Handshake):
     """
     Executes a handshake as the client end point of the socket. May raise a
@@ -230,7 +237,7 @@ class ClientHandshake(Handshake):
         if accept != required_accept:
             self.fail('invalid websocket accept header "%s"' % accept)
 
-        # Compare extensions
+        # Compare extensions, add hooks only for those returned by server
         if 'Sec-WebSocket-Extensions' in headers:
             supported_ext = dict((e.name, e) for e in self.wsock.extensions)
             self.wsock.extensions = []
@@ -242,7 +249,9 @@ class ClientHandshake(Handshake):
                     raise HandshakeError('server handshake contains '
                                          'unsupported extension "%s"' % name)
 
-                self.wsock.extensions.append(supported_ext[name](**params))
+                hook = supported_ext[name].Hook(**params)
+                self.wsock.extensions.append(supported_ext[name])
+                self.wsock.add_hook(send=hook.send, recv=hook.recv)
 
         # Assert that returned protocol (if any) is supported
         if 'Sec-WebSocket-Protocol' in headers:
@@ -325,7 +334,7 @@ class ClientHandshake(Handshake):
             yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
 
         if self.wsock.extensions:
-            values = [format_param_hdr(e.name, e.header_params())
+            values = [format_param_hdr(e.name, e.request)
                       for e in self.wsock.extensions]
             yield 'Sec-WebSocket-Extensions', ', '.join(values)
 

+ 14 - 5
websocket.py

@@ -41,7 +41,7 @@ class websocket(object):
 
         `protocols` is a list of supported protocol names.
 
-        `extensions` is a list of supported extension classes.
+        `extensions` is a list of supported extensions (`Extension` instances).
 
         `origin` (for client sockets) is the value for the "Origin" header sent
         in a client handshake .
@@ -68,6 +68,8 @@ class websocket(object):
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
         self.secure = False
         self.handshake_sent = False
+        self.hooks_send = []
+        self.hooks_recv = []
 
     def bind(self, address):
         self.sock.bind(address)
@@ -104,8 +106,8 @@ class websocket(object):
         Send a number of frames.
         """
         for frame in args:
-            for ext in self.extensions:
-                frame = ext.hook_send(frame)
+            for hook in self.hooks_send:
+                frame = hook(frame)
 
             #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
             self.sock.sendall(frame.pack())
@@ -117,8 +119,8 @@ class websocket(object):
         """
         frame = receive_frame(self.sock)
 
-        for ext in reversed(self.extensions):
-            frame = ext.hook_recv(frame)
+        for hook in self.hooks_recv:
+            frame = hook(frame)
 
         #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
         return frame
@@ -156,3 +158,10 @@ class websocket(object):
 
         self.secure = True
         self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)
+
+    def add_hook(self, send=None, recv=None):
+        if send:
+            self.hooks_send.append(send)
+
+        if recv:
+            self.hooks_recv.prepend(recv)