Commit bfbe3934 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Revised server-side connection closing and added a timeout to waiting for CLOSE frames

parent 64a1553c
...@@ -22,6 +22,7 @@ class Connection(object): ...@@ -22,6 +22,7 @@ class Connection(object):
""" """
self.sock = sock self.sock = sock
self.close_frame_sent = False
self.close_frame_received = False self.close_frame_received = False
self.ping_sent = False self.ping_sent = False
self.ping_payload = None self.ping_payload = None
...@@ -129,6 +130,14 @@ class Connection(object): ...@@ -129,6 +130,14 @@ class Connection(object):
self.ping_sent = True self.ping_sent = True
self.onping(payload) self.onping(payload)
def send_close_frame(self, code=None, reason=''):
"""
Send a CLOSE control frame.
"""
payload = '' if code is None else struct.pack('!H', code) + reason
self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
self.close_frame_sent = True
def close(self, code=None, reason=''): def close(self, code=None, reason=''):
""" """
Close the socket by sending a CLOSE frame and waiting for a response Close the socket by sending a CLOSE frame and waiting for a response
...@@ -138,8 +147,8 @@ class Connection(object): ...@@ -138,8 +147,8 @@ class Connection(object):
actually closed. actually closed.
""" """
# Send CLOSE frame # Send CLOSE frame
payload = '' if code is None else struct.pack('!H', code) + reason if not self.close_frame_sent:
self.sock.send(ControlFrame(OPCODE_CLOSE, payload)) self.send_close_frame(code, reason)
# Receive CLOSE frame # Receive CLOSE frame
if not self.close_frame_received: if not self.close_frame_received:
...@@ -148,6 +157,7 @@ class Connection(object): ...@@ -148,6 +157,7 @@ class Connection(object):
if frame.opcode != OPCODE_CLOSE: if frame.opcode != OPCODE_CLOSE:
raise ValueError('expected CLOSE frame, got %s' % frame) raise ValueError('expected CLOSE frame, got %s' % frame)
self.close_frame_received = True
res_code, res_reason = frame.unpack_close() res_code, res_reason = frame.unpack_close()
# FIXME: check if res_code == code and res_reason == reason? # FIXME: check if res_code == code and res_reason == reason?
......
import socket import socket
import logging import logging
import time
from traceback import format_exc from traceback import format_exc
from threading import Thread from threading import Thread
from ssl import SSLError from ssl import SSLError
...@@ -31,7 +32,7 @@ class Server(object): ...@@ -31,7 +32,7 @@ class Server(object):
""" """
def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[], def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[],
secure=False, **kwargs): secure=False, max_join_time=2.0, **kwargs):
""" """
Constructor for a simple websocket server. Constructor for a simple websocket server.
...@@ -48,6 +49,9 @@ class Server(object): ...@@ -48,6 +49,9 @@ class Server(object):
this case, `keyfile` and `certfile` must be specified. Any additional this case, `keyfile` and `certfile` must be specified. Any additional
keyword arguments are passed to websocket.enable_ssl (and thus to keyword arguments are passed to websocket.enable_ssl (and thus to
ssl.wrap_socket). ssl.wrap_socket).
`max_join_time` is the maximum time (in seconds) to wait for client
responses after sending CLOSE frames, it defaults to 2 seconds.
""" """
logging.basicConfig(level=loglevel, logging.basicConfig(level=loglevel,
format='%(asctime)s: %(levelname)s: %(message)s', format='%(asctime)s: %(levelname)s: %(message)s',
...@@ -66,8 +70,11 @@ class Server(object): ...@@ -66,8 +70,11 @@ class Server(object):
self.sock.listen(5) self.sock.listen(5)
self.clients = [] self.clients = []
self.client_threads = []
self.protocols = protocols self.protocols = protocols
self.max_join_time = max_join_time
def run(self): def run(self):
while True: while True:
try: try:
...@@ -80,6 +87,7 @@ class Server(object): ...@@ -80,6 +87,7 @@ class Server(object):
thread = Thread(target=client.receive_forever) thread = Thread(target=client.receive_forever)
thread.daemon = True thread.daemon = True
thread.start() thread.start()
self.client_threads.append(thread)
except SSLError as e: except SSLError as e:
logging.error('SSL error: %s', e) logging.error('SSL error: %s', e)
except HandshakeError as e: except HandshakeError as e:
...@@ -94,7 +102,13 @@ class Server(object): ...@@ -94,7 +102,13 @@ class Server(object):
def quit_gracefully(self): def quit_gracefully(self):
for client in self.clients: for client in self.clients:
client.close(CLOSE_NORMAL) client.send_close_frame()
start_time = time.time()
while time.time() - start_time <= self.max_join_time \
and any(t.is_alive() for t in self.client_threads):
time.sleep(0.050)
def remove_client(self, client, code, reason): def remove_client(self, client, code, reason):
self.clients.remove(client) self.clients.remove(client)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment