Quellcode durchsuchen

Added support for Origin checking (both client- and server-side)

Taddeus Kroes vor 12 Jahren
Ursprung
Commit
a215384a15
1 geänderte Dateien mit 28 neuen und 6 gelöschten Zeilen
  1. 28 6
      websocket.py

+ 28 - 6
websocket.py

@@ -42,8 +42,8 @@ class websocket(object):
     >>> sock.connect(('', 8000))
     >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
     """
-    def __init__(self, sock=None, protocols=[], extensions=[],
-                 sfamily=socket.AF_INET, sproto=0):
+    def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
+                 trusted_origins=[], sfamily=socket.AF_INET, sproto=0):
         """
         Create a regular TCP socket of family `family` and protocol
 
@@ -54,10 +54,20 @@ class websocket(object):
 
         `extensions` is a list of supported extensions.
 
+        `origin` (for client sockets) is the value for the "Origin" header sent
+        in a client handshake .
+
+        `trusted_origins` (for servere sockets) is a list of expected values
+        for the "Origin" header sent by a client. If the received Origin header
+        has value not in this list, a HandshakeError is raised. If the list is
+        empty (default), all origins are excepted.
+
         `sfamily` and `sproto` are used for the regular socket constructor.
         """
         self.protocols = protocols
         self.extensions = extensions
+        self.origin = origin
+        self.trusted_origins = trusted_origins
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
         self.secure = False
         self.handshake_started = False
@@ -184,9 +194,19 @@ class websocket(object):
         if 'upgrade' not in header('Connection').lower():
             fail('"Connection" header must contain "Upgrade"')
 
-        # Origin must be present if browser client
-        if 'User-Agent' in header_names and 'Origin' not in header_names:
-            fail('browser client must specify "Origin" header')
+        # Origin must be present if browser client, and must match the list of
+        # trusted origins
+        if 'Origin' not in header_names:
+            if 'User-Agent' in header_names:
+                fail('browser client must specify "Origin" header')
+
+            if self.trusted_origins:
+                fail('no "Origin" header specified, assuming untrusted')
+        elif self.trusted_origins:
+            origin = header('Origin')
+
+            if origin not in self.trusted_origins:
+                fail('untrusted origin "%s"' % origin)
 
         # Only supported protocols are returned
         client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
@@ -241,9 +261,11 @@ class websocket(object):
             shake += 'Upgrade: websocket\r\n'
             shake += 'Connection: keep-alive, Upgrade\r\n'
             shake += 'Sec-WebSocket-Key: %s\r\n' % key
-            shake += 'Origin: null\r\n'  # FIXME: is this correct/necessary?
             shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
 
+            if self.origin:
+                shake += 'Origin: %s\r\n' % self.origin
+
             # These are for eagerly caching webservers
             shake += 'Pragma: no-cache\r\n'
             shake += 'Cache-Control: no-cache\r\n'