|
|
@@ -6,7 +6,7 @@ from threading import Thread
|
|
|
from frame import ControlFrame, receive_fragments, receive_frame, \
|
|
|
OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG
|
|
|
from message import create_message
|
|
|
-from exceptions import SocketClosed, PingError
|
|
|
+from exceptions import InvalidRequest, SocketClosed, PingError
|
|
|
|
|
|
|
|
|
WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
|
|
@@ -14,9 +14,8 @@ WS_VERSION = '13'
|
|
|
|
|
|
|
|
|
class WebSocket(object):
|
|
|
- def __init__(self, sock, encoding=None):
|
|
|
+ def __init__(self, sock):
|
|
|
self.sock = sock
|
|
|
- self.encoding = encoding
|
|
|
|
|
|
self.received_close_params = None
|
|
|
self.close_frame_sent = False
|
|
|
@@ -57,40 +56,53 @@ class WebSocket(object):
|
|
|
payload = ''.join([f.payload for f in frames])
|
|
|
return create_message(frames[0].opcode, payload)
|
|
|
|
|
|
- def send_handshake(self):
|
|
|
- raw_headers = self.sock.recv(512)
|
|
|
-
|
|
|
- if self.encoding:
|
|
|
- raw_headers = raw_headers.decode(self.encoding, 'ignore')
|
|
|
+ def handshake(self):
|
|
|
+ """
|
|
|
+ Execute a handshake with the other end point of the socket. If the HTTP
|
|
|
+ request headers read from the socket are invalid, an InvalidRequest
|
|
|
+ exception will be raised.
|
|
|
+ """
|
|
|
+ raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
|
|
|
|
|
|
+ # request must be HTTP (at least 1.1) GET request, find the location
|
|
|
location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
|
|
|
headers = dict(re.findall(r'(.*?): (.*?)\r\n', raw_headers))
|
|
|
|
|
|
# Check if headers that MUST be present are actually present
|
|
|
for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
|
|
|
'Origin', 'Sec-WebSocket-Version'):
|
|
|
- assert name in headers
|
|
|
+ if name not in headers:
|
|
|
+ raise InvalidRequest('missing "%s" header' % name)
|
|
|
|
|
|
# Check WebSocket version used by client
|
|
|
- assert headers['Sec-WebSocket-Version'] == WS_VERSION
|
|
|
+ version = headers['Sec-WebSocket-Version']
|
|
|
+
|
|
|
+ if version != WS_VERSION:
|
|
|
+ raise InvalidRequest('WebSocket version %s requested (only %s '
|
|
|
+ 'is supported)' % (version, WS_VERSION))
|
|
|
|
|
|
# Make sure the requested protocols are supported by this server
|
|
|
if 'Sec-WebSocket-Protocol' in headers:
|
|
|
parts = headers['Sec-WebSocket-Protocol'].split(',')
|
|
|
protocols = map(str.strip, parts)
|
|
|
|
|
|
- for protocol in protocols:
|
|
|
- assert protocol in self.protocols
|
|
|
+ for p in protocols:
|
|
|
+ if p not in self.protocols:
|
|
|
+ raise InvalidRequest('unsupported protocol "%s"' % p)
|
|
|
else:
|
|
|
protocols = []
|
|
|
|
|
|
+ # Encode acceptation key using the WebSocket GUID
|
|
|
key = headers['Sec-WebSocket-Key']
|
|
|
accept = sha1(key + WS_GUID).digest().encode('base64')
|
|
|
+
|
|
|
+ # Construct HTTP response header
|
|
|
shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
|
|
|
shake += 'Upgrade: WebSocket\r\n'
|
|
|
shake += 'Connection: Upgrade\r\n'
|
|
|
shake += 'WebSocket-Origin: %s\r\n' % headers['Origin']
|
|
|
- shake += 'WebSocket-Location: ws://%s%s\r\n' % (headers['Host'], location)
|
|
|
+ shake += 'WebSocket-Location: ws://%s%s\r\n' \
|
|
|
+ % (headers['Host'], location)
|
|
|
shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
|
|
|
|
|
|
if self.protocols:
|
|
|
@@ -170,7 +182,7 @@ class WebSocket(object):
|
|
|
Called when a message is received. `message' is a Message object, which
|
|
|
can be constructed from a single frame or multiple fragmented frames.
|
|
|
"""
|
|
|
- raise NotImplemented
|
|
|
+ return NotImplemented
|
|
|
|
|
|
def onping(self, payload):
|
|
|
"""
|
|
|
@@ -178,13 +190,13 @@ class WebSocket(object):
|
|
|
used to start a timeout handler for a pong message that is not received
|
|
|
in time.
|
|
|
"""
|
|
|
- raise NotImplemented
|
|
|
+ pass
|
|
|
|
|
|
def onpong(self, payload):
|
|
|
"""
|
|
|
Called when a pong control frame is received.
|
|
|
"""
|
|
|
- raise NotImplemented
|
|
|
+ pass
|
|
|
|
|
|
def onclose(self, code, reason):
|
|
|
"""
|