| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- import socket
- import logging
- import time
- from traceback import format_exc
- from threading import Thread
- from ssl import SSLError
- from websocket import websocket
- from connection import Connection
- from errors import HandshakeError
- class Server(object):
- """
- Websocket server, manages multiple client connections.
- Example usage:
- >>> import wspy
- >>> class EchoServer(wspy.Server):
- >>> def onopen(self, client):
- >>> print 'Client %s connected' % client
- >>> def onmessage(self, client, message):
- >>> print 'Received message "%s"' % message.payload
- >>> client.send(wspy.TextMessage(message.payload))
- >>> def onclose(self, client, code, reason):
- >>> print 'Client %s disconnected' % client
- >>> EchoServer(('', 8000)).run()
- """
- def __init__(self, address, loglevel=logging.INFO, ssl_args=None,
- max_join_time=2.0, **kwargs):
- """
- Constructor for a simple web socket server.
- `address` is a (hostname, port) tuple to bind the web socket to.
- `loglevel` values should be imported from the logging module.
- logging.INFO only shows server start/stop messages, logging.DEBUG shows
- clients (dis)connecting and messages being sent/received.
- `protocols` and `extensions` are passed directly to the websocket
- constructor.
- `ssl_args` is a dictionary with arguments for `websocket.enable_ssl`
- (and thus to ssl.wrap_socket). If omitted, the server is not
- SSL-enabled. If specified, at least the dictionary keys "keyfile" and
- "certfile" must be present because these are required arguments for
- `websocket.enable_ssl` for a server socket.
- `max_join_time` is the maximum time (in seconds) to wait for client
- responses after sending CLOSE frames, it defaults to 2 seconds.
- """
- logging.basicConfig(level=loglevel,
- format='%(asctime)s: %(levelname)s: %(message)s',
- datefmt='%H:%M:%S')
- scheme = 'wss' if ssl_args else 'ws'
- hostname, port = address
- logging.info('Starting server at %s://%s:%d', scheme, hostname, port)
- self.sock = websocket(**kwargs)
- self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if ssl_args:
- self.sock.enable_ssl(server_side=True, **ssl_args)
- self.sock.bind(address)
- self.sock.listen(5)
- self.clients = []
- self.client_threads = []
- self.max_join_time = max_join_time
- def run(self):
- while True:
- try:
- sock, address = self.sock.accept()
- client = Client(self, sock)
- self.clients.append(client)
- logging.debug('Registered client %s', client)
- thread = Thread(target=client.receive_forever)
- thread.daemon = True
- thread.start()
- self.client_threads.append(thread)
- except SSLError as e:
- logging.error('SSL error: %s', e)
- except HandshakeError as e:
- logging.error('Invalid request: %s', e.message)
- except KeyboardInterrupt:
- logging.info('Received interrupt, stopping server...')
- break
- except Exception as e:
- logging.error(format_exc(e))
- self.quit_gracefully()
- def quit_gracefully(self):
- # Send a CLOSE frame so that the client connection will receive a
- # response CLOSE frame
- for client in self.clients:
- client.send_close_frame()
- # Wait for the CLOSE frames to be received. Wait for all threads in one
- # loop instead of joining separately, so that timeouts are not
- # propagated
- start_time = time.time()
- while time.time() - start_time <= self.max_join_time \
- and any(t.is_alive() for t in self.client_threads):
- time.sleep(0.050)
- # Close remaining sockets, this will trigger a socket.error in the
- # receive_forever() thread, causing the Connection.onclose() handler to
- # be invoked
- for client in self.clients:
- try:
- client.sock.close()
- except socket.error:
- pass
- # Wait for the onclose() handlers to finish
- for thread in self.client_threads:
- thread.join()
- def remove_client(self, client, code, reason):
- self.clients.remove(client)
- self.onclose(client, code, reason)
- def onopen(self, client):
- logging.debug('Opened socket to %s', client)
- def onmessage(self, client, message):
- logging.debug('Received %s from %s', message, client)
- def onping(self, client, payload):
- logging.debug('Sent ping "%s" to %s', payload, client)
- def onpong(self, client, payload):
- logging.debug('Received pong "%s" from %s', payload, client)
- def onclose(self, client, code, reason):
- msg = 'Closed socket to %s' % client
- if code is not None:
- msg += ' [%d]' % code
- if len(reason):
- msg += ': ' + reason
- logging.debug(msg)
- def onerror(self, client, e):
- logging.error(format_exc(e))
- class Client(Connection):
- def __init__(self, server, sock):
- self.server = server
- super(Client, self).__init__(sock)
- def __str__(self):
- try:
- return '<Client at %s:%d>' % self.sock.getpeername()
- except socket.error:
- return '<Client on closed socket>'
- def send(self, message, fragment_size=None, mask=False):
- logging.debug('Sending %s to %s', message, self)
- Connection.send(self, message, fragment_size=fragment_size, mask=mask)
- def onopen(self):
- self.server.onopen(self)
- def onmessage(self, message):
- self.server.onmessage(self, message)
- def onping(self, payload):
- self.server.onping(self, payload)
- def onpong(self, payload):
- self.server.onpong(self, payload)
- def onclose(self, code, reason):
- self.server.remove_client(self, code, reason)
- def onerror(self, e):
- self.server.onerror(self, e)
- if __name__ == '__main__':
- import sys
- port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
- Server(('', port), loglevel=logging.DEBUG).run()
|