extension.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from errors import HandshakeError
  2. class Extension(object):
  3. rsv1 = False
  4. rsv2 = False
  5. rsv3 = False
  6. opcodes = []
  7. parameters = []
  8. def __init__(self, **kwargs):
  9. for param in self.parameters:
  10. setattr(self, param, None)
  11. for param, value in kwargs.items():
  12. if param not in self.parameters:
  13. raise HandshakeError('invalid parameter "%s"' % param)
  14. if value is None:
  15. value = True
  16. setattr(self, param, value)
  17. def client_params(self, frame):
  18. return {}
  19. def hook_send(self, frame):
  20. return frame
  21. def hook_receive(self, frame):
  22. return frame
  23. class DeflateFrame(Extension):
  24. name = 'deflate-frame'
  25. rsv1 = True
  26. parameters = ['max_window_bits', 'no_context_takeover']
  27. def __init__(self, **kwargs):
  28. super(DeflateFrame, self).__init__(**kwargs)
  29. self.max_window_bits = int(self.max_window_bits)
  30. def hook_send(self, frame):
  31. # FIXME: original `frame` is modified, maybe it should be copied?
  32. if not frame.rsv1:
  33. frame.rsv1 = True
  34. frame.payload = self.encode(frame.payload)
  35. return frame
  36. def hook_recv(self, frame):
  37. # FIXME: original `frame` is modified, maybe it should be copied?
  38. if frame.rsv1:
  39. frame.rsv1 = False
  40. frame.payload = self.decode(frame.payload)
  41. return frame
  42. def client_params(self):
  43. raise NotImplementedError # TODO
  44. def encode(self, data):
  45. raise NotImplementedError # TODO
  46. def decode(self, data):
  47. raise NotImplementedError # TODO
  48. def filter_extensions(extensions):
  49. """
  50. Remove extensions that use conflicting rsv bits and/or opcodes, with the
  51. first options being most preferable.
  52. """
  53. rsv1_reserved = True
  54. rsv2_reserved = True
  55. rsv3_reserved = True
  56. opcodes_reserved = []
  57. compat = []
  58. for ext in extensions:
  59. if ext.rsv1 and rsv1_reserved \
  60. or ext.rsv2 and rsv2_reserved \
  61. or ext.rsv3 and rsv3_reserved \
  62. or len(set(ext.opcodes) & set(opcodes_reserved)):
  63. continue
  64. rsv1_reserved |= ext.rsv1
  65. rsv2_reserved |= ext.rsv2
  66. rsv3_reserved |= ext.rsv3
  67. opcodes_reserved.extend(ext.opcodes)
  68. compat.append(ext)
  69. return compat