Commit 667d6115 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented HTTP redirect and authentication in client handshake

parent 8f15e283
...@@ -4,6 +4,7 @@ import socket ...@@ -4,6 +4,7 @@ import socket
import ssl import ssl
from hashlib import sha1 from hashlib import sha1
from base64 import b64encode from base64 import b64encode
from urlparse import urlparse
from frame import receive_frame from frame import receive_frame
from errors import HandshakeError, SSLError from errors import HandshakeError, SSLError
...@@ -79,7 +80,7 @@ class websocket(object): ...@@ -79,7 +80,7 @@ class websocket(object):
wsock.server_handshake() wsock.server_handshake()
return wsock, address return wsock, address
def connect(self, address, path='/'): def connect(self, address, path='/', auth=None):
""" """
Equivalent to socket.connect(), but sends an client handshake request Equivalent to socket.connect(), but sends an client handshake request
after connecting. after connecting.
...@@ -88,9 +89,12 @@ class websocket(object): ...@@ -88,9 +89,12 @@ class websocket(object):
`path` is optional, used as the *location* part of the HTTP handshake. `path` is optional, used as the *location* part of the HTTP handshake.
In a URL, this would show as ws://host[:port]/path. In a URL, this would show as ws://host[:port]/path.
`auth` is optional, used for the HTTP "Authorization" header of the
handshake request.
""" """
self.sock.connect(address) self.sock.connect(address)
self.client_handshake(address, path) self.client_handshake(address, path, auth)
def send(self, *args): def send(self, *args):
""" """
...@@ -215,7 +219,7 @@ class websocket(object): ...@@ -215,7 +219,7 @@ class websocket(object):
self.sock.sendall(shake + '\r\n') self.sock.sendall(shake + '\r\n')
self.handshake_started = True self.handshake_started = True
def client_handshake(self, address, location): def client_handshake(self, address, location, auth):
""" """
Executes a handshake as the client end point of the socket. May raise a Executes a handshake as the client end point of the socket. May raise a
HandshakeError if the server response is invalid. HandshakeError if the server response is invalid.
...@@ -224,6 +228,7 @@ class websocket(object): ...@@ -224,6 +228,7 @@ class websocket(object):
self.sock.close() self.sock.close()
raise HandshakeError(msg) raise HandshakeError(msg)
def send_request(location):
if len(location) == 0: if len(location) == 0:
fail('request location is empty') fail('request location is empty')
...@@ -251,9 +256,13 @@ class websocket(object): ...@@ -251,9 +256,13 @@ class websocket(object):
if self.extensions: if self.extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions) shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
if auth:
shake += 'Authorization: %s\r\n' % auth
self.sock.sendall(shake + '\r\n') self.sock.sendall(shake + '\r\n')
self.handshake_started = True return key
def receive_response(key):
# Receive and process server handshake # Receive and process server handshake
raw_headers = '' raw_headers = ''
...@@ -261,16 +270,46 @@ class websocket(object): ...@@ -261,16 +270,46 @@ class websocket(object):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore') raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
# Response must be HTTP (at least 1.1) with status 101 # Response must be HTTP (at least 1.1) with status 101
if not raw_headers.startswith('HTTP/1.1 101'): match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
# TODO: implement HTTP authentication (401) and redirect (3xx)?
fail('not a valid HTTP 1.1 status 101 response') if match is None:
fail('not a valid HTTP 1.1 response')
status = int(match.group(1))
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers) headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers] header_names = [name for name, value in headers]
def header(name): def header(name):
return ', '.join([v for n, v in headers if n == name]) return ', '.join([v for n, v in headers if n == name])
if status == 401:
# HTTP authentication is required in the request
raise HandshakeError('HTTP authentication required: %s'
% header('WWW-Authenticate'))
if status in (301, 302, 303, 307, 308):
# Handle HTTP redirect
url = urlparse(header('Location').strip())
# Reconnect socket if net location changed
if not url.port:
url.port = 443 if self.secure else 80
addr = (url.netloc, url.port)
if addr != self.sock.getpeername():
self.sock.close()
self.sock.connect(addr)
# Send new handshake
receive_response(send_request(url.path))
return
if status != 101:
# 101 means server has accepted the connection and sent
# handshake headers
fail('invalid HTTP response status %d' % status)
# Check if headers that MUST be present are actually present # Check if headers that MUST be present are actually present
for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'): for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
if name not in header_names: if name not in header_names:
...@@ -290,17 +329,17 @@ class websocket(object): ...@@ -290,17 +329,17 @@ class websocket(object):
fail('invalid websocket accept header "%s"' % accept) fail('invalid websocket accept header "%s"' % accept)
# Compare extensions # Compare extensions
server_extensions = split_stripped(header('Sec-WebSocket-Extensions')) server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
for ext in server_extensions: for e in server_ext:
if ext not in self.extensions: if e not in self.extensions:
fail('server extension "%s" is unsupported by client' % ext) fail('server extension "%s" is unsupported by client' % e)
for ext in self.extensions: for e in self.extensions:
if ext not in server_extensions: if e not in server_ext:
fail('client extension "%s" is unsupported by server' % ext) fail('client extension "%s" is unsupported by server' % e)
# Assert that returned protocol is supported # Assert that returned protocol (if any) is supported
protocol = header('Sec-WebSocket-Protocol') protocol = header('Sec-WebSocket-Protocol')
if protocol: if protocol:
...@@ -309,6 +348,10 @@ class websocket(object): ...@@ -309,6 +348,10 @@ class websocket(object):
self.protocol = protocol self.protocol = protocol
self.handshake_started = True
receive_response(send_request(location))
def enable_ssl(self, *args, **kwargs): def enable_ssl(self, *args, **kwargs):
""" """
Transforms the regular socket.socket to an ssl.SSLSocket for secure Transforms the regular socket.socket to an ssl.SSLSocket for secure
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment