extension.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. class Extension(object):
  2. name = ''
  3. rsv1 = False
  4. rsv2 = False
  5. rsv3 = False
  6. opcodes = ()
  7. before_fragmentation = False
  8. defaults = {}
  9. def __init__(self, **kwargs):
  10. for param in kwargs.iterkeys():
  11. if param not in self.defaults:
  12. raise KeyError('unrecognized parameter "%s"' % param)
  13. # Copy dict first to avoid duplicate references to the same object
  14. self.defaults = dict(self.__class__.defaults)
  15. self.defaults.update(kwargs)
  16. def __str__(self):
  17. return '<Extension "%s" defaults=%s request=%s>' \
  18. % (self.name, self.defaults, self.request)
  19. @property
  20. def names(self):
  21. return (self.name,) if self.name else ()
  22. def is_supported(self, name, other_instances):
  23. return name in self.names and not any(self.conflicts(other.extension)
  24. for other in other_instances)
  25. def conflicts(self, ext):
  26. """
  27. Check if the extension conflicts with an already accepted extension.
  28. This may be the case when the two extensions use the same reserved
  29. bits, or have the same name (when the same extension is negotiated
  30. multiple times with different parameters).
  31. """
  32. return ext.rsv1 and self.rsv1 \
  33. or ext.rsv2 and self.rsv2 \
  34. or ext.rsv3 and self.rsv3 \
  35. or set(ext.names) & set(self.names) \
  36. or set(ext.opcodes) & set(self.opcodes)
  37. def negotiate(self, name, params):
  38. """
  39. Same as `negotiate_safe`, but instead returns an iterator of (param,
  40. value) tuples and raises an exception on error.
  41. """
  42. raise NotImplementedError
  43. def negotiate_safe(self, name, params):
  44. """
  45. `name` and `params` are sent in the HTTP request by the client. Check
  46. if the extension name is supported by this extension, and validate the
  47. parameters. Returns a dict with accepted parameters, or None if not
  48. accepted.
  49. """
  50. for param in params.iterkeys():
  51. if param not in self.defaults:
  52. return
  53. try:
  54. return dict(self.negotiate(name, params))
  55. except (KeyError, ValueError, AssertionError):
  56. pass
  57. class Instance:
  58. def __init__(self, extension, name, params):
  59. self.extension = extension
  60. self.name = name
  61. self.params = params
  62. for param, value in extension.defaults.iteritems():
  63. setattr(self, param, value)
  64. for param, value in params.iteritems():
  65. setattr(self, param, value)
  66. self.init()
  67. def init(self):
  68. return NotImplemented
  69. def handle_send(self, frame):
  70. if self.extension.before_fragmentation:
  71. assert not frame.is_fragmented()
  72. replacement = self.onsend(frame)
  73. return frame if replacement is None else replacement
  74. def handle_recv(self, frame):
  75. if self.extension.before_fragmentation:
  76. assert not frame.is_fragmented()
  77. replacement = self.onrecv(frame)
  78. return frame if replacement is None else replacement
  79. def onsend(self, frame):
  80. raise NotImplementedError
  81. def onrecv(self, frame):
  82. raise NotImplementedError