Commit 8f15e283 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented client handshake, and did some corresponding debugging

parent 6d2cf25d
.PHONY: check clean .PHONY: check clean
check: check:
@python test.py @python test/server.py
clean: clean:
find -name \*.pyc -delete find -name \*.pyc -delete
...@@ -7,3 +7,4 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \ ...@@ -7,3 +7,4 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE CLOSE_MISSING_EXTENSIONS, CLOSE_UNABLE
from connection import Connection from connection import Connection
from message import Message, TextMessage, BinaryMessage, JSONMessage from message import Message, TextMessage, BinaryMessage, JSONMessage
from errors import SocketClosed
import struct import struct
import socket
from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \ from frame import ControlFrame, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG, \
OPCODE_CONTINUATION OPCODE_CONTINUATION
...@@ -106,6 +107,16 @@ class Connection(object): ...@@ -106,6 +107,16 @@ class Connection(object):
except SocketClosed as e: except SocketClosed as e:
self.close(e.code, e.reason) self.close(e.code, e.reason)
break break
except socket.error as e:
self.onerror(e)
try:
self.sock.close()
except socket.error:
pass
self.onclose(None, '')
break
except Exception as e: except Exception as e:
self.onerror(e) self.onerror(e)
...@@ -124,10 +135,7 @@ class Connection(object): ...@@ -124,10 +135,7 @@ class Connection(object):
close message, unless such a message has already been received earlier close message, unless such a message has already been received earlier
(prior to calling this function, for example). The onclose() handler is (prior to calling this function, for example). The onclose() handler is
called after the response has been received, but before the socket is called after the response has been received, but before the socket is
actually closed. This order was chosen to prevent errors in actually closed.
stringification in the onclose() handler. For example,
socket.getpeername() raises a Bad file descriptor error then the socket
is closed.
""" """
# Send CLOSE frame # Send CLOSE frame
payload = '' if code is None else struct.pack('!H', code) + reason payload = '' if code is None else struct.pack('!H', code) + reason
......
import struct import struct
import socket
from os import urandom from os import urandom
from string import printable from string import printable
...@@ -150,7 +151,7 @@ class Frame(object): ...@@ -150,7 +151,7 @@ class Frame(object):
s += ' masking_key=%4s' % printstr(self.masking_key) s += ' masking_key=%4s' % printstr(self.masking_key)
max_pl_disp = 30 max_pl_disp = 30
pl = self.payload[:max_pl_disp] pl = printstr(self.payload)[:max_pl_disp]
if len(self.payload) > max_pl_disp: if len(self.payload) > max_pl_disp:
pl += '...' pl += '...'
...@@ -240,7 +241,7 @@ def recvn(sock, n): ...@@ -240,7 +241,7 @@ def recvn(sock, n):
received = sock.recv(n - len(data)) received = sock.recv(n - len(data))
if not len(received): if not len(received):
raise SocketClosed(None, 'no data read from socket') raise socket.error('no data read from socket')
data += received data += received
......
...@@ -133,7 +133,10 @@ class Client(Connection): ...@@ -133,7 +133,10 @@ class Client(Connection):
super(Client, self).__init__(sock) super(Client, self).__init__(sock)
def __str__(self): def __str__(self):
return '<Client at %s:%d>' % self.sock.getpeername() try:
return '<Client at %s:%d>' % self.sock.getpeername()
except socket.error:
return '<Client on closed socket>'
def send(self, message, fragment_size=None, mask=False): def send(self, message, fragment_size=None, mask=False):
logging.debug('Sending %s to %s', message, self) logging.debug('Sending %s to %s', message, self)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
var ws = new WebSocket(URL); var ws = new WebSocket(URL);
ws.onopen = function() { ws.onopen = function() {
log('Connection complete, sending "foo"'); log('Connection established, sending "foo"');
ws.send('foo'); ws.send('foo');
}; };
......
#!/usr/bin/env python
import sys
from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..')
sys.path.insert(0, basepath)
from websocket import websocket
from connection import Connection
from message import TextMessage
from errors import SocketClosed
ADDR = ('localhost', 8000)
class EchoClient(Connection):
def onopen(self):
print 'Connection established, sending "foo"'
self.send(TextMessage('foo'))
def onmessage(self, msg):
print 'Received', msg
raise SocketClosed(None, 'response received')
def onerror(self, e):
print 'Error:', e
def onclose(self, code, reason):
print 'Connection closed'
if __name__ == '__main__':
print 'Connecting to ws://%s:%d' % ADDR
sock = websocket()
sock.connect(ADDR)
EchoClient(sock).receive_forever()
#!/usr/bin/env python #!/usr/bin/env python
import sys
import logging import logging
from os.path import abspath, dirname
basepath = abspath(dirname(abspath(__file__)) + '/..')
sys.path.insert(0, basepath)
from server import Server from server import Server
......
import os
import re import re
import socket import socket
import ssl import ssl
...@@ -13,7 +14,7 @@ WS_VERSION = '13' ...@@ -13,7 +14,7 @@ WS_VERSION = '13'
def split_stripped(value, delim=','): def split_stripped(value, delim=','):
return map(str.strip, str(value).split(delim)) return map(str.strip, str(value).split(delim)) if value else []
class websocket(object): class websocket(object):
...@@ -78,13 +79,18 @@ class websocket(object): ...@@ -78,13 +79,18 @@ class websocket(object):
wsock.server_handshake() wsock.server_handshake()
return wsock, address return wsock, address
def connect(self, address): def connect(self, address, path='/'):
""" """
Equivalent to socket.connect(), but sends an client handshake request Equivalent to socket.connect(), but sends an client handshake request
after connecting. after connecting.
`address` is a (host, port) tuple of the server to connect to.
`path` is optional, used as the *location* part of the HTTP handshake.
In a URL, this would show as ws://host[:port]/path.
""" """
self.sock.sonnect(address) self.sock.connect(address)
self.client_handshake() self.client_handshake(address, path)
def send(self, *args): def send(self, *args):
""" """
...@@ -131,6 +137,10 @@ class websocket(object): ...@@ -131,6 +137,10 @@ class websocket(object):
request headers sent by the client are invalid, a HandshakeError request headers sent by the client are invalid, a HandshakeError
is raised. is raised.
""" """
def fail(msg):
self.sock.close()
raise HandshakeError(msg)
# Receive HTTP header # Receive HTTP header
raw_headers = '' raw_headers = ''
...@@ -138,7 +148,12 @@ class websocket(object): ...@@ -138,7 +148,12 @@ class websocket(object):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore') raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
# Request must be HTTP (at least 1.1) GET request, find the location # Request must be HTTP (at least 1.1) GET request, find the location
location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1) match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
if match is None:
fail('not a valid HTTP 1.1 GET request')
location = match.group(1)
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers) headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers] header_names = [name for name, value in headers]
...@@ -147,25 +162,39 @@ class websocket(object): ...@@ -147,25 +162,39 @@ class websocket(object):
# Check if headers that MUST be present are actually present # Check if headers that MUST be present are actually present
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key', for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
'Origin', 'Sec-WebSocket-Version'): 'Sec-WebSocket-Version'):
if name not in header_names: if name not in header_names:
raise HandshakeError('missing "%s" header' % name) fail('missing "%s" header' % name)
# Check WebSocket version used by client # Check WebSocket version used by client
version = header('Sec-WebSocket-Version') version = header('Sec-WebSocket-Version')
if version != WS_VERSION: if version != WS_VERSION:
raise HandshakeError('WebSocket version %s requested (only %s ' fail('WebSocket version %s requested (only %s '
'is supported)' % (version, WS_VERSION)) 'is supported)' % (version, WS_VERSION))
# Verify required header keywords
if 'websocket' not in header('Upgrade').lower():
fail('"Upgrade" header must contain "websocket"')
if 'upgrade' not in header('Connection').lower():
fail('"Connection" header must contain "Upgrade"')
# Origin must be present if browser client
if 'User-Agent' in header_names and 'Origin' not in header_names:
fail('browser client must specify "Origin" header')
# Only supported protocols are returned # Only supported protocols are returned
proto = header('Sec-WebSocket-Extensions') client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
protocols = split_stripped(proto) if proto else [] protocol = 'null'
protocols = [p for p in protocols if p in self.protocols]
for p in client_protocols:
if p in self.protocols:
protocol = p
break
# Only supported extensions are returned # Only supported extensions are returned
ext = header('Sec-WebSocket-Extensions') extensions = split_stripped(header('Sec-WebSocket-Extensions'))
extensions = split_stripped(ext) if ext else []
extensions = [e for e in extensions if e in self.extensions] extensions = [e for e in extensions if e in self.extensions]
# Encode acceptation key using the WebSocket GUID # Encode acceptation key using the WebSocket GUID
...@@ -174,34 +203,115 @@ class websocket(object): ...@@ -174,34 +203,115 @@ class websocket(object):
# Construct HTTP response header # Construct HTTP response header
shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n' shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
shake += 'Upgrade: WebSocket\r\n' shake += 'Upgrade: websocket\r\n'
shake += 'Connection: Upgrade\r\n' shake += 'Connection: Upgrade\r\n'
shake += 'WebSocket-Origin: %s\r\n' % header('Origin') shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
shake += 'WebSocket-Location: ws://%s%s\r\n' \ shake += 'WebSocket-Location: ws://%s%s\r\n' \
% (header('Host'), location) % (header('Host'), location)
shake += 'Sec-WebSocket-Accept: %s\r\n' % accept shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
if protocols: shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
if extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
self.sock.sendall(shake + '\r\n') self.sock.sendall(shake + '\r\n')
self.handshake_started = True self.handshake_started = True
def client_handshake(self): def client_handshake(self, address, location):
""" """
Execute a handshake as the client end point of the socket. May raise a Executes a handshake as the client end point of the socket. May raise a
HandshakeError if the server response is invalid. HandshakeError if the server response is invalid.
""" """
# TODO: implement HTTP request headers for client handshake def fail(msg):
self.sock.close()
raise HandshakeError(msg)
if len(location) == 0:
fail('request location is empty')
# Generate a 16-byte random base64-encoded key for this connection
key = b64encode(os.urandom(16))
# Send client handshake
shake = 'GET %s HTTP/1.1\r\n' % location
shake += 'Host: %s:%d\r\n' % address
shake += 'Upgrade: websocket\r\n'
shake += 'Connection: keep-alive, Upgrade\r\n'
shake += 'Sec-WebSocket-Key: %s\r\n' % key
shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
# These are for eagerly caching webservers
shake += 'Pragma: no-cache\r\n'
shake += 'Cache-Control: no-cache\r\n'
# Request protocols and extension, these are later checked with the
# actual supported values from the server's response
if self.protocols:
shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
if self.extensions:
shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
self.sock.sendall(shake + '\r\n')
self.handshake_started = True self.handshake_started = True
raise NotImplementedError
# Receive and process server handshake
raw_headers = ''
while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
# Response must be HTTP (at least 1.1) with status 101
if not raw_headers.startswith('HTTP/1.1 101'):
# TODO: implement HTTP authentication (401) and redirect (3xx)?
fail('not a valid HTTP 1.1 status 101 response')
headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
header_names = [name for name, value in headers]
def header(name):
return ', '.join([v for n, v in headers if n == name])
# Check if headers that MUST be present are actually present
for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
if name not in header_names:
fail('missing "%s" header' % name)
if 'websocket' not in header('Upgrade').lower():
fail('"Upgrade" header must contain "websocket"')
if 'upgrade' not in header('Connection').lower():
fail('"Connection" header must contain "Upgrade"')
# Verify accept header
accept = header('Sec-WebSocket-Accept').strip()
required_accept = b64encode(sha1(key + WS_GUID).digest())
if accept != required_accept:
fail('invalid websocket accept header "%s"' % accept)
# Compare extensions
server_extensions = split_stripped(header('Sec-WebSocket-Extensions'))
for ext in server_extensions:
if ext not in self.extensions:
fail('server extension "%s" is unsupported by client' % ext)
for ext in self.extensions:
if ext not in server_extensions:
fail('client extension "%s" is unsupported by server' % ext)
# Assert that returned protocol is supported
protocol = header('Sec-WebSocket-Protocol')
if protocol:
if protocol != 'null' and protocol not in self.protocols:
fail('unsupported protocol "%s"' % protocol)
self.protocol = protocol
def enable_ssl(self, *args, **kwargs): def enable_ssl(self, *args, **kwargs):
""" """
Transform the regular socket.socket to an ssl.SSLSocket for secure Transforms the regular socket.socket to an ssl.SSLSocket for secure
connections. Any arguments are passed to ssl.wrap_socket: connections. Any arguments are passed to ssl.wrap_socket:
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment