瀏覽代碼

Extensions are now passed to websocket.accept() properly, fixed origin/location checking

Taddeus Kroes 12 年之前
父節點
當前提交
6ec710a518
共有 2 個文件被更改,包括 41 次插入23 次删除
  1. 25 17
      handshake.py
  2. 16 6
      websocket.py

+ 25 - 17
handshake.py

@@ -27,7 +27,8 @@ class Handshake(object):
         raw, headers = self.receive_headers()
 
         # Request must be HTTP (at least 1.1) GET request, find the location
-        match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw)
+        # (without trailing slash)
+        match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
 
         if match is None:
             self.fail('not a valid HTTP 1.1 GET request')
@@ -84,14 +85,18 @@ class ServerHandshake(Handshake):
     request headers sent by the client are invalid, a HandshakeError is raised.
     """
 
-    def perform(self):
+    def perform(self, ssock):
         # Receive and validate client handshake
-        self.wsock.location, headers = self.receive_request()
+        location, headers = self.receive_request()
+        self.wsock.location = location
+        self.wsock.request_headers = headers
 
         # Send server handshake in response
-        self.send_headers(self.response_headers(headers))
+        self.send_headers(self.response_headers(ssock))
+
+    def response_headers(self, ssock):
+        headers = self.wsock.request_headers
 
-    def response_headers(self, headers):
         # Check if headers that MUST be present are actually present
         for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
                      'Sec-WebSocket-Version'):
@@ -114,19 +119,16 @@ class ServerHandshake(Handshake):
 
         # Origin must be present if browser client, and must match the list of
         # trusted origins
-        origin = 'null'
+        origin = headers.get('Origin', 'null')
 
-        if 'Origin' not in headers:
+        if origin == 'null':
             if 'User-Agent' in headers:
                 self.fail('browser client must specify "Origin" header')
 
-            if self.wsock.trusted_origins:
+            if ssock.trusted_origins:
                 self.fail('no "Origin" header specified, assuming untrusted')
-        elif self.wsock.trusted_origins:
-            origin = headers['Origin']
-
-            if origin not in self.wsock.trusted_origins:
-                self.fail('untrusted origin "%s"' % origin)
+        elif ssock.trusted_origins and origin not in ssock.trusted_origins:
+            self.fail('untrusted origin "%s"' % origin)
 
         # Only a supported protocol can be returned
         client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
@@ -134,13 +136,13 @@ class ServerHandshake(Handshake):
         self.wsock.protocol = None
 
         for p in client_proto:
-            if p in self.wsock.protocols:
+            if p in ssock.protocols:
                 self.wsock.protocol = p
                 break
 
         # Only supported extensions are returned
         if 'Sec-WebSocket-Extensions' in headers:
-            supported_ext = dict((e.name, e) for e in self.wsock.extensions)
+            supported_ext = dict((e.name, e) for e in ssock.extensions)
             extensions = []
             all_params = []
 
@@ -159,6 +161,12 @@ class ServerHandshake(Handshake):
         else:
             self.wsock.extensions = []
 
+        # Check if requested resource location is served by this server
+        if ssock.locations:
+            if self.wsock.location not in ssock.locations:
+                raise HandshakeError('location "%s" is not supported by this '
+                                     'server' % self.wsock.location)
+
         # Encode acceptation key using the WebSocket GUID
         key = headers['Sec-WebSocket-Key'].strip()
         accept = b64encode(sha1(key + WS_GUID).digest())
@@ -181,8 +189,8 @@ class ServerHandshake(Handshake):
         yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
         yield 'Upgrade', 'websocket'
         yield 'Connection', 'Upgrade'
-        yield 'WebSocket-Origin', origin
-        yield 'WebSocket-Location', location
+        yield 'Sec-WebSocket-Origin', origin
+        yield 'Sec-WebSocket-Location', location
         yield 'Sec-WebSocket-Accept', accept
 
         if self.wsock.protocol:

+ 16 - 6
websocket.py

@@ -31,7 +31,7 @@ class websocket(object):
     >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
     """
     def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
-                 trusted_origins=[], location='/', auth=None,
+                 location='/', trusted_origins=[], locations=[], auth=None,
                  sfamily=socket.AF_INET, sproto=0):
         """
         Create a regular TCP socket of family `family` and protocol
@@ -46,13 +46,21 @@ class websocket(object):
         `origin` (for client sockets) is the value for the "Origin" header sent
         in a client handshake .
 
-        `trusted_origins` (for servere sockets) is a list of expected values
+        `location` (for client sockets) is optional, used to request a
+        particular resource in the HTTP handshake. In a URL, this would show as
+        ws://host[:port]/<location>. Use this when the server serves multiple
+        resources (see `locations`).
+
+        `trusted_origins` (for server sockets) is a list of expected values
         for the "Origin" header sent by a client. If the received Origin header
         has value not in this list, a HandshakeError is raised. If the list is
         empty (default), all origins are excepted.
 
-        `location` is optional, used for the HTTP handshake. In a URL, this
-        would show as ws://host[:port]/path.
+        `locations` (for server sockets) is an optional list of resources
+        serverd by this server. If specified (without trailing slashes), these
+        are used to verify the resource location requested by a client. The
+        requested location may be used to distinquish different services in a
+        server implementation.
 
         `auth` is optional, used for HTTP Basic or Digest authentication during
         the handshake. It must be specified as a (username, password) tuple.
@@ -62,8 +70,9 @@ class websocket(object):
         self.protocols = protocols
         self.extensions = extensions
         self.origin = origin
-        self.trusted_origins = trusted_origins
         self.location = location
+        self.trusted_origins = trusted_origins
+        self.locations = locations
         self.auth = auth
 
         self.secure = False
@@ -90,7 +99,8 @@ class websocket(object):
         """
         sock, address = self.sock.accept()
         wsock = websocket(sock)
-        ServerHandshake(wsock).perform()
+        wsock.secure = self.secure
+        ServerHandshake(wsock).perform(self)
         wsock.handshake_sent = True
         return wsock, address