Commit 6ec710a5 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Extensions are now passed to websocket.accept() properly, fixed origin/location checking

parent 8db0a06c
...@@ -27,7 +27,8 @@ class Handshake(object): ...@@ -27,7 +27,8 @@ class Handshake(object):
raw, headers = self.receive_headers() raw, headers = self.receive_headers()
# Request must be HTTP (at least 1.1) GET request, find the location # Request must be HTTP (at least 1.1) GET request, find the location
match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw) # (without trailing slash)
match = re.search(r'^GET (.*?)/* HTTP/1.1\r\n', raw)
if match is None: if match is None:
self.fail('not a valid HTTP 1.1 GET request') self.fail('not a valid HTTP 1.1 GET request')
...@@ -84,14 +85,18 @@ class ServerHandshake(Handshake): ...@@ -84,14 +85,18 @@ class ServerHandshake(Handshake):
request headers sent by the client are invalid, a HandshakeError is raised. request headers sent by the client are invalid, a HandshakeError is raised.
""" """
def perform(self): def perform(self, ssock):
# Receive and validate client handshake # Receive and validate client handshake
self.wsock.location, headers = self.receive_request() location, headers = self.receive_request()
self.wsock.location = location
self.wsock.request_headers = headers
# Send server handshake in response # Send server handshake in response
self.send_headers(self.response_headers(headers)) self.send_headers(self.response_headers(ssock))
def response_headers(self, ssock):
headers = self.wsock.request_headers
def response_headers(self, headers):
# Check if headers that MUST be present are actually present # Check if headers that MUST be present are actually present
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key', for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
'Sec-WebSocket-Version'): 'Sec-WebSocket-Version'):
...@@ -114,18 +119,15 @@ class ServerHandshake(Handshake): ...@@ -114,18 +119,15 @@ class ServerHandshake(Handshake):
# Origin must be present if browser client, and must match the list of # Origin must be present if browser client, and must match the list of
# trusted origins # trusted origins
origin = 'null' origin = headers.get('Origin', 'null')
if 'Origin' not in headers: if origin == 'null':
if 'User-Agent' in headers: if 'User-Agent' in headers:
self.fail('browser client must specify "Origin" header') self.fail('browser client must specify "Origin" header')
if self.wsock.trusted_origins: if ssock.trusted_origins:
self.fail('no "Origin" header specified, assuming untrusted') self.fail('no "Origin" header specified, assuming untrusted')
elif self.wsock.trusted_origins: elif ssock.trusted_origins and origin not in ssock.trusted_origins:
origin = headers['Origin']
if origin not in self.wsock.trusted_origins:
self.fail('untrusted origin "%s"' % origin) self.fail('untrusted origin "%s"' % origin)
# Only a supported protocol can be returned # Only a supported protocol can be returned
...@@ -134,13 +136,13 @@ class ServerHandshake(Handshake): ...@@ -134,13 +136,13 @@ class ServerHandshake(Handshake):
self.wsock.protocol = None self.wsock.protocol = None
for p in client_proto: for p in client_proto:
if p in self.wsock.protocols: if p in ssock.protocols:
self.wsock.protocol = p self.wsock.protocol = p
break break
# Only supported extensions are returned # Only supported extensions are returned
if 'Sec-WebSocket-Extensions' in headers: if 'Sec-WebSocket-Extensions' in headers:
supported_ext = dict((e.name, e) for e in self.wsock.extensions) supported_ext = dict((e.name, e) for e in ssock.extensions)
extensions = [] extensions = []
all_params = [] all_params = []
...@@ -159,6 +161,12 @@ class ServerHandshake(Handshake): ...@@ -159,6 +161,12 @@ class ServerHandshake(Handshake):
else: else:
self.wsock.extensions = [] self.wsock.extensions = []
# Check if requested resource location is served by this server
if ssock.locations:
if self.wsock.location not in ssock.locations:
raise HandshakeError('location "%s" is not supported by this '
'server' % self.wsock.location)
# Encode acceptation key using the WebSocket GUID # Encode acceptation key using the WebSocket GUID
key = headers['Sec-WebSocket-Key'].strip() key = headers['Sec-WebSocket-Key'].strip()
accept = b64encode(sha1(key + WS_GUID).digest()) accept = b64encode(sha1(key + WS_GUID).digest())
...@@ -181,8 +189,8 @@ class ServerHandshake(Handshake): ...@@ -181,8 +189,8 @@ class ServerHandshake(Handshake):
yield 'HTTP/1.1 101 Web Socket Protocol Handshake' yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
yield 'Upgrade', 'websocket' yield 'Upgrade', 'websocket'
yield 'Connection', 'Upgrade' yield 'Connection', 'Upgrade'
yield 'WebSocket-Origin', origin yield 'Sec-WebSocket-Origin', origin
yield 'WebSocket-Location', location yield 'Sec-WebSocket-Location', location
yield 'Sec-WebSocket-Accept', accept yield 'Sec-WebSocket-Accept', accept
if self.wsock.protocol: if self.wsock.protocol:
......
...@@ -31,7 +31,7 @@ class websocket(object): ...@@ -31,7 +31,7 @@ class websocket(object):
>>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!')) >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
""" """
def __init__(self, sock=None, protocols=[], extensions=[], origin=None, def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
trusted_origins=[], location='/', auth=None, location='/', trusted_origins=[], locations=[], auth=None,
sfamily=socket.AF_INET, sproto=0): sfamily=socket.AF_INET, sproto=0):
""" """
Create a regular TCP socket of family `family` and protocol Create a regular TCP socket of family `family` and protocol
...@@ -46,13 +46,21 @@ class websocket(object): ...@@ -46,13 +46,21 @@ class websocket(object):
`origin` (for client sockets) is the value for the "Origin" header sent `origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake . in a client handshake .
`trusted_origins` (for servere sockets) is a list of expected values `location` (for client sockets) is optional, used to request a
particular resource in the HTTP handshake. In a URL, this would show as
ws://host[:port]/<location>. Use this when the server serves multiple
resources (see `locations`).
`trusted_origins` (for server sockets) is a list of expected values
for the "Origin" header sent by a client. If the received Origin header for the "Origin" header sent by a client. If the received Origin header
has value not in this list, a HandshakeError is raised. If the list is has value not in this list, a HandshakeError is raised. If the list is
empty (default), all origins are excepted. empty (default), all origins are excepted.
`location` is optional, used for the HTTP handshake. In a URL, this `locations` (for server sockets) is an optional list of resources
would show as ws://host[:port]/path. serverd by this server. If specified (without trailing slashes), these
are used to verify the resource location requested by a client. The
requested location may be used to distinquish different services in a
server implementation.
`auth` is optional, used for HTTP Basic or Digest authentication during `auth` is optional, used for HTTP Basic or Digest authentication during
the handshake. It must be specified as a (username, password) tuple. the handshake. It must be specified as a (username, password) tuple.
...@@ -62,8 +70,9 @@ class websocket(object): ...@@ -62,8 +70,9 @@ class websocket(object):
self.protocols = protocols self.protocols = protocols
self.extensions = extensions self.extensions = extensions
self.origin = origin self.origin = origin
self.trusted_origins = trusted_origins
self.location = location self.location = location
self.trusted_origins = trusted_origins
self.locations = locations
self.auth = auth self.auth = auth
self.secure = False self.secure = False
...@@ -90,7 +99,8 @@ class websocket(object): ...@@ -90,7 +99,8 @@ class websocket(object):
""" """
sock, address = self.sock.accept() sock, address = self.sock.accept()
wsock = websocket(sock) wsock = websocket(sock)
ServerHandshake(wsock).perform() wsock.secure = self.secure
ServerHandshake(wsock).perform(self)
wsock.handshake_sent = True wsock.handshake_sent = True
return wsock, address return wsock, address
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment