16_decoder.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python3
  2. import sys
  3. from itertools import islice
  4. from operator import add, mul, gt, lt, eq
  5. from functools import reduce
  6. def bitstream(hexdata):
  7. for quad in hexdata:
  8. i = int(quad, 16)
  9. yield i >> 3
  10. yield i >> 2 & 1
  11. yield i >> 1 & 1
  12. yield i & 1
  13. def consume(bits, num):
  14. i = 0
  15. for bit in islice(bits, num):
  16. i = i << 1 | bit
  17. return i
  18. def decode(bits):
  19. version = consume(bits, 3)
  20. ty = consume(bits, 3)
  21. if ty == 4:
  22. has_next = next(bits)
  23. arg = consume(bits, 4)
  24. while has_next:
  25. has_next = next(bits)
  26. arg = arg << 4 | consume(bits, 4)
  27. elif next(bits):
  28. arg = [decode(bits) for i in range(consume(bits, 11))]
  29. else:
  30. subbits = islice(bits, consume(bits, 15))
  31. arg = []
  32. while True:
  33. try:
  34. arg.append(decode(subbits))
  35. except StopIteration:
  36. break
  37. return ty, version, arg
  38. def version_sum(packet):
  39. ty, version, arg = packet
  40. return version if ty == 4 else version + sum(map(version_sum, arg))
  41. OPS = add, mul, min, max, None, gt, lt, eq
  42. def evaluate(packet):
  43. ty, version, arg = packet
  44. return arg if ty == 4 else reduce(OPS[ty], map(evaluate, arg))
  45. packet = decode(bitstream(sys.stdin.readline().rstrip()))
  46. print(version_sum(packet))
  47. print(evaluate(packet))