|
@@ -1,5 +1,6 @@
|
|
|
import socket
|
|
import socket
|
|
|
import logging
|
|
import logging
|
|
|
|
|
+import time
|
|
|
from traceback import format_exc
|
|
from traceback import format_exc
|
|
|
from threading import Thread
|
|
from threading import Thread
|
|
|
from ssl import SSLError
|
|
from ssl import SSLError
|
|
@@ -31,7 +32,7 @@ class Server(object):
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[],
|
|
def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[],
|
|
|
- secure=False, **kwargs):
|
|
|
|
|
|
|
+ secure=False, max_join_time=2.0, **kwargs):
|
|
|
"""
|
|
"""
|
|
|
Constructor for a simple websocket server.
|
|
Constructor for a simple websocket server.
|
|
|
|
|
|
|
@@ -48,6 +49,9 @@ class Server(object):
|
|
|
this case, `keyfile` and `certfile` must be specified. Any additional
|
|
this case, `keyfile` and `certfile` must be specified. Any additional
|
|
|
keyword arguments are passed to websocket.enable_ssl (and thus to
|
|
keyword arguments are passed to websocket.enable_ssl (and thus to
|
|
|
ssl.wrap_socket).
|
|
ssl.wrap_socket).
|
|
|
|
|
+
|
|
|
|
|
+ `max_join_time` is the maximum time (in seconds) to wait for client
|
|
|
|
|
+ responses after sending CLOSE frames, it defaults to 2 seconds.
|
|
|
"""
|
|
"""
|
|
|
logging.basicConfig(level=loglevel,
|
|
logging.basicConfig(level=loglevel,
|
|
|
format='%(asctime)s: %(levelname)s: %(message)s',
|
|
format='%(asctime)s: %(levelname)s: %(message)s',
|
|
@@ -66,8 +70,11 @@ class Server(object):
|
|
|
self.sock.listen(5)
|
|
self.sock.listen(5)
|
|
|
|
|
|
|
|
self.clients = []
|
|
self.clients = []
|
|
|
|
|
+ self.client_threads = []
|
|
|
self.protocols = protocols
|
|
self.protocols = protocols
|
|
|
|
|
|
|
|
|
|
+ self.max_join_time = max_join_time
|
|
|
|
|
+
|
|
|
def run(self):
|
|
def run(self):
|
|
|
while True:
|
|
while True:
|
|
|
try:
|
|
try:
|
|
@@ -80,6 +87,7 @@ class Server(object):
|
|
|
thread = Thread(target=client.receive_forever)
|
|
thread = Thread(target=client.receive_forever)
|
|
|
thread.daemon = True
|
|
thread.daemon = True
|
|
|
thread.start()
|
|
thread.start()
|
|
|
|
|
+ self.client_threads.append(thread)
|
|
|
except SSLError as e:
|
|
except SSLError as e:
|
|
|
logging.error('SSL error: %s', e)
|
|
logging.error('SSL error: %s', e)
|
|
|
except HandshakeError as e:
|
|
except HandshakeError as e:
|
|
@@ -94,7 +102,13 @@ class Server(object):
|
|
|
|
|
|
|
|
def quit_gracefully(self):
|
|
def quit_gracefully(self):
|
|
|
for client in self.clients:
|
|
for client in self.clients:
|
|
|
- client.close(CLOSE_NORMAL)
|
|
|
|
|
|
|
+ client.send_close_frame()
|
|
|
|
|
+
|
|
|
|
|
+ start_time = time.time()
|
|
|
|
|
+
|
|
|
|
|
+ while time.time() - start_time <= self.max_join_time \
|
|
|
|
|
+ and any(t.is_alive() for t in self.client_threads):
|
|
|
|
|
+ time.sleep(0.050)
|
|
|
|
|
|
|
|
def remove_client(self, client, code, reason):
|
|
def remove_client(self, client, code, reason):
|
|
|
self.clients.remove(client)
|
|
self.clients.remove(client)
|