Commit a215384a authored by Taddeüs Kroes's avatar Taddeüs Kroes

Added support for Origin checking (both client- and server-side)

parent b5619042
...@@ -42,8 +42,8 @@ class websocket(object): ...@@ -42,8 +42,8 @@ class websocket(object):
>>> sock.connect(('', 8000)) >>> sock.connect(('', 8000))
>>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!')) >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
""" """
def __init__(self, sock=None, protocols=[], extensions=[], def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
sfamily=socket.AF_INET, sproto=0): trusted_origins=[], sfamily=socket.AF_INET, sproto=0):
""" """
Create a regular TCP socket of family `family` and protocol Create a regular TCP socket of family `family` and protocol
...@@ -54,10 +54,20 @@ class websocket(object): ...@@ -54,10 +54,20 @@ class websocket(object):
`extensions` is a list of supported extensions. `extensions` is a list of supported extensions.
`origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake .
`trusted_origins` (for servere sockets) is a list of expected values
for the "Origin" header sent by a client. If the received Origin header
has value not in this list, a HandshakeError is raised. If the list is
empty (default), all origins are excepted.
`sfamily` and `sproto` are used for the regular socket constructor. `sfamily` and `sproto` are used for the regular socket constructor.
""" """
self.protocols = protocols self.protocols = protocols
self.extensions = extensions self.extensions = extensions
self.origin = origin
self.trusted_origins = trusted_origins
self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto) self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
self.secure = False self.secure = False
self.handshake_started = False self.handshake_started = False
...@@ -184,9 +194,19 @@ class websocket(object): ...@@ -184,9 +194,19 @@ class websocket(object):
if 'upgrade' not in header('Connection').lower(): if 'upgrade' not in header('Connection').lower():
fail('"Connection" header must contain "Upgrade"') fail('"Connection" header must contain "Upgrade"')
# Origin must be present if browser client # Origin must be present if browser client, and must match the list of
if 'User-Agent' in header_names and 'Origin' not in header_names: # trusted origins
fail('browser client must specify "Origin" header') if 'Origin' not in header_names:
if 'User-Agent' in header_names:
fail('browser client must specify "Origin" header')
if self.trusted_origins:
fail('no "Origin" header specified, assuming untrusted')
elif self.trusted_origins:
origin = header('Origin')
if origin not in self.trusted_origins:
fail('untrusted origin "%s"' % origin)
# Only supported protocols are returned # Only supported protocols are returned
client_protocols = split_stripped(header('Sec-WebSocket-Extensions')) client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
...@@ -241,9 +261,11 @@ class websocket(object): ...@@ -241,9 +261,11 @@ class websocket(object):
shake += 'Upgrade: websocket\r\n' shake += 'Upgrade: websocket\r\n'
shake += 'Connection: keep-alive, Upgrade\r\n' shake += 'Connection: keep-alive, Upgrade\r\n'
shake += 'Sec-WebSocket-Key: %s\r\n' % key shake += 'Sec-WebSocket-Key: %s\r\n' % key
shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
if self.origin:
shake += 'Origin: %s\r\n' % self.origin
# These are for eagerly caching webservers # These are for eagerly caching webservers
shake += 'Pragma: no-cache\r\n' shake += 'Pragma: no-cache\r\n'
shake += 'Cache-Control: no-cache\r\n' shake += 'Cache-Control: no-cache\r\n'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment