Ver Fonte

Cleanup

Taddeus Kroes há 4 anos atrás
pai
commit
f038a30a16
1 ficheiros alterados com 28 adições e 32 exclusões
  1. 28 32
      2021/16_decoder.py

+ 28 - 32
2021/16_decoder.py

@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 import sys
-from itertools import islice
+from itertools import islice, starmap
 from operator import add, mul, gt, lt, eq
 from functools import reduce
 
@@ -19,38 +19,34 @@ def consume(bits, num):
     return i
 
 def decode(bits):
-    version = consume(bits, 3)
-    ty = consume(bits, 3)
-
-    if ty == 4:
-        has_next = next(bits)
-        arg = consume(bits, 4)
-        while has_next:
-            has_next = next(bits)
-            arg = arg << 4 | consume(bits, 4)
-    elif next(bits):
-        arg = [decode(bits) for i in range(consume(bits, 11))]
-    else:
-        subbits = islice(bits, consume(bits, 15))
-        arg = []
-        while True:
-            try:
-                arg.append(decode(subbits))
-            except StopIteration:
-                break
-
-    return ty, version, arg
-
-def version_sum(packet):
-    ty, version, arg = packet
-    return version if ty == 4 else version + sum(map(version_sum, arg))
+    while True:
+        try:
+            version = consume(bits, 3)
+            ty = consume(bits, 3)
+
+            if ty == 4:
+                has_next = next(bits)
+                arg = consume(bits, 4)
+                while has_next:
+                    has_next = next(bits)
+                    arg = arg << 4 | consume(bits, 4)
+            elif next(bits):
+                arg = list(islice(decode(bits), consume(bits, 11)))
+            else:
+                arg = list(decode(islice(bits, consume(bits, 15))))
+
+            yield ty, version, arg
+        except StopIteration:
+            break
+
+def version_sum(ty, version, arg):
+    return version if ty == 4 else version + sum(starmap(version_sum, arg))
 
 OPS = add, mul, min, max, None, gt, lt, eq
 
-def evaluate(packet):
-    ty, version, arg = packet
-    return arg if ty == 4 else reduce(OPS[ty], map(evaluate, arg))
+def evaluate(ty, version, arg):
+    return arg if ty == 4 else reduce(OPS[ty], starmap(evaluate, arg))
 
-packet = decode(bitstream(sys.stdin.readline().rstrip()))
-print(version_sum(packet))
-print(evaluate(packet))
+packet = next(decode(bitstream(sys.stdin.readline().rstrip())))
+print(version_sum(*packet))
+print(evaluate(*packet))