Преглед изворни кода

Revised server-side connection closing and added a timeout to waiting for CLOSE frames

Taddeus Kroes пре 12 година
родитељ
комит
bfbe3934e5
2 измењених фајлова са 28 додато и 4 уклоњено
  1. 12 2
      connection.py
  2. 16 2
      server.py

+ 12 - 2
connection.py

@@ -22,6 +22,7 @@ class Connection(object):
         """
         self.sock = sock
 
+        self.close_frame_sent = False
         self.close_frame_received = False
         self.ping_sent = False
         self.ping_payload = None
@@ -129,6 +130,14 @@ class Connection(object):
         self.ping_sent = True
         self.onping(payload)
 
+    def send_close_frame(self, code=None, reason=''):
+        """
+        Send a CLOSE control frame.
+        """
+        payload = '' if code is None else struct.pack('!H', code) + reason
+        self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
+        self.close_frame_sent = True
+
     def close(self, code=None, reason=''):
         """
         Close the socket by sending a CLOSE frame and waiting for a response
@@ -138,8 +147,8 @@ class Connection(object):
         actually closed.
         """
         # Send CLOSE frame
-        payload = '' if code is None else struct.pack('!H', code) + reason
-        self.sock.send(ControlFrame(OPCODE_CLOSE, payload))
+        if not self.close_frame_sent:
+            self.send_close_frame(code, reason)
 
         # Receive CLOSE frame
         if not self.close_frame_received:
@@ -148,6 +157,7 @@ class Connection(object):
             if frame.opcode != OPCODE_CLOSE:
                 raise ValueError('expected CLOSE frame, got %s' % frame)
 
+            self.close_frame_received = True
             res_code, res_reason = frame.unpack_close()
 
             # FIXME: check if res_code == code and res_reason == reason?

+ 16 - 2
server.py

@@ -1,5 +1,6 @@
 import socket
 import logging
+import time
 from traceback import format_exc
 from threading import Thread
 from ssl import SSLError
@@ -31,7 +32,7 @@ class Server(object):
     """
 
     def __init__(self, port, hostname='', loglevel=logging.INFO, protocols=[],
-                 secure=False, **kwargs):
+                 secure=False, max_join_time=2.0, **kwargs):
         """
         Constructor for a simple websocket server.
 
@@ -48,6 +49,9 @@ class Server(object):
         this case, `keyfile` and `certfile` must be specified. Any additional
         keyword arguments are passed to websocket.enable_ssl (and thus to
         ssl.wrap_socket).
+
+        `max_join_time` is the maximum time (in seconds) to wait for client
+        responses after sending CLOSE frames, it defaults to 2 seconds.
         """
         logging.basicConfig(level=loglevel,
                 format='%(asctime)s: %(levelname)s: %(message)s',
@@ -66,8 +70,11 @@ class Server(object):
         self.sock.listen(5)
 
         self.clients = []
+        self.client_threads = []
         self.protocols = protocols
 
+        self.max_join_time = max_join_time
+
     def run(self):
         while True:
             try:
@@ -80,6 +87,7 @@ class Server(object):
                 thread = Thread(target=client.receive_forever)
                 thread.daemon = True
                 thread.start()
+                self.client_threads.append(thread)
             except SSLError as e:
                 logging.error('SSL error: %s', e)
             except HandshakeError as e:
@@ -94,7 +102,13 @@ class Server(object):
 
     def quit_gracefully(self):
         for client in self.clients:
-            client.close(CLOSE_NORMAL)
+            client.send_close_frame()
+
+        start_time = time.time()
+
+        while time.time() - start_time <= self.max_join_time \
+                and any(t.is_alive() for t in self.client_threads):
+            time.sleep(0.050)
 
     def remove_client(self, client, code, reason):
         self.clients.remove(client)