|
|
@@ -1,6 +1,6 @@
|
|
|
#!/usr/bin/env python3
|
|
|
import sys
|
|
|
-from itertools import islice
|
|
|
+from itertools import islice, starmap
|
|
|
from operator import add, mul, gt, lt, eq
|
|
|
from functools import reduce
|
|
|
|
|
|
@@ -19,38 +19,34 @@ def consume(bits, num):
|
|
|
return i
|
|
|
|
|
|
def decode(bits):
|
|
|
- version = consume(bits, 3)
|
|
|
- ty = consume(bits, 3)
|
|
|
-
|
|
|
- if ty == 4:
|
|
|
- has_next = next(bits)
|
|
|
- arg = consume(bits, 4)
|
|
|
- while has_next:
|
|
|
- has_next = next(bits)
|
|
|
- arg = arg << 4 | consume(bits, 4)
|
|
|
- elif next(bits):
|
|
|
- arg = [decode(bits) for i in range(consume(bits, 11))]
|
|
|
- else:
|
|
|
- subbits = islice(bits, consume(bits, 15))
|
|
|
- arg = []
|
|
|
- while True:
|
|
|
- try:
|
|
|
- arg.append(decode(subbits))
|
|
|
- except StopIteration:
|
|
|
- break
|
|
|
-
|
|
|
- return ty, version, arg
|
|
|
-
|
|
|
-def version_sum(packet):
|
|
|
- ty, version, arg = packet
|
|
|
- return version if ty == 4 else version + sum(map(version_sum, arg))
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ version = consume(bits, 3)
|
|
|
+ ty = consume(bits, 3)
|
|
|
+
|
|
|
+ if ty == 4:
|
|
|
+ has_next = next(bits)
|
|
|
+ arg = consume(bits, 4)
|
|
|
+ while has_next:
|
|
|
+ has_next = next(bits)
|
|
|
+ arg = arg << 4 | consume(bits, 4)
|
|
|
+ elif next(bits):
|
|
|
+ arg = list(islice(decode(bits), consume(bits, 11)))
|
|
|
+ else:
|
|
|
+ arg = list(decode(islice(bits, consume(bits, 15))))
|
|
|
+
|
|
|
+ yield ty, version, arg
|
|
|
+ except StopIteration:
|
|
|
+ break
|
|
|
+
|
|
|
+def version_sum(ty, version, arg):
|
|
|
+ return version if ty == 4 else version + sum(starmap(version_sum, arg))
|
|
|
|
|
|
OPS = add, mul, min, max, None, gt, lt, eq
|
|
|
|
|
|
-def evaluate(packet):
|
|
|
- ty, version, arg = packet
|
|
|
- return arg if ty == 4 else reduce(OPS[ty], map(evaluate, arg))
|
|
|
+def evaluate(ty, version, arg):
|
|
|
+ return arg if ty == 4 else reduce(OPS[ty], starmap(evaluate, arg))
|
|
|
|
|
|
-packet = decode(bitstream(sys.stdin.readline().rstrip()))
|
|
|
-print(version_sum(packet))
|
|
|
-print(evaluate(packet))
|
|
|
+packet = next(decode(bitstream(sys.stdin.readline().rstrip())))
|
|
|
+print(version_sum(*packet))
|
|
|
+print(evaluate(*packet))
|