Commit 667d6115 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented HTTP redirect and authentication in client handshake

parent 8f15e283
...@@ -4,6 +4,7 @@ import socket ...@@ -4,6 +4,7 @@ import socket
import ssl import ssl
from hashlib import sha1 from hashlib import sha1
from base64 import b64encode from base64 import b64encode
from urlparse import urlparse
from frame import receive_frame from frame import receive_frame
from errors import HandshakeError, SSLError from errors import HandshakeError, SSLError
...@@ -79,7 +80,7 @@ class websocket(object): ...@@ -79,7 +80,7 @@ class websocket(object):
wsock.server_handshake() wsock.server_handshake()
return wsock, address return wsock, address
def connect(self, address, path='/'): def connect(self, address, path='/', auth=None):
""" """
Equivalent to socket.connect(), but sends an client handshake request Equivalent to socket.connect(), but sends an client handshake request
after connecting. after connecting.
...@@ -88,9 +89,12 @@ class websocket(object): ...@@ -88,9 +89,12 @@ class websocket(object):
`path` is optional, used as the *location* part of the HTTP handshake. `path` is optional, used as the *location* part of the HTTP handshake.
In a URL, this would show as ws://host[:port]/path. In a URL, this would show as ws://host[:port]/path.
`auth` is optional, used for the HTTP "Authorization" header of the
handshake request.
""" """
self.sock.connect(address) self.sock.connect(address)
self.client_handshake(address, path) self.client_handshake(address, path, auth)
def send(self, *args): def send(self, *args):
""" """
...@@ -215,7 +219,7 @@ class websocket(object): ...@@ -215,7 +219,7 @@ class websocket(object):
self.sock.sendall(shake + '\r\n') self.sock.sendall(shake + '\r\n')
self.handshake_started = True self.handshake_started = True
def client_handshake(self, address, location): def client_handshake(self, address, location, auth):
""" """
Executes a handshake as the client end point of the socket. May raise a Executes a handshake as the client end point of the socket. May raise a
HandshakeError if the server response is invalid. HandshakeError if the server response is invalid.
...@@ -224,90 +228,129 @@ class websocket(object): ...@@ -224,90 +228,129 @@ class websocket(object):
self.sock.close() self.sock.close()
raise HandshakeError(msg) raise HandshakeError(msg)
if len(location) == 0: def send_request(location):
fail('request location is empty') if len(location) == 0:
fail('request location is empty')
# Generate a 16-byte random base64-encoded key for this connection # Generate a 16-byte random base64-encoded key for this connection
key = b64encode(os.urandom(16)) key = b64encode(os.urandom(16))
# Send client handshake # Send client handshake
shake = 'GET %s HTTP/1.1\r\n' % location shake = 'GET %s HTTP/1.1\r\n' % location
shake += 'Host: %s:%d\r\n' % address shake += 'Host: %s:%d\r\n' % address
shake += 'Upgrade: websocket\r\n' shake += 'Upgrade: websocket\r\n'
shake += 'Connection: keep-alive, Upgrade\r\n' shake += 'Connection: keep-alive, Upgrade\r\n'
shake += 'Sec-WebSocket-Key: %s\r\n' % key shake += 'Sec-WebSocket-Key: %s\r\n' % key
shake += 'Origin: null\r\n' # FIXME: is this correct/necessary? shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
# These are for eagerly caching webservers # These are for eagerly caching webservers
shake += 'Pragma: no-cache\r\n' shake += 'Pragma: no-cache\r\n'
shake += 'Cache-Control: no-cache\r\n' shake += 'Cache-Control: no-cache\r\n'
# Request protocols and extension, these are later checked with the # Request protocols and extension, these are later checked with the
# actual supported values from the server's response # actual supported values from the server's response
if self.protocols: if self.protocols:
shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols) shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
if self.extensions: if self.extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions) shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
self.sock.sendall(shake + '\r\n') if auth:
self.handshake_started = True shake += 'Authorization: %s\r\n' % auth
# Receive and process server handshake self.sock.sendall(shake + '\r\n')
raw_headers = '' return key
while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'): def receive_response(key):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore') # Receive and process server handshake
raw_headers = ''
# Response must be HTTP (at least 1.1) with status 101 while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
if not raw_headers.startswith('HTTP/1.1 101'): raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
# TODO: implement HTTP authentication (401) and redirect (3xx)?
fail('not a valid HTTP 1.1 status 101 response')
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers) # Response must be HTTP (at least 1.1) with status 101
header_names = [name for name, value in headers] match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
def header(name): if match is None:
return ', '.join([v for n, v in headers if n == name]) fail('not a valid HTTP 1.1 response')
# Check if headers that MUST be present are actually present status = int(match.group(1))
for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'): headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
if name not in header_names: header_names = [name for name, value in headers]
fail('missing "%s" header' % name)
if 'websocket' not in header('Upgrade').lower(): def header(name):
fail('"Upgrade" header must contain "websocket"') return ', '.join([v for n, v in headers if n == name])
if 'upgrade' not in header('Connection').lower(): if status == 401:
fail('"Connection" header must contain "Upgrade"') # HTTP authentication is required in the request
raise HandshakeError('HTTP authentication required: %s'
% header('WWW-Authenticate'))
if status in (301, 302, 303, 307, 308):
# Handle HTTP redirect
url = urlparse(header('Location').strip())
# Reconnect socket if net location changed
if not url.port:
url.port = 443 if self.secure else 80
addr = (url.netloc, url.port)
if addr != self.sock.getpeername():
self.sock.close()
self.sock.connect(addr)
# Send new handshake
receive_response(send_request(url.path))
return
# Verify accept header if status != 101:
accept = header('Sec-WebSocket-Accept').strip() # 101 means server has accepted the connection and sent
required_accept = b64encode(sha1(key + WS_GUID).digest()) # handshake headers
fail('invalid HTTP response status %d' % status)
if accept != required_accept: # Check if headers that MUST be present are actually present
fail('invalid websocket accept header "%s"' % accept) for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
if name not in header_names:
fail('missing "%s" header' % name)
# Compare extensions if 'websocket' not in header('Upgrade').lower():
server_extensions = split_stripped(header('Sec-WebSocket-Extensions')) fail('"Upgrade" header must contain "websocket"')
for ext in server_extensions: if 'upgrade' not in header('Connection').lower():
if ext not in self.extensions: fail('"Connection" header must contain "Upgrade"')
fail('server extension "%s" is unsupported by client' % ext)
for ext in self.extensions: # Verify accept header
if ext not in server_extensions: accept = header('Sec-WebSocket-Accept').strip()
fail('client extension "%s" is unsupported by server' % ext) required_accept = b64encode(sha1(key + WS_GUID).digest())
# Assert that returned protocol is supported if accept != required_accept:
protocol = header('Sec-WebSocket-Protocol') fail('invalid websocket accept header "%s"' % accept)
if protocol: # Compare extensions
if protocol != 'null' and protocol not in self.protocols: server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
fail('unsupported protocol "%s"' % protocol)
for e in server_ext:
if e not in self.extensions:
fail('server extension "%s" is unsupported by client' % e)
for e in self.extensions:
if e not in server_ext:
fail('client extension "%s" is unsupported by server' % e)
# Assert that returned protocol (if any) is supported
protocol = header('Sec-WebSocket-Protocol')
if protocol:
if protocol != 'null' and protocol not in self.protocols:
fail('unsupported protocol "%s"' % protocol)
self.protocol = protocol
self.handshake_started = True
receive_response(send_request(location))
self.protocol = protocol
def enable_ssl(self, *args, **kwargs): def enable_ssl(self, *args, **kwargs):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment