Commit 6ec710a5 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Extensions are now passed to websocket.accept() properly, fixed origin/location checking

parent 8db0a06c
......@@ -27,7 +27,8 @@ class Handshake(object):
raw, headers = self.receive_headers()
# Request must be HTTP (at least 1.1) GET request, find the location
match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw)
# (without trailing slash)
match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
if match is None:
self.fail('not a valid HTTP 1.1 GET request')
......@@ -84,14 +85,18 @@ class ServerHandshake(Handshake):
request headers sent by the client are invalid, a HandshakeError is raised.
"""
def perform(self):
def perform(self, ssock):
# Receive and validate client handshake
self.wsock.location, headers = self.receive_request()
location, headers = self.receive_request()
self.wsock.location = location
self.wsock.request_headers = headers
# Send server handshake in response
self.send_headers(self.response_headers(headers))
self.send_headers(self.response_headers(ssock))
def response_headers(self, ssock):
headers = self.wsock.request_headers
def response_headers(self, headers):
# Check if headers that MUST be present are actually present
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
'Sec-WebSocket-Version'):
......@@ -114,18 +119,15 @@ class ServerHandshake(Handshake):
# Origin must be present if browser client, and must match the list of
# trusted origins
origin = 'null'
origin = headers.get('Origin', 'null')
if 'Origin' not in headers:
if origin == 'null':
if 'User-Agent' in headers:
self.fail('browser client must specify "Origin" header')
if self.wsock.trusted_origins:
if ssock.trusted_origins:
self.fail('no "Origin" header specified, assuming untrusted')
elif self.wsock.trusted_origins:
origin = headers['Origin']
if origin not in self.wsock.trusted_origins:
elif ssock.trusted_origins and origin not in ssock.trusted_origins:
self.fail('untrusted origin "%s"' % origin)
# Only a supported protocol can be returned
......@@ -134,13 +136,13 @@ class ServerHandshake(Handshake):
self.wsock.protocol = None
for p in client_proto:
if p in self.wsock.protocols:
if p in ssock.protocols:
self.wsock.protocol = p
break
# Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions)
supported_ext = dict((e.name, e) for e in ssock.extensions)
extensions = []
all_params = []
......@@ -159,6 +161,12 @@ class ServerHandshake(Handshake):
else:
self.wsock.extensions = []
# Check if requested resource location is served by this server
if ssock.locations:
if self.wsock.location not in ssock.locations:
raise HandshakeError('location "%s" is not supported by this '
'server' % self.wsock.location)
# Encode acceptation key using the WebSocket GUID
key = headers['Sec-WebSocket-Key'].strip()
accept = b64encode(sha1(key + WS_GUID).digest())
......@@ -181,8 +189,8 @@ class ServerHandshake(Handshake):
yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
yield 'Upgrade', 'websocket'
yield 'Connection', 'Upgrade'
yield 'WebSocket-Origin', origin
yield 'WebSocket-Location', location
yield 'Sec-WebSocket-Origin', origin
yield 'Sec-WebSocket-Location', location
yield 'Sec-WebSocket-Accept', accept
if self.wsock.protocol:
......
......@@ -31,7 +31,7 @@ class websocket(object):
>>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
"""
def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
trusted_origins=[], location='/', auth=None,
location='/', trusted_origins=[], locations=[], auth=None,
sfamily=socket.AF_INET, sproto=0):
"""
Create a regular TCP socket of family `family` and protocol
......@@ -46,13 +46,21 @@ class websocket(object):
`origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake .
`trusted_origins` (for servere sockets) is a list of expected values
`location` (for client sockets) is optional, used to request a
particular resource in the HTTP handshake. In a URL, this would show as
ws://host[:port]/<location>. Use this when the server serves multiple
resources (see `locations`).
`trusted_origins` (for server sockets) is a list of expected values
for the "Origin" header sent by a client. If the received Origin header
has value not in this list, a HandshakeError is raised. If the list is
empty (default), all origins are excepted.
`location` is optional, used for the HTTP handshake. In a URL, this
would show as ws://host[:port]/path.
`locations` (for server sockets) is an optional list of resources
serverd by this server. If specified (without trailing slashes), these
are used to verify the resource location requested by a client. The
requested location may be used to distinquish different services in a
server implementation.
`auth` is optional, used for HTTP Basic or Digest authentication during
the handshake. It must be specified as a (username, password) tuple.
......@@ -62,8 +70,9 @@ class websocket(object):
self.protocols = protocols
self.extensions = extensions
self.origin = origin
self.trusted_origins = trusted_origins
self.location = location
self.trusted_origins = trusted_origins
self.locations = locations
self.auth = auth
self.secure = False
......@@ -90,7 +99,8 @@ class websocket(object):
"""
sock, address = self.sock.accept()
wsock = websocket(sock)
ServerHandshake(wsock).perform()
wsock.secure = self.secure
ServerHandshake(wsock).perform(self)
wsock.handshake_sent = True
return wsock, address
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment