extension.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from errors import HandshakeError
  2. class Extension(object):
  3. rsv1 = False
  4. rsv2 = False
  5. rsv3 = False
  6. opcodes = []
  7. parameters = []
  8. def __init__(self, **kwargs):
  9. for param in self.parameters:
  10. setattr(self, param, None)
  11. for param, value in kwargs.items():
  12. if param not in self.parameters:
  13. raise HandshakeError('invalid parameter "%s"' % param)
  14. if value is None:
  15. value = True
  16. setattr(self, param, value)
  17. def client_params(self, frame):
  18. return {}
  19. def hook_send(self, frame):
  20. return frame
  21. def hook_receive(self, frame):
  22. return frame
  23. class DeflateFrame(Extension):
  24. name = 'deflate-frame'
  25. rsv1 = True
  26. parameters = ['max_window_bits', 'no_context_takeover']
  27. def __init__(self, **kwargs):
  28. super(DeflateFrame, self).__init__(**kwargs)
  29. self.max_window_bits = int(self.max_window_bits)
  30. def hook_send(self, frame):
  31. # FIXME: original `frame` is modified, maybe it should be copied?
  32. if not frame.rsv1:
  33. frame.rsv1 = True
  34. frame.payload = self.encode(frame.payload)
  35. return frame
  36. def hook_recv(self, frame):
  37. # FIXME: original `frame` is modified, maybe it should be copied?
  38. if frame.rsv1:
  39. frame.rsv1 = False
  40. frame.payload = self.decode(frame.payload)
  41. return frame
  42. def client_params(self):
  43. raise NotImplementedError # TODO
  44. def encode(self, data):
  45. raise NotImplementedError # TODO
  46. def decode(self, data):
  47. raise NotImplementedError # TODO
  48. def filter_extensions(extensions):
  49. """
  50. Remove extensions that use conflicting rsv bits and/or opcodes, with the
  51. first options being most preferable.
  52. """
  53. rsv1_reserved = True
  54. rsv2_reserved = True
  55. rsv3_reserved = True
  56. opcodes_reserved = []
  57. compat = []
  58. for ext in extensions:
  59. if ext.rsv1 and rsv1_reserved \
  60. or ext.rsv2 and rsv2_reserved \
  61. or ext.rsv3 and rsv3_reserved \
  62. or len(set(ext.opcodes) & set(opcodes_reserved)):
  63. continue
  64. rsv1_reserved |= ext.rsv1
  65. rsv2_reserved |= ext.rsv2
  66. rsv3_reserved |= ext.rsv3
  67. opcodes_reserved.extend(ext.opcodes)
  68. compat.append(ext)
  69. return compat
  70. """
  71. Class map used to find contructors for client-specified extensions. Not to be
  72. modified manually, only through `register_extension`.
  73. """
  74. extension_class_map = {}
  75. def register_extension(ext):
  76. if not isinstance(ext, Extension):
  77. raise ValueError('extensions should extend the `Extension` class')
  78. if ext.name in extension_clas_map:
  79. raise KeyError('extension "%s" has already been registered' % ext.name)
  80. extension_class_map[ext.name] = ext
  81. register_extension(DeflateFrame)