websocket.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import re
  2. import socket
  3. import ssl
  4. from hashlib import sha1
  5. from base64 import b64encode
  6. from frame import receive_frame
  7. from errors import HandshakeError, SSLError
  8. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  9. WS_VERSION = '13'
  10. def split_stripped(value, delim=','):
  11. return map(str.strip, str(value).split(delim))
  12. class websocket(object):
  13. """
  14. Implementation of web socket, upgrades a regular TCP socket to a websocket
  15. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  16. The API of a websocket is identical to that of a regular socket, as
  17. illustrated by the examples below.
  18. Server example:
  19. >>> import twspy, socket
  20. >>> sock = twspy.websocket()
  21. >>> sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  22. >>> sock.bind(('', 8000))
  23. >>> sock.listen()
  24. >>> client = sock.accept()
  25. >>> client.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Client!'))
  26. >>> frame = client.recv()
  27. Client example:
  28. >>> import twspy
  29. >>> sock = twspy.websocket()
  30. >>> sock.connect(('', 8000))
  31. >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
  32. """
  33. def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET,
  34. sproto=0):
  35. """
  36. Create a regular TCP socket of family `family` and protocol
  37. `sock` is an optional regular TCP socket to be used for sending binary
  38. data. If not specified, a new socket is created.
  39. `protocols` is a list of supported protocol names.
  40. `extensions` is a list of supported extensions.
  41. `sfamily` and `sproto` are used for the regular socket constructor.
  42. """
  43. self.protocols = protocols
  44. self.extensions = extensions
  45. self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
  46. self.secure = False
  47. self.handshake_started = False
  48. def bind(self, address):
  49. self.sock.bind(address)
  50. def listen(self, backlog):
  51. self.sock.listen(backlog)
  52. def accept(self):
  53. """
  54. Equivalent to socket.accept(), but transforms the socket into a
  55. websocket instance and sends a server handshake (after receiving a
  56. client handshake). Note that the handshake may raise a HandshakeError
  57. exception.
  58. """
  59. sock, address = self.sock.accept()
  60. wsock = websocket(sock)
  61. wsock.server_handshake()
  62. return wsock, address
  63. def connect(self, address):
  64. """
  65. Equivalent to socket.connect(), but sends an client handshake request
  66. after connecting.
  67. """
  68. self.sock.sonnect(address)
  69. self.client_handshake()
  70. def send(self, *args):
  71. """
  72. Send a number of frames.
  73. """
  74. for frame in args:
  75. #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
  76. self.sock.sendall(frame.pack())
  77. def recv(self):
  78. """
  79. Receive a single frames. This can be either a data frame or a control
  80. frame.
  81. """
  82. frame = receive_frame(self.sock)
  83. #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
  84. return frame
  85. def recvn(self, n):
  86. """
  87. Receive exactly `n` frames. These can be either data frames or control
  88. frames, or a combination of both.
  89. """
  90. return [self.recv() for i in xrange(n)]
  91. def getpeername(self):
  92. return self.sock.getpeername()
  93. def getsockname(self):
  94. return self.sock.getsockname()
  95. def setsockopt(self, level, optname, value):
  96. self.sock.setsockopt(level, optname, value)
  97. def getsockopt(self, level, optname):
  98. return self.sock.getsockopt(level, optname)
  99. def close(self):
  100. self.sock.close()
  101. def server_handshake(self):
  102. """
  103. Execute a handshake as the server end point of the socket. If the HTTP
  104. request headers sent by the client are invalid, a HandshakeError
  105. is raised.
  106. """
  107. # Receive HTTP header
  108. raw_headers = ''
  109. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  110. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  111. # Request must be HTTP (at least 1.1) GET request, find the location
  112. location = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers).group(1)
  113. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  114. header_names = [name for name, value in headers]
  115. def header(name):
  116. return ', '.join([v for n, v in headers if n == name])
  117. # Check if headers that MUST be present are actually present
  118. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  119. 'Origin', 'Sec-WebSocket-Version'):
  120. if name not in header_names:
  121. raise HandshakeError('missing "%s" header' % name)
  122. # Check WebSocket version used by client
  123. version = header('Sec-WebSocket-Version')
  124. if version != WS_VERSION:
  125. raise HandshakeError('WebSocket version %s requested (only %s '
  126. 'is supported)' % (version, WS_VERSION))
  127. # Only supported protocols are returned
  128. proto = header('Sec-WebSocket-Extensions')
  129. protocols = split_stripped(proto) if proto else []
  130. protocols = [p for p in protocols if p in self.protocols]
  131. # Only supported extensions are returned
  132. ext = header('Sec-WebSocket-Extensions')
  133. extensions = split_stripped(ext) if ext else []
  134. extensions = [e for e in extensions if e in self.extensions]
  135. # Encode acceptation key using the WebSocket GUID
  136. key = header('Sec-WebSocket-Key').strip()
  137. accept = b64encode(sha1(key + WS_GUID).digest())
  138. # Construct HTTP response header
  139. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  140. shake += 'Upgrade: WebSocket\r\n'
  141. shake += 'Connection: Upgrade\r\n'
  142. shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
  143. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  144. % (header('Host'), location)
  145. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  146. if protocols:
  147. shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(protocols)
  148. if extensions:
  149. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
  150. self.sock.sendall(shake + '\r\n')
  151. self.handshake_started = True
  152. def client_handshake(self):
  153. """
  154. Execute a handshake as the client end point of the socket. May raise a
  155. HandshakeError if the server response is invalid.
  156. """
  157. # TODO: implement HTTP request headers for client handshake
  158. self.handshake_started = True
  159. raise NotImplementedError
  160. def enable_ssl(self, *args, **kwargs):
  161. """
  162. Transform the regular socket.socket to an ssl.SSLSocket for secure
  163. connections. Any arguments are passed to ssl.wrap_socket:
  164. http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
  165. """
  166. if self.handshake_started:
  167. raise SSLError('can only enable SSL before handshake')
  168. self.secure = True
  169. self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)