extension.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from errors import HandshakeError
  2. class Extension(object):
  3. name = ''
  4. rsv1 = False
  5. rsv2 = False
  6. rsv3 = False
  7. opcodes = []
  8. parameters = []
  9. def __init__(self, **kwargs):
  10. for param in self.parameters:
  11. setattr(self, param, None)
  12. for param, value in kwargs.items():
  13. if param not in self.parameters:
  14. raise HandshakeError('unrecognized parameter "%s"' % param)
  15. if value is None:
  16. value = True
  17. setattr(self, param, value)
  18. def __str__(self, frame):
  19. if len(self.parameters):
  20. params = ' ' + ', '.join(p + '=' + str(getattr(self, p))
  21. for p in self.parameters)
  22. else:
  23. params = ''
  24. return '<Extension "%s"%s>' % (self.name, params)
  25. def header_params(self, frame):
  26. return {}
  27. def hook_send(self, frame):
  28. return frame
  29. def hook_receive(self, frame):
  30. return frame
  31. class DeflateFrame(Extension):
  32. """
  33. This is an implementation of the "deflate-frame" extension, as defined by
  34. http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06.
  35. Supported parameters are:
  36. - max_window_size: maximum size for the LZ77 sliding window.
  37. - no_context_takeover: disallows usage of LZ77 sliding window from
  38. previously built frames for the current frame.
  39. Note that the deflate and inflate hooks modify the RSV1 bit and payload of
  40. existing `Frame` objects.
  41. """
  42. name = 'deflate-frame'
  43. rsv1 = True
  44. parameters = ['max_window_bits', 'no_context_takeover']
  45. # FIXME: is this correct?
  46. default_max_window_bits = 32768
  47. def __init__(self, **kwargs):
  48. super(DeflateFrame, self).__init__(**kwargs)
  49. if self.max_window_bits is None:
  50. self.max_window_bits = self.default_max_window_bits
  51. elif not isinstance(self.max_window_bits, int):
  52. raise HandshakeError('"max_window_bits" must be an integer')
  53. elif self.max_window_bits > 32768:
  54. raise HandshakeError('"max_window_bits" may not be larger than '
  55. '32768')
  56. if self.no_context_takeover is None:
  57. self.no_context_takeover = False
  58. elif self.no_context_takeover is not True:
  59. raise HandshakeError('"no_context_takeover" must have no value')
  60. def hook_send(self, frame):
  61. if not frame.rsv1:
  62. frame.rsv1 = True
  63. frame.payload = self.deflate(frame.payload)
  64. return frame
  65. def hook_recv(self, frame):
  66. if frame.rsv1:
  67. frame.rsv1 = False
  68. frame.payload = self.inflate(frame.payload)
  69. return frame
  70. def header_params(self):
  71. raise NotImplementedError # TODO
  72. def deflate(self, data):
  73. raise NotImplementedError # TODO
  74. def inflate(self, data):
  75. raise NotImplementedError # TODO
  76. class Multiplex(Extension):
  77. """
  78. This is an implementation of the "mux" extension, as defined by
  79. http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-11.
  80. Supported parameters are:
  81. - quota: TODO
  82. """
  83. name = 'mux'
  84. rsv1 = True # FIXME
  85. rsv2 = True # FIXME
  86. rsv3 = True # FIXME
  87. parameters = ['quota']
  88. def __init__(self, **kwargs):
  89. super(Multiplex, self).__init__(**kwargs)
  90. # TODO: check "quota" value
  91. def hook_send(self, frame):
  92. raise NotImplementedError # TODO
  93. def hook_recv(self, frame):
  94. raise NotImplementedError # TODO
  95. def header_params(self):
  96. raise NotImplementedError # TODO
  97. def filter_extensions(extensions):
  98. """
  99. Remove extensions that use conflicting rsv bits and/or opcodes, with the
  100. first options being the most preferable.
  101. """
  102. rsv1_reserved = True
  103. rsv2_reserved = True
  104. rsv3_reserved = True
  105. opcodes_reserved = []
  106. compat = []
  107. for ext in extensions:
  108. if ext.rsv1 and rsv1_reserved \
  109. or ext.rsv2 and rsv2_reserved \
  110. or ext.rsv3 and rsv3_reserved \
  111. or len(set(ext.opcodes) & set(opcodes_reserved)):
  112. continue
  113. rsv1_reserved |= ext.rsv1
  114. rsv2_reserved |= ext.rsv2
  115. rsv3_reserved |= ext.rsv3
  116. opcodes_reserved.extend(ext.opcodes)
  117. compat.append(ext)
  118. return compat