async.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import socket
  2. from select import epoll, EPOLLIN, EPOLLOUT, EPOLLHUP
  3. from traceback import format_exc
  4. import logging
  5. from connection import Connection
  6. from frame import ControlFrame, OPCODE_PING, OPCODE_CONTINUATION, \
  7. create_close_frame
  8. from server import Server, Client
  9. from errors import HandshakeError, SocketClosed
  10. class AsyncConnection(Connection):
  11. def __init__(self, sock):
  12. sock.recv_callback = self.contruct_message
  13. sock.recv_close_callback = self.onclose
  14. self.recvbuf = []
  15. Connection.__init__(self, sock)
  16. def contruct_message(self, frame):
  17. if isinstance(frame, ControlFrame):
  18. self.handle_control_frame(frame)
  19. return
  20. self.recvbuf.append(frame)
  21. if frame.final:
  22. message = self.concat_fragments(self.recvbuf)
  23. self.recvbuf = []
  24. self.onmessage(message)
  25. elif len(self.recvbuf) > 1 and frame.opcode != OPCODE_CONTINUATION:
  26. raise ValueError('expected continuation/control frame, got %s '
  27. 'instead' % frame)
  28. def send(self, message, fragment_size=None, mask=False):
  29. frames = list(self.message_to_frames(message, fragment_size, mask))
  30. for frame in frames[:-1]:
  31. self.sock.queue_send(frame)
  32. self.sock.queue_send(frames[-1], lambda: self.onsend(message))
  33. def send_frame(self, frame, callback):
  34. self.sock.queue_send(frame, callback)
  35. def do_async_send(self):
  36. self.execute_controlled(self.sock.do_async_send)
  37. def do_async_recv(self, bufsize):
  38. self.execute_controlled(self.sock.do_async_recv, bufsize)
  39. def execute_controlled(self, func, *args, **kwargs):
  40. try:
  41. func(*args, **kwargs)
  42. except (KeyboardInterrupt, SystemExit, SocketClosed):
  43. raise
  44. except Exception as e:
  45. self.onerror(e)
  46. self.onclose(None, 'error: %s' % e)
  47. try:
  48. self.sock.close()
  49. except socket.error:
  50. pass
  51. raise e
  52. def send_close_frame(self, code, reason):
  53. self.sock.queue_send(create_close_frame(code, reason),
  54. self.shutdown_write)
  55. self.close_frame_sent = True
  56. def close(self, code=None, reason=''):
  57. self.send_close_frame(code, reason)
  58. def send_ping(self, payload=''):
  59. self.sock.queue_send(ControlFrame(OPCODE_PING, payload),
  60. lambda: self.onping(payload))
  61. self.ping_payload = payload
  62. self.ping_sent = True
  63. def onsend(self, message):
  64. """
  65. Called after a message has been written.
  66. """
  67. return NotImplemented
  68. class AsyncServer(Server):
  69. def __init__(self, *args, **kwargs):
  70. Server.__init__(self, *args, **kwargs)
  71. self.recvbuf_size = kwargs.get('recvbuf_size', 2048)
  72. self.epoll = epoll()
  73. self.epoll.register(self.sock.fileno(), EPOLLIN)
  74. self.conns = {}
  75. @property
  76. def clients(self):
  77. return self.conns.values()
  78. def remove_client(self, client, code, reason):
  79. self.epoll.unregister(client.fno)
  80. del self.conns[client.fno]
  81. self.onclose(client, code, reason)
  82. def handle_events(self):
  83. for fileno, event in self.epoll.poll(1):
  84. if fileno == self.sock.fileno():
  85. try:
  86. sock, addr = self.sock.accept()
  87. except HandshakeError as e:
  88. logging.error('Invalid request: %s', e.message)
  89. continue
  90. client = AsyncClient(self, sock)
  91. client.fno = sock.fileno()
  92. sock.setblocking(0)
  93. self.epoll.register(client.fno, EPOLLIN)
  94. self.conns[client.fno] = client
  95. logging.debug('Registered client %s', client)
  96. elif event & EPOLLHUP:
  97. self.epoll.unregister(fileno)
  98. del self.conns[fileno]
  99. else:
  100. conn = self.conns[fileno]
  101. try:
  102. if event & EPOLLOUT:
  103. conn.do_async_send()
  104. elif event & EPOLLIN:
  105. conn.do_async_recv(self.recvbuf_size)
  106. except (KeyboardInterrupt, SystemExit):
  107. raise
  108. except SocketClosed:
  109. continue
  110. except Exception as e:
  111. logging.error(format_exc(e).rstrip())
  112. continue
  113. self.update_mask(conn)
  114. def run(self):
  115. try:
  116. while True:
  117. self.handle_events()
  118. except (KeyboardInterrupt, SystemExit):
  119. logging.info('Received interrupt, stopping server...')
  120. finally:
  121. self.epoll.unregister(self.sock.fileno())
  122. self.epoll.close()
  123. self.sock.close()
  124. def update_mask(self, conn):
  125. mask = 0
  126. if conn.sock.can_send():
  127. mask |= EPOLLOUT
  128. if conn.sock.can_recv():
  129. mask |= EPOLLIN
  130. self.epoll.modify(conn.sock.fileno(), mask)
  131. def onsend(self, client, message):
  132. return NotImplemented
  133. class AsyncClient(Client, AsyncConnection):
  134. def __init__(self, server, sock):
  135. self.server = server
  136. AsyncConnection.__init__(self, sock)
  137. def send(self, message, fragment_size=None, mask=False):
  138. logging.debug('Enqueueing %s to %s', message, self)
  139. AsyncConnection.send(self, message, fragment_size, mask)
  140. self.server.update_mask(self)
  141. def onsend(self, message):
  142. logging.debug('Finished sending %s to %s', message, self)
  143. self.server.onsend(self, message)
  144. if __name__ == '__main__':
  145. import sys
  146. port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
  147. AsyncServer(('', port), loglevel=logging.DEBUG).run()