Commit 667d6115 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented HTTP redirect and authentication in client handshake

parent 8f15e283
......@@ -4,6 +4,7 @@ import socket
import ssl
from hashlib import sha1
from base64 import b64encode
from urlparse import urlparse
from frame import receive_frame
from errors import HandshakeError, SSLError
......@@ -79,7 +80,7 @@ class websocket(object):
wsock.server_handshake()
return wsock, address
def connect(self, address, path='/'):
def connect(self, address, path='/', auth=None):
"""
Equivalent to socket.connect(), but sends an client handshake request
after connecting.
......@@ -88,9 +89,12 @@ class websocket(object):
`path` is optional, used as the *location* part of the HTTP handshake.
In a URL, this would show as ws://host[:port]/path.
`auth` is optional, used for the HTTP "Authorization" header of the
handshake request.
"""
self.sock.connect(address)
self.client_handshake(address, path)
self.client_handshake(address, path, auth)
def send(self, *args):
"""
......@@ -215,7 +219,7 @@ class websocket(object):
self.sock.sendall(shake + '\r\n')
self.handshake_started = True
def client_handshake(self, address, location):
def client_handshake(self, address, location, auth):
"""
Executes a handshake as the client end point of the socket. May raise a
HandshakeError if the server response is invalid.
......@@ -224,6 +228,7 @@ class websocket(object):
self.sock.close()
raise HandshakeError(msg)
def send_request(location):
if len(location) == 0:
fail('request location is empty')
......@@ -251,9 +256,13 @@ class websocket(object):
if self.extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
if auth:
shake += 'Authorization: %s\r\n' % auth
self.sock.sendall(shake + '\r\n')
self.handshake_started = True
return key
def receive_response(key):
# Receive and process server handshake
raw_headers = ''
......@@ -261,16 +270,46 @@ class websocket(object):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
# Response must be HTTP (at least 1.1) with status 101
if not raw_headers.startswith('HTTP/1.1 101'):
# TODO: implement HTTP authentication (401) and redirect (3xx)?
fail('not a valid HTTP 1.1 status 101 response')
match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
if match is None:
fail('not a valid HTTP 1.1 response')
status = int(match.group(1))
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers]
def header(name):
return ', '.join([v for n, v in headers if n == name])
if status == 401:
# HTTP authentication is required in the request
raise HandshakeError('HTTP authentication required: %s'
% header('WWW-Authenticate'))
if status in (301, 302, 303, 307, 308):
# Handle HTTP redirect
url = urlparse(header('Location').strip())
# Reconnect socket if net location changed
if not url.port:
url.port = 443 if self.secure else 80
addr = (url.netloc, url.port)
if addr != self.sock.getpeername():
self.sock.close()
self.sock.connect(addr)
# Send new handshake
receive_response(send_request(url.path))
return
if status != 101:
# 101 means server has accepted the connection and sent
# handshake headers
fail('invalid HTTP response status %d' % status)
# Check if headers that MUST be present are actually present
for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
if name not in header_names:
......@@ -290,17 +329,17 @@ class websocket(object):
fail('invalid websocket accept header "%s"' % accept)
# Compare extensions
server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
for ext in server_extensions:
if ext not in self.extensions:
fail('server extension "%s" is unsupported by client' % ext)
for e in server_ext:
if e not in self.extensions:
fail('server extension "%s" is unsupported by client' % e)
for ext in self.extensions:
if ext not in server_extensions:
fail('client extension "%s" is unsupported by server' % ext)
for e in self.extensions:
if e not in server_ext:
fail('client extension "%s" is unsupported by server' % e)
# Assert that returned protocol is supported
# Assert that returned protocol (if any) is supported
protocol = header('Sec-WebSocket-Protocol')
if protocol:
......@@ -309,6 +348,10 @@ class websocket(object):
self.protocol = protocol
self.handshake_started = True
receive_response(send_request(location))
def enable_ssl(self, *args, **kwargs):
"""
Transforms the regular socket.socket to an ssl.SSLSocket for secure
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment