Commit 667d6115 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented HTTP redirect and authentication in client handshake

parent 8f15e283
......@@ -4,6 +4,7 @@ import socket
import ssl
from hashlib import sha1
from base64 import b64encode
from urlparse import urlparse
from frame import receive_frame
from errors import HandshakeError, SSLError
......@@ -79,7 +80,7 @@ class websocket(object):
wsock.server_handshake()
return wsock, address
def connect(self, address, path='/'):
def connect(self, address, path='/', auth=None):
"""
Equivalent to socket.connect(), but sends an client handshake request
after connecting.
......@@ -88,9 +89,12 @@ class websocket(object):
`path` is optional, used as the *location* part of the HTTP handshake.
In a URL, this would show as ws://host[:port]/path.
`auth` is optional, used for the HTTP "Authorization" header of the
handshake request.
"""
self.sock.connect(address)
self.client_handshake(address, path)
self.client_handshake(address, path, auth)
def send(self, *args):
"""
......@@ -215,7 +219,7 @@ class websocket(object):
self.sock.sendall(shake + '\r\n')
self.handshake_started = True
def client_handshake(self, address, location):
def client_handshake(self, address, location, auth):
"""
Executes a handshake as the client end point of the socket. May raise a
HandshakeError if the server response is invalid.
......@@ -224,90 +228,129 @@ class websocket(object):
self.sock.close()
raise HandshakeError(msg)
if len(location) == 0:
fail('request location is empty')
def send_request(location):
if len(location) == 0:
fail('request location is empty')
# Generate a 16-byte random base64-encoded key for this connection
key = b64encode(os.urandom(16))
# Generate a 16-byte random base64-encoded key for this connection
key = b64encode(os.urandom(16))
# Send client handshake
shake = 'GET %s HTTP/1.1\r\n' % location
shake += 'Host: %s:%d\r\n' % address
shake += 'Upgrade: websocket\r\n'
shake += 'Connection: keep-alive, Upgrade\r\n'
shake += 'Sec-WebSocket-Key: %s\r\n' % key
shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
# Send client handshake
shake = 'GET %s HTTP/1.1\r\n' % location
shake += 'Host: %s:%d\r\n' % address
shake += 'Upgrade: websocket\r\n'
shake += 'Connection: keep-alive, Upgrade\r\n'
shake += 'Sec-WebSocket-Key: %s\r\n' % key
shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
# These are for eagerly caching webservers
shake += 'Pragma: no-cache\r\n'
shake += 'Cache-Control: no-cache\r\n'
# These are for eagerly caching webservers
shake += 'Pragma: no-cache\r\n'
shake += 'Cache-Control: no-cache\r\n'
# Request protocols and extension, these are later checked with the
# actual supported values from the server's response
if self.protocols:
shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
# Request protocols and extension, these are later checked with the
# actual supported values from the server's response
if self.protocols:
shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
if self.extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
if self.extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
self.sock.sendall(shake + '\r\n')
self.handshake_started = True
if auth:
shake += 'Authorization: %s\r\n' % auth
# Receive and process server handshake
raw_headers = ''
self.sock.sendall(shake + '\r\n')
return key
while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
def receive_response(key):
# Receive and process server handshake
raw_headers = ''
# Response must be HTTP (at least 1.1) with status 101
if not raw_headers.startswith('HTTP/1.1 101'):
# TODO: implement HTTP authentication (401) and redirect (3xx)?
fail('not a valid HTTP 1.1 status 101 response')
while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers]
# Response must be HTTP (at least 1.1) with status 101
match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
def header(name):
return ', '.join([v for n, v in headers if n == name])
if match is None:
fail('not a valid HTTP 1.1 response')
# Check if headers that MUST be present are actually present
for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
if name not in header_names:
fail('missing "%s" header' % name)
status = int(match.group(1))
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers]
if 'websocket' not in header('Upgrade').lower():
fail('"Upgrade" header must contain "websocket"')
def header(name):
return ', '.join([v for n, v in headers if n == name])
if 'upgrade' not in header('Connection').lower():
fail('"Connection" header must contain "Upgrade"')
if status == 401:
# HTTP authentication is required in the request
raise HandshakeError('HTTP authentication required: %s'
% header('WWW-Authenticate'))
if status in (301, 302, 303, 307, 308):
# Handle HTTP redirect
url = urlparse(header('Location').strip())
# Reconnect socket if net location changed
if not url.port:
url.port = 443 if self.secure else 80
addr = (url.netloc, url.port)
if addr != self.sock.getpeername():
self.sock.close()
self.sock.connect(addr)
# Send new handshake
receive_response(send_request(url.path))
return
# Verify accept header
accept = header('Sec-WebSocket-Accept').strip()
required_accept = b64encode(sha1(key + WS_GUID).digest())
if status != 101:
# 101 means server has accepted the connection and sent
# handshake headers
fail('invalid HTTP response status %d' % status)
if accept != required_accept:
fail('invalid websocket accept header "%s"' % accept)
# Check if headers that MUST be present are actually present
for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
if name not in header_names:
fail('missing "%s" header' % name)
# Compare extensions
server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
if 'websocket' not in header('Upgrade').lower():
fail('"Upgrade" header must contain "websocket"')
for ext in server_extensions:
if ext not in self.extensions:
fail('server extension "%s" is unsupported by client' % ext)
if 'upgrade' not in header('Connection').lower():
fail('"Connection" header must contain "Upgrade"')
for ext in self.extensions:
if ext not in server_extensions:
fail('client extension "%s" is unsupported by server' % ext)
# Verify accept header
accept = header('Sec-WebSocket-Accept').strip()
required_accept = b64encode(sha1(key + WS_GUID).digest())
# Assert that returned protocol is supported
protocol = header('Sec-WebSocket-Protocol')
if accept != required_accept:
fail('invalid websocket accept header "%s"' % accept)
if protocol:
if protocol != 'null' and protocol not in self.protocols:
fail('unsupported protocol "%s"' % protocol)
# Compare extensions
server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
for e in server_ext:
if e not in self.extensions:
fail('server extension "%s" is unsupported by client' % e)
for e in self.extensions:
if e not in server_ext:
fail('client extension "%s" is unsupported by server' % e)
# Assert that returned protocol (if any) is supported
protocol = header('Sec-WebSocket-Protocol')
if protocol:
if protocol != 'null' and protocol not in self.protocols:
fail('unsupported protocol "%s"' % protocol)
self.protocol = protocol
self.handshake_started = True
receive_response(send_request(location))
self.protocol = protocol
def enable_ssl(self, *args, **kwargs):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment