extension.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from errors import HandshakeError
  2. class Extension(object):
  3. name = ''
  4. rsv1 = False
  5. rsv2 = False
  6. rsv3 = False
  7. opcodes = []
  8. parameters = []
  9. def __init__(self, **kwargs):
  10. for param in self.parameters:
  11. setattr(self, param, None)
  12. for param, value in kwargs.items():
  13. if param not in self.parameters:
  14. raise HandshakeError('invalid parameter "%s"' % param)
  15. if value is None:
  16. value = True
  17. setattr(self, param, value)
  18. def __str__(self, frame):
  19. if len(self.parameters):
  20. params = ' ' + ', '.join(p + '=' + str(getattr(self, p))
  21. for p in self.parameters)
  22. else:
  23. params = ''
  24. return '<Extension "%s"%s>' % (self.name, params)
  25. def header_params(self, frame):
  26. return {}
  27. def hook_send(self, frame):
  28. return frame
  29. def hook_receive(self, frame):
  30. return frame
  31. class DeflateFrame(Extension):
  32. name = 'deflate-frame'
  33. rsv1 = True
  34. parameters = ['max_window_bits', 'no_context_takeover']
  35. def __init__(self, **kwargs):
  36. super(DeflateFrame, self).__init__(**kwargs)
  37. if self.max_window_bits is None:
  38. # FIXME: is this correct? None may actually be a better value
  39. self.max_window_bits = 0
  40. def hook_send(self, frame):
  41. # FIXME: original `frame` is modified, maybe it should be copied?
  42. if not frame.rsv1:
  43. frame.rsv1 = True
  44. frame.payload = self.encode(frame.payload)
  45. return frame
  46. def hook_recv(self, frame):
  47. # FIXME: original `frame` is modified, maybe it should be copied?
  48. if frame.rsv1:
  49. frame.rsv1 = False
  50. frame.payload = self.decode(frame.payload)
  51. return frame
  52. def header_params(self):
  53. raise NotImplementedError # TODO
  54. def encode(self, data):
  55. raise NotImplementedError # TODO
  56. def decode(self, data):
  57. raise NotImplementedError # TODO
  58. def filter_extensions(extensions):
  59. """
  60. Remove extensions that use conflicting rsv bits and/or opcodes, with the
  61. first options being the most preferable.
  62. """
  63. rsv1_reserved = True
  64. rsv2_reserved = True
  65. rsv3_reserved = True
  66. opcodes_reserved = []
  67. compat = []
  68. for ext in extensions:
  69. if ext.rsv1 and rsv1_reserved \
  70. or ext.rsv2 and rsv2_reserved \
  71. or ext.rsv3 and rsv3_reserved \
  72. or len(set(ext.opcodes) & set(opcodes_reserved)):
  73. continue
  74. rsv1_reserved |= ext.rsv1
  75. rsv2_reserved |= ext.rsv2
  76. rsv3_reserved |= ext.rsv3
  77. opcodes_reserved.extend(ext.opcodes)
  78. compat.append(ext)
  79. return compat