websocket.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import os
  2. import re
  3. import socket
  4. import ssl
  5. from hashlib import sha1
  6. from base64 import b64encode
  7. from urlparse import urlparse
  8. from frame import receive_frame
  9. from errors import HandshakeError, SSLError
  10. WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  11. WS_VERSION = '13'
  12. def split_stripped(value, delim=','):
  13. return map(str.strip, str(value).split(delim)) if value else []
  14. class websocket(object):
  15. """
  16. Implementation of web socket, upgrades a regular TCP socket to a websocket
  17. using the HTTP handshakes and frame (un)packing, as specified by RFC 6455.
  18. The API of a websocket is identical to that of a regular socket, as
  19. illustrated by the examples below.
  20. Server example:
  21. >>> import twspy, socket
  22. >>> sock = twspy.websocket()
  23. >>> sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  24. >>> sock.bind(('', 8000))
  25. >>> sock.listen()
  26. >>> client = sock.accept()
  27. >>> client.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Client!'))
  28. >>> frame = client.recv()
  29. Client example:
  30. >>> import twspy
  31. >>> sock = twspy.websocket()
  32. >>> sock.connect(('', 8000))
  33. >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
  34. """
  35. def __init__(self, sock=None, protocols=[], extensions=[],
  36. sfamily=socket.AF_INET, sproto=0):
  37. """
  38. Create a regular TCP socket of family `family` and protocol
  39. `sock` is an optional regular TCP socket to be used for sending binary
  40. data. If not specified, a new socket is created.
  41. `protocols` is a list of supported protocol names.
  42. `extensions` is a list of supported extensions.
  43. `sfamily` and `sproto` are used for the regular socket constructor.
  44. """
  45. self.protocols = protocols
  46. self.extensions = extensions
  47. self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
  48. self.secure = False
  49. self.handshake_started = False
  50. def bind(self, address):
  51. self.sock.bind(address)
  52. def listen(self, backlog):
  53. self.sock.listen(backlog)
  54. def accept(self):
  55. """
  56. Equivalent to socket.accept(), but transforms the socket into a
  57. websocket instance and sends a server handshake (after receiving a
  58. client handshake). Note that the handshake may raise a HandshakeError
  59. exception.
  60. """
  61. sock, address = self.sock.accept()
  62. wsock = websocket(sock)
  63. wsock.server_handshake()
  64. return wsock, address
  65. def connect(self, address, path='/', auth=None):
  66. """
  67. Equivalent to socket.connect(), but sends an client handshake request
  68. after connecting.
  69. `address` is a (host, port) tuple of the server to connect to.
  70. `path` is optional, used as the *location* part of the HTTP handshake.
  71. In a URL, this would show as ws://host[:port]/path.
  72. `auth` is optional, used for the HTTP "Authorization" header of the
  73. handshake request.
  74. """
  75. self.sock.connect(address)
  76. self.client_handshake(address, path, auth)
  77. def send(self, *args):
  78. """
  79. Send a number of frames.
  80. """
  81. for frame in args:
  82. #print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
  83. self.sock.sendall(frame.pack())
  84. def recv(self):
  85. """
  86. Receive a single frames. This can be either a data frame or a control
  87. frame.
  88. """
  89. frame = receive_frame(self.sock)
  90. #print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
  91. return frame
  92. def recvn(self, n):
  93. """
  94. Receive exactly `n` frames. These can be either data frames or control
  95. frames, or a combination of both.
  96. """
  97. return [self.recv() for i in xrange(n)]
  98. def getpeername(self):
  99. return self.sock.getpeername()
  100. def getsockname(self):
  101. return self.sock.getsockname()
  102. def setsockopt(self, level, optname, value):
  103. self.sock.setsockopt(level, optname, value)
  104. def getsockopt(self, level, optname):
  105. return self.sock.getsockopt(level, optname)
  106. def close(self):
  107. self.sock.close()
  108. def server_handshake(self):
  109. """
  110. Execute a handshake as the server end point of the socket. If the HTTP
  111. request headers sent by the client are invalid, a HandshakeError
  112. is raised.
  113. """
  114. def fail(msg):
  115. self.sock.close()
  116. raise HandshakeError(msg)
  117. # Receive HTTP header
  118. raw_headers = ''
  119. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  120. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  121. # Request must be HTTP (at least 1.1) GET request, find the location
  122. match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
  123. if match is None:
  124. fail('not a valid HTTP 1.1 GET request')
  125. location = match.group(1)
  126. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  127. header_names = [name for name, value in headers]
  128. def header(name):
  129. return ', '.join([v for n, v in headers if n == name])
  130. # Check if headers that MUST be present are actually present
  131. for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
  132. 'Sec-WebSocket-Version'):
  133. if name not in header_names:
  134. fail('missing "%s" header' % name)
  135. # Check WebSocket version used by client
  136. version = header('Sec-WebSocket-Version')
  137. if version != WS_VERSION:
  138. fail('WebSocket version %s requested (only %s '
  139. 'is supported)' % (version, WS_VERSION))
  140. # Verify required header keywords
  141. if 'websocket' not in header('Upgrade').lower():
  142. fail('"Upgrade" header must contain "websocket"')
  143. if 'upgrade' not in header('Connection').lower():
  144. fail('"Connection" header must contain "Upgrade"')
  145. # Origin must be present if browser client
  146. if 'User-Agent' in header_names and 'Origin' not in header_names:
  147. fail('browser client must specify "Origin" header')
  148. # Only supported protocols are returned
  149. client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
  150. protocol = 'null'
  151. for p in client_protocols:
  152. if p in self.protocols:
  153. protocol = p
  154. break
  155. # Only supported extensions are returned
  156. extensions = split_stripped(header('Sec-WebSocket-Extensions'))
  157. extensions = [e for e in extensions if e in self.extensions]
  158. # Encode acceptation key using the WebSocket GUID
  159. key = header('Sec-WebSocket-Key').strip()
  160. accept = b64encode(sha1(key + WS_GUID).digest())
  161. # Construct HTTP response header
  162. shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
  163. shake += 'Upgrade: websocket\r\n'
  164. shake += 'Connection: Upgrade\r\n'
  165. shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
  166. shake += 'WebSocket-Location: ws://%s%s\r\n' \
  167. % (header('Host'), location)
  168. shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
  169. shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
  170. shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
  171. self.sock.sendall(shake + '\r\n')
  172. self.handshake_started = True
  173. def client_handshake(self, address, location, auth):
  174. """
  175. Executes a handshake as the client end point of the socket. May raise a
  176. HandshakeError if the server response is invalid.
  177. """
  178. def fail(msg):
  179. self.sock.close()
  180. raise HandshakeError(msg)
  181. def send_request(location):
  182. if len(location) == 0:
  183. fail('request location is empty')
  184. # Generate a 16-byte random base64-encoded key for this connection
  185. key = b64encode(os.urandom(16))
  186. # Send client handshake
  187. shake = 'GET %s HTTP/1.1\r\n' % location
  188. shake += 'Host: %s:%d\r\n' % address
  189. shake += 'Upgrade: websocket\r\n'
  190. shake += 'Connection: keep-alive, Upgrade\r\n'
  191. shake += 'Sec-WebSocket-Key: %s\r\n' % key
  192. shake += 'Origin: null\r\n' # FIXME: is this correct/necessary?
  193. shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
  194. # These are for eagerly caching webservers
  195. shake += 'Pragma: no-cache\r\n'
  196. shake += 'Cache-Control: no-cache\r\n'
  197. # Request protocols and extension, these are later checked with the
  198. # actual supported values from the server's response
  199. if self.protocols:
  200. shake += 'Sec-WebSocket-Protocol: %s\r\n' \
  201. % ', '.join(self.protocols)
  202. if self.extensions:
  203. shake += 'Sec-WebSocket-Extensions: %s\r\n' \
  204. % ', '.join(self.extensions)
  205. if auth:
  206. shake += 'Authorization: %s\r\n' % auth
  207. self.sock.sendall(shake + '\r\n')
  208. return key
  209. def receive_response(key):
  210. # Receive and process server handshake
  211. raw_headers = ''
  212. while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
  213. raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
  214. # Response must be HTTP (at least 1.1) with status 101
  215. match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
  216. if match is None:
  217. fail('not a valid HTTP 1.1 response')
  218. status = int(match.group(1))
  219. headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
  220. header_names = [name for name, value in headers]
  221. def header(name):
  222. return ', '.join([v for n, v in headers if n == name])
  223. if status == 401:
  224. # HTTP authentication is required in the request
  225. raise HandshakeError('HTTP authentication required: %s'
  226. % header('WWW-Authenticate'))
  227. if status in (301, 302, 303, 307, 308):
  228. # Handle HTTP redirect
  229. url = urlparse(header('Location').strip())
  230. # Reconnect socket if net location changed
  231. if not url.port:
  232. url.port = 443 if self.secure else 80
  233. addr = (url.netloc, url.port)
  234. if addr != self.sock.getpeername():
  235. self.sock.close()
  236. self.sock.connect(addr)
  237. # Send new handshake
  238. receive_response(send_request(url.path))
  239. return
  240. if status != 101:
  241. # 101 means server has accepted the connection and sent
  242. # handshake headers
  243. fail('invalid HTTP response status %d' % status)
  244. # Check if headers that MUST be present are actually present
  245. for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
  246. if name not in header_names:
  247. fail('missing "%s" header' % name)
  248. if 'websocket' not in header('Upgrade').lower():
  249. fail('"Upgrade" header must contain "websocket"')
  250. if 'upgrade' not in header('Connection').lower():
  251. fail('"Connection" header must contain "Upgrade"')
  252. # Verify accept header
  253. accept = header('Sec-WebSocket-Accept').strip()
  254. required_accept = b64encode(sha1(key + WS_GUID).digest())
  255. if accept != required_accept:
  256. fail('invalid websocket accept header "%s"' % accept)
  257. # Compare extensions
  258. server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
  259. for e in server_ext:
  260. if e not in self.extensions:
  261. fail('server extension "%s" is unsupported by client' % e)
  262. for e in self.extensions:
  263. if e not in server_ext:
  264. fail('client extension "%s" is unsupported by server' % e)
  265. # Assert that returned protocol (if any) is supported
  266. protocol = header('Sec-WebSocket-Protocol')
  267. if protocol:
  268. if protocol != 'null' and protocol not in self.protocols:
  269. fail('unsupported protocol "%s"' % protocol)
  270. self.protocol = protocol
  271. self.handshake_started = True
  272. receive_response(send_request(location))
  273. def enable_ssl(self, *args, **kwargs):
  274. """
  275. Transforms the regular socket.socket to an ssl.SSLSocket for secure
  276. connections. Any arguments are passed to ssl.wrap_socket:
  277. http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
  278. """
  279. if self.handshake_started:
  280. raise SSLError('can only enable SSL before handshake')
  281. self.secure = True
  282. self.sock = ssl.wrap_socket(self.sock, *args, **kwargs)