Commit 7b56e728 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Server now closes all client connections after keyboard interrupt

parent 9f7b6c86
import socket import socket
import logging import logging
from traceback import format_exc from traceback import format_exc
from threading import Thread
from websocket import WebSocket from websocket import WebSocket
from exceptions import InvalidRequest from exceptions import InvalidRequest
from frame import CLOSE_NORMAL
class Server(object): class Server(object):
...@@ -26,10 +28,14 @@ class Server(object): ...@@ -26,10 +28,14 @@ class Server(object):
try: try:
sock, address = self.sock.accept() sock, address = self.sock.accept()
client = Client(self, sock, address) client = Client(self, sock, address)
client.handshake()
client.server_handshake()
self.clients.append(client) self.clients.append(client)
logging.info('Registered client %s', client) logging.info('Registered client %s', client)
client.run_threaded()
thread = Thread(target=client.receive_forever)
thread.daemon = True
thread.start()
except InvalidRequest as e: except InvalidRequest as e:
logging.error('Invalid request: %s', e.message) logging.error('Invalid request: %s', e.message)
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -38,6 +44,12 @@ class Server(object): ...@@ -38,6 +44,12 @@ class Server(object):
except Exception as e: except Exception as e:
logging.error(format_exc(e)) logging.error(format_exc(e))
self.quit_gracefully()
def quit_gracefully(self):
for client in self.clients:
client.close(CLOSE_NORMAL)
def remove_client(self, client, code, reason): def remove_client(self, client, code, reason):
self.clients.remove(client) self.clients.remove(client)
self.onclose(client, code, reason) self.onclose(client, code, reason)
...@@ -71,6 +83,7 @@ class Client(WebSocket): ...@@ -71,6 +83,7 @@ class Client(WebSocket):
super(Client, self).__init__(sock) super(Client, self).__init__(sock)
self.server = server self.server = server
self.address = address self.address = address
self.send_lock = Lock()
def onopen(self): def onopen(self):
self.server.onopen(self) self.server.onopen(self)
......
import re import re
import struct import struct
from hashlib import sha1 from hashlib import sha1
from threading import Thread
from frame import ControlFrame, receive_fragments, receive_frame, \ from frame import ControlFrame, receive_fragments, receive_frame, \
OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG
...@@ -65,10 +64,10 @@ class WebSocket(object): ...@@ -65,10 +64,10 @@ class WebSocket(object):
payload = ''.join([f.payload for f in frames]) payload = ''.join([f.payload for f in frames])
return create_message(frames[0].opcode, payload) return create_message(frames[0].opcode, payload)
def handshake(self): def server_handshake(self):
""" """
Execute a handshake with the other end point of the socket. If the HTTP Execute a handshake as the server end point of the socket. If the HTTP
request headers read from the socket are invalid, an InvalidRequest request headers sent by the client are invalid, an InvalidRequest
exception is raised. exception is raised.
""" """
raw_headers = self.sock.recv(512).decode('utf-8', 'ignore') raw_headers = self.sock.recv(512).decode('utf-8', 'ignore')
...@@ -141,15 +140,6 @@ class WebSocket(object): ...@@ -141,15 +140,6 @@ class WebSocket(object):
except Exception as e: except Exception as e:
self.onexception(e) self.onexception(e)
def run_threaded(self, daemon=True):
"""
Spawn a new thread that receives messages in an endless loop.
"""
thread = Thread(target=self.receive_forever)
thread.daemon = daemon
thread.start()
return thread
def send_close(self, code, reason): def send_close(self, code, reason):
""" """
Send a close control frame. Send a close control frame.
...@@ -193,7 +183,7 @@ class WebSocket(object): ...@@ -193,7 +183,7 @@ class WebSocket(object):
self.sock.close() self.sock.close()
if frame.opcode != OPCODE_CLOSE: if frame.opcode != OPCODE_CLOSE:
raise ValueError('Expected close frame, got %s instead' % frame) raise ValueError('expected close frame, got %s instead' % frame)
def onopen(self): def onopen(self):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment