extension.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from errors import HandshakeError
  2. class Extension(object):
  3. rsv1 = False
  4. rsv2 = False
  5. rsv3 = False
  6. opcodes = []
  7. parameters = []
  8. def __init__(self, **kwargs):
  9. for param in self.parameters:
  10. setattr(self, param, None)
  11. for param, value in kwargs.items():
  12. if param not in self.parameters:
  13. raise HandshakeError('invalid parameter "%s"' % param)
  14. if value is None:
  15. value = True
  16. setattr(self, param, value)
  17. def hook_send(self, frame):
  18. return frame
  19. def hook_receive(self, frame):
  20. return frame
  21. class Deflate(Extension):
  22. rsv1 = True
  23. parameters = ['max_window_bits', 'no_context_takeover']
  24. def __init__(self, **kwargs):
  25. super(Deflate, self).__init__(**kwargs)
  26. self.max_window_bits = int(self.max_window_bits)
  27. def hook_send(self, frame):
  28. # FIXME: original `frame` is modified, maybe it should be copied?
  29. if not frame.rsv1:
  30. frame.rsv1 = True
  31. frame.payload = self.encode(frame.payload)
  32. return frame
  33. def hook_recv(self, frame):
  34. # FIXME: original `frame` is modified, maybe it should be copied?
  35. if frame.rsv1:
  36. frame.rsv1 = False
  37. frame.payload = self.decode(frame.payload)
  38. return frame
  39. def encode(self, data):
  40. raise NotImplementedError # TODO
  41. def decode(self, data):
  42. raise NotImplementedError # TODO
  43. def filter_compatible(extensions):
  44. """
  45. Remove extensions that use conflicting rsv bits and/or opcodes, with the
  46. first options being most preferable.
  47. """
  48. rsv1_reserved = True
  49. rsv2_reserved = True
  50. rsv3_reserved = True
  51. opcodes_reserved = []
  52. compat = []
  53. for ext in extensions:
  54. if ext.rsv1 and rsv1_reserved \
  55. or ext.rsv2 and rsv2_reserved \
  56. or ext.rsv3 and rsv3_reserved \
  57. or len(set(ext.opcodes) & set(opcodes_reserved)):
  58. continue
  59. rsv1_reserved |= ext.rsv1
  60. rsv2_reserved |= ext.rsv2
  61. rsv3_reserved |= ext.rsv3
  62. opcodes_reserved.extend(ext.opcodes)
  63. compat.append(ext)
  64. return compat