extension.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import zlib
  2. from frame import ControlFrame
  3. from errors import SocketClosed
  4. class Extension(object):
  5. name = ''
  6. rsv1 = False
  7. rsv2 = False
  8. rsv3 = False
  9. opcodes = []
  10. defaults = {}
  11. request = {}
  12. def __init__(self, defaults={}, request={}):
  13. for param in defaults.keys() + request.keys():
  14. if param not in self.defaults:
  15. raise KeyError('unrecognized parameter "%s"' % param)
  16. # Copy dict first to avoid duplicate references to the same object
  17. self.defaults = dict(self.__class__.defaults)
  18. self.defaults.update(defaults)
  19. self.request = dict(self.__class__.request)
  20. self.request.update(request)
  21. def __str__(self):
  22. return '<Extension "%s" defaults=%s request=%s>' \
  23. % (self.name, self.defaults, self.request)
  24. def create_hook(self, **kwargs):
  25. params = {}
  26. params.update(self.defaults)
  27. params.update(kwargs)
  28. return self.Hook(self, **params)
  29. class Hook:
  30. def __init__(self, extension, **kwargs):
  31. self.extension = extension
  32. for param, value in kwargs.iteritems():
  33. setattr(self, param, value)
  34. def send(self, frame):
  35. return frame
  36. def recv(self, frame):
  37. return frame
  38. class DeflateFrame(Extension):
  39. """
  40. This is an implementation of the "deflate-frame" extension, as defined by
  41. http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06.
  42. Supported parameters are:
  43. - max_window_size: maximum size for the LZ77 sliding window.
  44. - no_context_takeover: disallows usage of LZ77 sliding window from
  45. previously built frames for the current frame.
  46. Note that the deflate and inflate hooks modify the RSV1 bit and payload of
  47. existing `Frame` objects.
  48. """
  49. name = 'deflate-frame'
  50. rsv1 = True
  51. # FIXME: is 32768 (below) correct?
  52. defaults = {'max_window_bits': 15, 'no_context_takeover': False}
  53. def __init__(self, defaults={}, request={}):
  54. Extension.__init__(self, defaults, request)
  55. mwb = self.defaults['max_window_bits']
  56. cto = self.defaults['no_context_takeover']
  57. if not isinstance(mwb, int):
  58. raise ValueError('"max_window_bits" must be an integer')
  59. elif mwb > 32768:
  60. raise ValueError('"max_window_bits" may not be larger than 32768')
  61. if cto is not False and cto is not True:
  62. raise ValueError('"no_context_takeover" must have no value')
  63. class Hook(Extension.Hook):
  64. def __init__(self, extension, **kwargs):
  65. Extension.Hook.__init__(self, extension, **kwargs)
  66. if not self.no_context_takeover:
  67. self.defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
  68. zlib.DEFLATED,
  69. -self.max_window_bits)
  70. other_wbits = self.extension.request.get('max_window_bits', 15)
  71. self.dec = zlib.decompressobj(-other_wbits)
  72. def send(self, frame):
  73. if not frame.rsv1 and not isinstance(frame, ControlFrame):
  74. frame.rsv1 = True
  75. frame.payload = self.deflate(frame.payload)
  76. return frame
  77. def recv(self, frame):
  78. if frame.rsv1:
  79. if isinstance(frame, ControlFrame):
  80. raise SocketClosed('received compressed control frame')
  81. frame.rsv1 = False
  82. frame.payload = self.inflate(frame.payload)
  83. return frame
  84. def deflate(self, data):
  85. if self.no_context_takeover:
  86. defl = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
  87. zlib.DEFLATED, -self.max_window_bits)
  88. # FIXME: why the '\x00' below? This was borrowed from
  89. # https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
  90. return defl.compress(data) + defl.flush(zlib.Z_FINISH) + '\x00'
  91. compressed = self.defl.compress(data)
  92. compressed += self.defl.flush(zlib.Z_SYNC_FLUSH)
  93. assert compressed[-4:] == '\x00\x00\xff\xff'
  94. return compressed[:-4]
  95. def inflate(self, data):
  96. data = self.dec.decompress(str(data + '\x00\x00\xff\xff'))
  97. assert not self.dec.unused_data
  98. return data
  99. class Multiplex(Extension):
  100. """
  101. This is an implementation of the "mux" extension, as defined by
  102. http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-11.
  103. Supported parameters are:
  104. - quota: TODO
  105. """
  106. name = 'mux'
  107. rsv1 = True # FIXME
  108. rsv2 = True # FIXME
  109. rsv3 = True # FIXME
  110. defaults = {'quota': None}
  111. def __init__(self, defaults={}, request={}):
  112. Extension.__init__(self, defaults, request)
  113. # TODO: check "quota" value
  114. class Hook(Extension.Hook):
  115. def send(self, frame):
  116. raise NotImplementedError # TODO
  117. def recv(self, frame):
  118. raise NotImplementedError # TODO
  119. def filter_extensions(extensions):
  120. """
  121. Remove extensions that use conflicting rsv bits and/or opcodes, with the
  122. first options being the most preferable.
  123. """
  124. rsv1_reserved = False
  125. rsv2_reserved = False
  126. rsv3_reserved = False
  127. opcodes_reserved = []
  128. compat = []
  129. for ext in extensions:
  130. if ext.rsv1 and rsv1_reserved \
  131. or ext.rsv2 and rsv2_reserved \
  132. or ext.rsv3 and rsv3_reserved \
  133. or len(set(ext.opcodes) & set(opcodes_reserved)):
  134. continue
  135. rsv1_reserved |= ext.rsv1
  136. rsv2_reserved |= ext.rsv2
  137. rsv3_reserved |= ext.rsv3
  138. opcodes_reserved.extend(ext.opcodes)
  139. compat.append(ext)
  140. return compat