websocket.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. import os
  2. import re
  3. import socket
  4. import ssl
  5. from hashlib import sha1
  6. from base64 import b64encode
  7. from urlparse import urlparse
  8. from frame import receive_frame
  9. from errors import HandshakeError, SSLError
  10. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  11. WS_VERSION = '13'
  12. def split_stripped(value, delim=','):
  13. return map(str.strip, str(value).split(delim)) if value else []
  14. class websocket(object):
  15. """
  16. Implementation of web socket, upgrades a regular TCP socket to a websocket
  17. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  18. The API of a websocket is identical to that of a regular socket, as
  19. illustrated by the examples below.
  20. Server example:
  21. >>> import twspy, socket
  22. >>> sock = twspy.websocket()
  23. >>> sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  24. >>> sock.bind(('', 8000))
  25. >>> sock.listen()
  26. >>> client = sock.accept()
  27. >>> client.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Client!'))
  28. >>> frame = client.recv()
  29. Client example:
  30. >>> import twspy
  31. >>> sock = twspy.websocket()
  32. >>> sock.connect(('', 8000))
  33. >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
  34. """
  35. def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
  36. trusted_origins=[], sfamily=socket.AF_INET, sproto=0):
  37. """
  38. Create a regular TCP socket of family `family` and protocol
  39. `sock` is an optional regular TCP socket to be used for sending binary
  40. data. If not specified, a new socket is created.
  41. `protocols` is a list of supported protocol names.
  42. `extensions` is a list of supported extensions.
  43. `origin` (for client sockets) is the value for the "Origin" header sent
  44. in a client handshake .
  45. `trusted_origins` (for servere sockets) is a list of expected values
  46. for the "Origin" header sent by a client. If the received Origin header
  47. has value not in this list, a HandshakeError is raised. If the list is
  48. empty (default), all origins are excepted.
  49. `sfamily` and `sproto` are used for the regular socket constructor.
  50. """
  51. self.protocols = protocols
  52. self.extensions = extensions
  53. self.origin = origin
  54. self.trusted_origins = trusted_origins
  55. self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
  56. self.secure = False
  57. self.handshake_started = False
  58. def bind(self, address):
  59. self.sock.bind(address)
  60. def listen(self, backlog):
  61. self.sock.listen(backlog)
  62. def accept(self):
  63. """
  64. Equivalent to socket.accept(), but transforms the socket into a
  65. websocket instance and sends a server handshake (after receiving a
  66. client handshake). Note that the handshake may raise a HandshakeError
  67. exception.
  68. """
  69. sock, address = self.sock.accept()
  70. wsock = websocket(sock)
  71. wsock.server_handshake()
  72. return wsock, address
  73. def connect(self, address, path='/', auth=None):
  74. """
  75. Equivalent to socket.connect(), but sends an client handshake request
  76. after connecting.
  77. `address` is a (host, port) tuple of the server to connect to.
  78. `path` is optional, used as the *location* part of the HTTP handshake.
  79. In a URL, this would show as ws://host[:port]/path.
  80. `auth` is optional, used for the HTTP "Authorization" header of the
  81. handshake request.
  82. """
  83. self.sock.connect(address)
  84. self.client_handshake(address, path, auth)
  85. def send(self, *args):
  86. """
  87. Send a number of frames.
  88. """
  89. for frame in args:
  90. #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
  91. self.sock.sendall(frame.pack())
  92. def recv(self):
  93. """
  94. Receive a single frames. This can be either a data frame or a control
  95. frame.
  96. """
  97. frame = receive_frame(self.sock)
  98. #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
  99. return frame
  100. def recvn(self, n):
  101. """
  102. Receive exactly `n` frames. These can be either data frames or control
  103. frames, or a combination of both.
  104. """
  105. return [self.recv() for i in xrange(n)]
  106. def getpeername(self):
  107. return self.sock.getpeername()
  108. def getsockname(self):
  109. return self.sock.getsockname()
  110. def setsockopt(self, level, optname, value):
  111. self.sock.setsockopt(level, optname, value)
  112. def getsockopt(self, level, optname):
  113. return self.sock.getsockopt(level, optname)
  114. def close(self):
  115. self.sock.close()
  116. def server_handshake(self):
  117. """
  118. Execute a handshake as the server end point of the socket. If the HTTP
  119. request headers sent by the client are invalid, a HandshakeError
  120. is raised.
  121. """
  122. def fail(msg):
  123. self.sock.close()
  124. raise HandshakeError(msg)
  125. # Receive HTTP header
  126. raw_headers = ''
  127. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  128. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  129. # Request must be HTTP (at least 1.1) GET request, find the location
  130. match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
  131. if match is None:
  132. fail('not a valid HTTP 1.1 GET request')
  133. location = match.group(1)
  134. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  135. header_names = [name for name, value in headers]
  136. def header(name):
  137. return ', '.join([v for n, v in headers if n == name])
  138. # Check if headers that MUST be present are actually present
  139. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  140. 'Sec-WebSocket-Version'):
  141. if name not in header_names:
  142. fail('missing "%s" header' % name)
  143. # Check WebSocket version used by client
  144. version = header('Sec-WebSocket-Version')
  145. if version != WS_VERSION:
  146. fail('WebSocket version %s requested (only %s '
  147. 'is supported)' % (version, WS_VERSION))
  148. # Verify required header keywords
  149. if 'websocket' not in header('Upgrade').lower():
  150. fail('"Upgrade" header must contain "websocket"')
  151. if 'upgrade' not in header('Connection').lower():
  152. fail('"Connection" header must contain "Upgrade"')
  153. # Origin must be present if browser client, and must match the list of
  154. # trusted origins
  155. if 'Origin' not in header_names:
  156. if 'User-Agent' in header_names:
  157. fail('browser client must specify "Origin" header')
  158. if self.trusted_origins:
  159. fail('no "Origin" header specified, assuming untrusted')
  160. elif self.trusted_origins:
  161. origin = header('Origin')
  162. if origin not in self.trusted_origins:
  163. fail('untrusted origin "%s"' % origin)
  164. # Only supported protocols are returned
  165. client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
  166. protocol = 'null'
  167. for p in client_protocols:
  168. if p in self.protocols:
  169. protocol = p
  170. break
  171. # Only supported extensions are returned
  172. extensions = split_stripped(header('Sec-WebSocket-Extensions'))
  173. extensions = [e for e in extensions if e in self.extensions]
  174. # Encode acceptation key using the WebSocket GUID
  175. key = header('Sec-WebSocket-Key').strip()
  176. accept = b64encode(sha1(key + WS_GUID).digest())
  177. # Construct HTTP response header
  178. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  179. shake += 'Upgrade: websocket\r\n'
  180. shake += 'Connection: Upgrade\r\n'
  181. shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
  182. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  183. % (header('Host'), location)
  184. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  185. shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
  186. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
  187. self.sock.sendall(shake + '\r\n')
  188. self.handshake_started = True
  189. def client_handshake(self, address, location, auth):
  190. """
  191. Executes a handshake as the client end point of the socket. May raise a
  192. HandshakeError if the server response is invalid.
  193. """
  194. def fail(msg):
  195. self.sock.close()
  196. raise HandshakeError(msg)
  197. def send_request(location):
  198. if len(location) == 0:
  199. fail('request location is empty')
  200. # Generate a 16-byte random base64-encoded key for this connection
  201. key = b64encode(os.urandom(16))
  202. # Send client handshake
  203. shake = 'GET %s HTTP/1.1\r\n' % location
  204. shake += 'Host: %s:%d\r\n' % address
  205. shake += 'Upgrade: websocket\r\n'
  206. shake += 'Connection: keep-alive, Upgrade\r\n'
  207. shake += 'Sec-WebSocket-Key: %s\r\n' % key
  208. shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
  209. if self.origin:
  210. shake += 'Origin: %s\r\n' % self.origin
  211. # These are for eagerly caching webservers
  212. shake += 'Pragma: no-cache\r\n'
  213. shake += 'Cache-Control: no-cache\r\n'
  214. # Request protocols and extension, these are later checked with the
  215. # actual supported values from the server's response
  216. if self.protocols:
  217. shake += 'Sec-WebSocket-Protocol: %s\r\n' \
  218. % ', '.join(self.protocols)
  219. if self.extensions:
  220. shake += 'Sec-WebSocket-Extensions: %s\r\n' \
  221. % ', '.join(self.extensions)
  222. if auth:
  223. shake += 'Authorization: %s\r\n' % auth
  224. self.sock.sendall(shake + '\r\n')
  225. return key
  226. def receive_response(key):
  227. # Receive and process server handshake
  228. raw_headers = ''
  229. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  230. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  231. # Response must be HTTP (at least 1.1) with status 101
  232. match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
  233. if match is None:
  234. fail('not a valid HTTP 1.1 response')
  235. status = int(match.group(1))
  236. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  237. header_names = [name for name, value in headers]
  238. def header(name):
  239. return ', '.join([v for n, v in headers if n == name])
  240. if status == 401:
  241. # HTTP authentication is required in the request
  242. raise HandshakeError('HTTP authentication required: %s'
  243. % header('WWW-Authenticate'))
  244. if status in (301, 302, 303, 307, 308):
  245. # Handle HTTP redirect
  246. url = urlparse(header('Location').strip())
  247. # Reconnect socket if net location changed
  248. if not url.port:
  249. url.port = 443 if self.secure else 80
  250. addr = (url.netloc, url.port)
  251. if addr != self.sock.getpeername():
  252. self.sock.close()
  253. self.sock.connect(addr)
  254. # Send new handshake
  255. receive_response(send_request(url.path))
  256. return
  257. if status != 101:
  258. # 101 means server has accepted the connection and sent
  259. # handshake headers
  260. fail('invalid HTTP response status %d' % status)
  261. # Check if headers that MUST be present are actually present
  262. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  263. if name not in header_names:
  264. fail('missing "%s" header' % name)
  265. if 'websocket' not in header('Upgrade').lower():
  266. fail('"Upgrade" header must contain "websocket"')
  267. if 'upgrade' not in header('Connection').lower():
  268. fail('"Connection" header must contain "Upgrade"')
  269. # Verify accept header
  270. accept = header('Sec-WebSocket-Accept').strip()
  271. required_accept = b64encode(sha1(key + WS_GUID).digest())
  272. if accept != required_accept:
  273. fail('invalid websocket accept header "%s"' % accept)
  274. # Compare extensions
  275. server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
  276. for e in server_ext:
  277. if e not in self.extensions:
  278. fail('server extension "%s" is unsupported by client' % e)
  279. for e in self.extensions:
  280. if e not in server_ext:
  281. fail('client extension "%s" is unsupported by server' % e)
  282. # Assert that returned protocol (if any) is supported
  283. protocol = header('Sec-WebSocket-Protocol')
  284. if protocol:
  285. if protocol != 'null' and protocol not in self.protocols:
  286. fail('unsupported protocol "%s"' % protocol)
  287. self.protocol = protocol
  288. self.handshake_started = True
  289. receive_response(send_request(location))
  290. def enable_ssl(self, *args, **kwargs):
  291. """
  292. Transforms the regular socket.socket to an ssl.SSLSocket for secure
  293. connections. Any arguments are passed to ssl.wrap_socket:
  294. http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
  295. """
  296. if self.handshake_started:
  297. raise SSLError('can only enable SSL before handshake')
  298. self.secure = True
  299. self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)