| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- import socket
- from select import epoll, EPOLLIN, EPOLLOUT, EPOLLHUP
- from traceback import format_exc
- import logging
- from connection import Connection
- from frame import ControlFrame, OPCODE_PING, OPCODE_CONTINUATION, \
- create_close_frame
- from server import Server, Client
- from errors import HandshakeError, SocketClosed
- class AsyncConnection(Connection):
- def __init__(self, sock):
- sock.recv_callback = self.contruct_message
- sock.recv_close_callback = self.onclose
- self.recvbuf = []
- Connection.__init__(self, sock)
- def contruct_message(self, frame):
- if isinstance(frame, ControlFrame):
- self.handle_control_frame(frame)
- return
- self.recvbuf.append(frame)
- if frame.final:
- message = self.concat_fragments(self.recvbuf)
- self.recvbuf = []
- self.onmessage(message)
- elif len(self.recvbuf) > 1 and frame.opcode != OPCODE_CONTINUATION:
- raise ValueError('expected continuation/control frame, got %s '
- 'instead' % frame)
- def send(self, message, fragment_size=None, mask=False):
- frames = list(self.message_to_frames(message, fragment_size, mask))
- for frame in frames[:-1]:
- self.sock.queue_send(frame)
- self.sock.queue_send(frames[-1], lambda: self.onsent(message))
- def send_frame(self, frame, callback):
- self.sock.queue_send(frame, callback)
- def do_async_send(self):
- self.execute_controlled(self.sock.do_async_send)
- def do_async_recv(self, bufsize):
- self.execute_controlled(self.sock.do_async_recv, bufsize)
- def execute_controlled(self, func, *args, **kwargs):
- try:
- func(*args, **kwargs)
- except (KeyboardInterrupt, SystemExit, SocketClosed):
- raise
- except Exception as e:
- self.onerror(e)
- self.onclose(None, 'error: %s' % e)
- try:
- self.sock.close()
- except socket.error:
- pass
- raise e
- def send_close_frame(self, code, reason):
- self.sock.queue_send(create_close_frame(code, reason),
- self.shutdown_write)
- self.close_frame_sent = True
- def close(self, code=None, reason=''):
- self.send_close_frame(code, reason)
- def send_ping(self, payload=''):
- self.sock.queue_send(ControlFrame(OPCODE_PING, payload),
- lambda: self.onping(payload))
- self.ping_payload = payload
- self.ping_sent = True
- def onsent(self, message):
- """
- Called after a message has been written.
- """
- return NotImplemented
- class AsyncServer(Server):
- def __init__(self, *args, **kwargs):
- Server.__init__(self, *args, **kwargs)
- self.recvbuf_size = kwargs.get('recvbuf_size', 2048)
- self.epoll = epoll()
- self.epoll.register(self.sock.fileno(), EPOLLIN)
- self.conns = {}
- @property
- def clients(self):
- return self.conns.values()
- def remove_client(self, client, code, reason):
- self.epoll.unregister(client.fno)
- del self.conns[client.fno]
- self.onclose(client, code, reason)
- def handle_events(self):
- for fileno, event in self.epoll.poll(1):
- if fileno == self.sock.fileno():
- try:
- sock, addr = self.sock.accept()
- except HandshakeError as e:
- logging.error('Invalid request: %s', e.message)
- continue
- client = AsyncClient(self, sock)
- client.fno = sock.fileno()
- sock.setblocking(0)
- self.epoll.register(client.fno, EPOLLIN)
- self.conns[client.fno] = client
- logging.debug('Registered client %s', client)
- elif event & EPOLLHUP:
- self.epoll.unregister(fileno)
- del self.conns[fileno]
- else:
- conn = self.conns[fileno]
- try:
- if event & EPOLLOUT:
- conn.do_async_send()
- elif event & EPOLLIN:
- conn.do_async_recv(self.recvbuf_size)
- except (KeyboardInterrupt, SystemExit):
- raise
- except SocketClosed:
- continue
- except Exception as e:
- logging.error(format_exc(e).rstrip())
- continue
- self.update_mask(conn)
- def run(self):
- try:
- while True:
- self.handle_events()
- except (KeyboardInterrupt, SystemExit):
- logging.info('Received interrupt, stopping server...')
- finally:
- self.epoll.unregister(self.sock.fileno())
- self.epoll.close()
- self.sock.close()
- def update_mask(self, conn):
- mask = 0
- if conn.sock.can_send():
- mask |= EPOLLOUT
- if conn.sock.can_recv():
- mask |= EPOLLIN
- self.epoll.modify(conn.sock.fileno(), mask)
- def onsent(self, client, message):
- return NotImplemented
- class AsyncClient(Client, AsyncConnection):
- def __init__(self, server, sock):
- self.server = server
- AsyncConnection.__init__(self, sock)
- def send(self, message, fragment_size=None, mask=False):
- logging.debug('Enqueueing %s to %s', message, self)
- AsyncConnection.send(self, message, fragment_size, mask)
- self.server.update_mask(self)
- def onsent(self, message):
- logging.debug('Finished sending %s to %s', message, self)
- self.server.onsent(self, message)
- if __name__ == '__main__':
- import sys
- port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
- AsyncServer(('', port), loglevel=logging.DEBUG).run()
|