websocket.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import os
  2. import re
  3. import socket
  4. import ssl
  5. from hashlib import sha1
  6. from base64 import b64encode
  7. from urlparse import urlparse
  8. from frame import receive_frame
  9. from errors import HandshakeError, SSLError
  10. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  11. WS_VERSION = '13'
  12. def split_stripped(value, delim=','):
  13. return map(str.strip, str(value).split(delim)) if value else []
  14. class websocket(object):
  15. """
  16. Implementation of web socket, upgrades a regular TCP socket to a websocket
  17. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  18. The API of a websocket is identical to that of a regular socket, as
  19. illustrated by the examples below.
  20. Server example:
  21. >>> import twspy, socket
  22. >>> sock = twspy.websocket()
  23. >>> sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  24. >>> sock.bind(('', 8000))
  25. >>> sock.listen()
  26. >>> client = sock.accept()
  27. >>> client.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Client!'))
  28. >>> frame = client.recv()
  29. Client example:
  30. >>> import twspy
  31. >>> sock = twspy.websocket()
  32. >>> sock.connect(('', 8000))
  33. >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
  34. """
  35. def __init__(self, sock=None, protocols=[], extensions=[], sfamily=socket.AF_INET,
  36. sproto=0):
  37. """
  38. Create a regular TCP socket of family `family` and protocol
  39. `sock` is an optional regular TCP socket to be used for sending binary
  40. data. If not specified, a new socket is created.
  41. `protocols` is a list of supported protocol names.
  42. `extensions` is a list of supported extensions.
  43. `sfamily` and `sproto` are used for the regular socket constructor.
  44. """
  45. self.protocols = protocols
  46. self.extensions = extensions
  47. self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
  48. self.secure = False
  49. self.handshake_started = False
  50. def bind(self, address):
  51. self.sock.bind(address)
  52. def listen(self, backlog):
  53. self.sock.listen(backlog)
  54. def accept(self):
  55. """
  56. Equivalent to socket.accept(), but transforms the socket into a
  57. websocket instance and sends a server handshake (after receiving a
  58. client handshake). Note that the handshake may raise a HandshakeError
  59. exception.
  60. """
  61. sock, address = self.sock.accept()
  62. wsock = websocket(sock)
  63. wsock.server_handshake()
  64. return wsock, address
  65. def connect(self, address, path='/', auth=None):
  66. """
  67. Equivalent to socket.connect(), but sends an client handshake request
  68. after connecting.
  69. `address` is a (host, port) tuple of the server to connect to.
  70. `path` is optional, used as the *location* part of the HTTP handshake.
  71. In a URL, this would show as ws://host[:port]/path.
  72. `auth` is optional, used for the HTTP "Authorization" header of the
  73. handshake request.
  74. """
  75. self.sock.connect(address)
  76. self.client_handshake(address, path, auth)
  77. def send(self, *args):
  78. """
  79. Send a number of frames.
  80. """
  81. for frame in args:
  82. #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
  83. self.sock.sendall(frame.pack())
  84. def recv(self):
  85. """
  86. Receive a single frames. This can be either a data frame or a control
  87. frame.
  88. """
  89. frame = receive_frame(self.sock)
  90. #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
  91. return frame
  92. def recvn(self, n):
  93. """
  94. Receive exactly `n` frames. These can be either data frames or control
  95. frames, or a combination of both.
  96. """
  97. return [self.recv() for i in xrange(n)]
  98. def getpeername(self):
  99. return self.sock.getpeername()
  100. def getsockname(self):
  101. return self.sock.getsockname()
  102. def setsockopt(self, level, optname, value):
  103. self.sock.setsockopt(level, optname, value)
  104. def getsockopt(self, level, optname):
  105. return self.sock.getsockopt(level, optname)
  106. def close(self):
  107. self.sock.close()
  108. def server_handshake(self):
  109. """
  110. Execute a handshake as the server end point of the socket. If the HTTP
  111. request headers sent by the client are invalid, a HandshakeError
  112. is raised.
  113. """
  114. def fail(msg):
  115. self.sock.close()
  116. raise HandshakeError(msg)
  117. # Receive HTTP header
  118. raw_headers = ''
  119. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  120. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  121. # Request must be HTTP (at least 1.1) GET request, find the location
  122. match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
  123. if match is None:
  124. fail('not a valid HTTP 1.1 GET request')
  125. location = match.group(1)
  126. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  127. header_names = [name for name, value in headers]
  128. def header(name):
  129. return ', '.join([v for n, v in headers if n == name])
  130. # Check if headers that MUST be present are actually present
  131. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  132. 'Sec-WebSocket-Version'):
  133. if name not in header_names:
  134. fail('missing "%s" header' % name)
  135. # Check WebSocket version used by client
  136. version = header('Sec-WebSocket-Version')
  137. if version != WS_VERSION:
  138. fail('WebSocket version %s requested (only %s '
  139. 'is supported)' % (version, WS_VERSION))
  140. # Verify required header keywords
  141. if 'websocket' not in header('Upgrade').lower():
  142. fail('"Upgrade" header must contain "websocket"')
  143. if 'upgrade' not in header('Connection').lower():
  144. fail('"Connection" header must contain "Upgrade"')
  145. # Origin must be present if browser client
  146. if 'User-Agent' in header_names and 'Origin' not in header_names:
  147. fail('browser client must specify "Origin" header')
  148. # Only supported protocols are returned
  149. client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
  150. protocol = 'null'
  151. for p in client_protocols:
  152. if p in self.protocols:
  153. protocol = p
  154. break
  155. # Only supported extensions are returned
  156. extensions = split_stripped(header('Sec-WebSocket-Extensions'))
  157. extensions = [e for e in extensions if e in self.extensions]
  158. # Encode acceptation key using the WebSocket GUID
  159. key = header('Sec-WebSocket-Key').strip()
  160. accept = b64encode(sha1(key + WS_GUID).digest())
  161. # Construct HTTP response header
  162. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  163. shake += 'Upgrade: websocket\r\n'
  164. shake += 'Connection: Upgrade\r\n'
  165. shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
  166. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  167. % (header('Host'), location)
  168. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  169. shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
  170. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
  171. self.sock.sendall(shake + '\r\n')
  172. self.handshake_started = True
  173. def client_handshake(self, address, location, auth):
  174. """
  175. Executes a handshake as the client end point of the socket. May raise a
  176. HandshakeError if the server response is invalid.
  177. """
  178. def fail(msg):
  179. self.sock.close()
  180. raise HandshakeError(msg)
  181. def send_request(location):
  182. if len(location) == 0:
  183. fail('request location is empty')
  184. # Generate a 16-byte random base64-encoded key for this connection
  185. key = b64encode(os.urandom(16))
  186. # Send client handshake
  187. shake = 'GET %s HTTP/1.1\r\n' % location
  188. shake += 'Host: %s:%d\r\n' % address
  189. shake += 'Upgrade: websocket\r\n'
  190. shake += 'Connection: keep-alive, Upgrade\r\n'
  191. shake += 'Sec-WebSocket-Key: %s\r\n' % key
  192. shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
  193. shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
  194. # These are for eagerly caching webservers
  195. shake += 'Pragma: no-cache\r\n'
  196. shake += 'Cache-Control: no-cache\r\n'
  197. # Request protocols and extension, these are later checked with the
  198. # actual supported values from the server's response
  199. if self.protocols:
  200. shake += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(self.protocols)
  201. if self.extensions:
  202. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self.extensions)
  203. if auth:
  204. shake += 'Authorization: %s\r\n' % auth
  205. self.sock.sendall(shake + '\r\n')
  206. return key
  207. def receive_response(key):
  208. # Receive and process server handshake
  209. raw_headers = ''
  210. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  211. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  212. # Response must be HTTP (at least 1.1) with status 101
  213. match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
  214. if match is None:
  215. fail('not a valid HTTP 1.1 response')
  216. status = int(match.group(1))
  217. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  218. header_names = [name for name, value in headers]
  219. def header(name):
  220. return ', '.join([v for n, v in headers if n == name])
  221. if status == 401:
  222. # HTTP authentication is required in the request
  223. raise HandshakeError('HTTP authentication required: %s'
  224. % header('WWW-Authenticate'))
  225. if status in (301, 302, 303, 307, 308):
  226. # Handle HTTP redirect
  227. url = urlparse(header('Location').strip())
  228. # Reconnect socket if net location changed
  229. if not url.port:
  230. url.port = 443 if self.secure else 80
  231. addr = (url.netloc, url.port)
  232. if addr != self.sock.getpeername():
  233. self.sock.close()
  234. self.sock.connect(addr)
  235. # Send new handshake
  236. receive_response(send_request(url.path))
  237. return
  238. if status != 101:
  239. # 101 means server has accepted the connection and sent
  240. # handshake headers
  241. fail('invalid HTTP response status %d' % status)
  242. # Check if headers that MUST be present are actually present
  243. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  244. if name not in header_names:
  245. fail('missing "%s" header' % name)
  246. if 'websocket' not in header('Upgrade').lower():
  247. fail('"Upgrade" header must contain "websocket"')
  248. if 'upgrade' not in header('Connection').lower():
  249. fail('"Connection" header must contain "Upgrade"')
  250. # Verify accept header
  251. accept = header('Sec-WebSocket-Accept').strip()
  252. required_accept = b64encode(sha1(key + WS_GUID).digest())
  253. if accept != required_accept:
  254. fail('invalid websocket accept header "%s"' % accept)
  255. # Compare extensions
  256. server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
  257. for e in server_ext:
  258. if e not in self.extensions:
  259. fail('server extension "%s" is unsupported by client' % e)
  260. for e in self.extensions:
  261. if e not in server_ext:
  262. fail('client extension "%s" is unsupported by server' % e)
  263. # Assert that returned protocol (if any) is supported
  264. protocol = header('Sec-WebSocket-Protocol')
  265. if protocol:
  266. if protocol != 'null' and protocol not in self.protocols:
  267. fail('unsupported protocol "%s"' % protocol)
  268. self.protocol = protocol
  269. self.handshake_started = True
  270. receive_response(send_request(location))
  271. def enable_ssl(self, *args, **kwargs):
  272. """
  273. Transforms the regular socket.socket to an ssl.SSLSocket for secure
  274. connections. Any arguments are passed to ssl.wrap_socket:
  275. http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
  276. """
  277. if self.handshake_started:
  278. raise SSLError('can only enable SSL before handshake')
  279. self.secure = True
  280. self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)