Commit 8f15e283 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented client handshake, and did some corresponding debugging

parent 6d2cf25d
.PHONY: check clean
check:
@python test.py
@python test/server.py
clean:
find -name \*.pyc -delete
......@@ -7,3 +7,4 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
from connection import Connection
from message import Message, TextMessage, BinaryMessage, JSONMessage
from errors import SocketClosed
import struct
import socket
from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
OPCODE_CONTINUATION
......@@ -106,6 +107,16 @@ class Connection(object):
except SocketClosed as e:
self.close(e.code, e.reason)
break
except socket.error as e:
self.onerror(e)
try:
self.sock.close()
except socket.error:
pass
self.onclose(None, '')
break
except Exception as e:
self.onerror(e)
......@@ -124,10 +135,7 @@ class Connection(object):
close message, unless such a message has already been received earlier
(prior to calling this function, for example). The onclose() handler is
called after the response has been received, but before the socket is
actually closed. This order was chosen to prevent errors in
stringification in the onclose() handler. For example,
socket.getpeername() raises a Bad file descriptor error then the socket
is closed.
actually closed.
"""
# Send CLOSE frame
payload = '' if code is None else struct.pack('!H', code) + reason
......
import struct
import socket
from os import urandom
from string import printable
......@@ -150,7 +151,7 @@ class Frame(object):
s += ' masking_key=%4s' % printstr(self.masking_key)
max_pl_disp = 30
pl = self.payload[:max_pl_disp]
pl = printstr(self.payload)[:max_pl_disp]
if len(self.payload) > max_pl_disp:
pl += '...'
......@@ -240,7 +241,7 @@ def recvn(sock, n):
received = sock.recv(n - len(data))
if not len(received):
raise SocketClosed(None, 'no data read from socket')
raise socket.error('no data read from socket')
data += received
......
......@@ -133,7 +133,10 @@ class Client(Connection):
super(Client, self).__init__(sock)
def __str__(self):
return '<Client at %s:%d>' % self.sock.getpeername()
try:
return '<Client at %s:%d>' % self.sock.getpeername()
except socket.error:
return '<Client on closed socket>'
def send(self, message, fragment_size=None, mask=False):
logging.debug('Sending %s to %s', message, self)
......
......@@ -15,7 +15,7 @@
var ws = new WebSocket(URL);
ws.onopen = function() {
log('Connection complete, sending "foo"');
log('Connection established, sending "foo"');
ws.send('foo');
};
......
#!/usr/bin/env python
import sys
from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..')
sys.path.insert(0, basepath)
from websocket import websocket
from connection import Connection
from message import TextMessage
from errors import SocketClosed
ADDR = ('localhost', 8000)
class EchoClient(Connection):
def onopen(self):
print 'Connection established, sending "foo"'
self.send(TextMessage('foo'))
def onmessage(self, msg):
print 'Received', msg
raise SocketClosed(None, 'response received')
def onerror(self, e):
print 'Error:', e
def onclose(self, code, reason):
print 'Connection closed'
if __name__ == '__main__':
print 'Connecting to ws://%s:%d' % ADDR
sock = websocket()
sock.connect(ADDR)
EchoClient(sock).receive_forever()
#!/usr/bin/env python
import sys
import logging
from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..')
sys.path.insert(0, basepath)
from server import Server
......
import os
import re
import socket
import ssl
......@@ -13,7 +14,7 @@ WS_VERSION = '13'
def split_stripped(value, delim=','):
return map(str.strip, str(value).split(delim))
return map(str.strip, str(value).split(delim)) if value else []
class websocket(object):
......@@ -78,13 +79,18 @@ class websocket(object):
wsock.server_handshake()
return wsock, address
def connect(self, address):
def connect(self, address, path='/'):
"""
Equivalent to socket.connect(), but sends an client handshake request
after connecting.
`address` is a (host, port) tuple of the server to connect to.
`path` is optional, used as the *location* part of the HTTP handshake.
In a URL, this would show as ws://host[:port]/path.
"""
self.sock.sonnect(address)
self.client_handshake()
self.sock.connect(address)
self.client_handshake(address, path)
def send(self, *args):
"""
......@@ -131,6 +137,10 @@ class websocket(object):
request headers sent by the client are invalid, a HandshakeError
is raised.
"""
def fail(msg):
self.sock.close()
raise HandshakeError(msg)
# Receive HTTP header
raw_headers = ''
......@@ -138,7 +148,12 @@ class websocket(object):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
# Request must be HTTP (at least 1.1) GET request, find the location
location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
if match is None:
fail('not a valid HTTP 1.1 GET request')
location = match.group(1)
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers]
......@@ -147,25 +162,39 @@ class websocket(object):
# Check if headers that MUST be present are actually present
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
'Origin', 'Sec-WebSocket-Version'):
'Sec-WebSocket-Version'):
if name not in header_names:
raise HandshakeError('missing "%s" header' % name)
fail('missing "%s" header' % name)
# Check WebSocket version used by client
version = header('Sec-WebSocket-Version')
if version != WS_VERSION:
raise HandshakeError('WebSocket version %s requested (only %s '
fail('WebSocket version %s requested (only %s '
'is supported)' % (version, WS_VERSION))
# Verify required header keywords
if 'websocket' not in header('Upgrade').lower():
fail('"Upgrade" header must contain "websocket"')
if 'upgrade' not in header('Connection').lower():
fail('"Connection" header must contain "Upgrade"')
# Origin must be present if browser client
if 'User-Agent' in header_names and 'Origin' not in header_names:
fail('browser client must specify "Origin" header')
# Only supported protocols are returned
proto = header('Sec-WebSocket-Extensions')
protocols = split_stripped(proto) if proto else []
protocols = [p for p in protocols if p in self.protocols]
client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
protocol = 'null'
for p in client_protocols:
if p in self.protocols:
protocol = p
break
# Only supported extensions are returned
ext = header('Sec-WebSocket-Extensions')
extensions = split_stripped(ext) if ext else []
extensions = split_stripped(header('Sec-WebSocket-Extensions'))
extensions = [e for e in extensions if e in self.extensions]
# Encode acceptation key using the WebSocket GUID
......@@ -174,34 +203,115 @@ class websocket(object):
# Construct HTTP response header
shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
shake += 'Upgrade: WebSocket\r\n'
shake += 'Upgrade: websocket\r\n'
shake += 'Connection: Upgrade\r\n'
shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
shake += 'WebSocket-Location: ws://%s%s\r\n' \
% (header('Host'), location)
shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
if protocols:
shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
if extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
self.sock.sendall(shake + '\r\n')
self.handshake_started = True
def client_handshake(self):
def client_handshake(self, address, location):
"""
Execute a handshake as the client end point of the socket. May raise a
Executes a handshake as the client end point of the socket. May raise a
HandshakeError if the server response is invalid.
"""
# TODO: implement HTTP request headers for client handshake
def fail(msg):
self.sock.close()
raise HandshakeError(msg)
if len(location) == 0:
fail('request location is empty')
# Generate a 16-byte random base64-encoded key for this connection
key = b64encode(os.urandom(16))
# Send client handshake
shake = 'GET %s HTTP/1.1\r\n' % location
shake += 'Host: %s:%d\r\n' % address
shake += 'Upgrade: websocket\r\n'
shake += 'Connection: keep-alive, Upgrade\r\n'
shake += 'Sec-WebSocket-Key: %s\r\n' % key
shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
# These are for eagerly caching webservers
shake += 'Pragma: no-cache\r\n'
shake += 'Cache-Control: no-cache\r\n'
# Request protocols and extension, these are later checked with the
# actual supported values from the server's response
if self.protocols:
shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
if self.extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
self.sock.sendall(shake + '\r\n')
self.handshake_started = True
raise NotImplementedError
# Receive and process server handshake
raw_headers = ''
while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
# Response must be HTTP (at least 1.1) with status 101
if not raw_headers.startswith('HTTP/1.1 101'):
# TODO: implement HTTP authentication (401) and redirect (3xx)?
fail('not a valid HTTP 1.1 status 101 response')
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers]
def header(name):
return ', '.join([v for n, v in headers if n == name])
# Check if headers that MUST be present are actually present
for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
if name not in header_names:
fail('missing "%s" header' % name)
if 'websocket' not in header('Upgrade').lower():
fail('"Upgrade" header must contain "websocket"')
if 'upgrade' not in header('Connection').lower():
fail('"Connection" header must contain "Upgrade"')
# Verify accept header
accept = header('Sec-WebSocket-Accept').strip()
required_accept = b64encode(sha1(key + WS_GUID).digest())
if accept != required_accept:
fail('invalid websocket accept header "%s"' % accept)
# Compare extensions
server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
for ext in server_extensions:
if ext not in self.extensions:
fail('server extension "%s" is unsupported by client' % ext)
for ext in self.extensions:
if ext not in server_extensions:
fail('client extension "%s" is unsupported by server' % ext)
# Assert that returned protocol is supported
protocol = header('Sec-WebSocket-Protocol')
if protocol:
if protocol != 'null' and protocol not in self.protocols:
fail('unsupported protocol "%s"' % protocol)
self.protocol = protocol
def enable_ssl(self, *args, **kwargs):
"""
Transform the regular socket.socket to an ssl.SSLSocket for secure
Transforms the regular socket.socket to an ssl.SSLSocket for secure
connections. Any arguments are passed to ssl.wrap_socket:
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment