Quellcode durchsuchen

Simplify day 12, less optimal but still fast and simpler code

Taddeus Kroes vor 2 Jahren
Ursprung
Commit
1ff5a99740
1 geänderte Dateien mit 14 neuen und 23 gelöschten Zeilen
  1. 14 23
      2023/12_records.py

+ 14 - 23
2023/12_records.py

@@ -2,37 +2,28 @@
 import sys
 from functools import cache
 
-def place_iter(size, record):
-    groupend = record.find('.', 1) + 1
-    if groupend != len(record):
-        group = record[:groupend]
-        for rem in place(size, group):
-            yield rem + record[groupend:]
-        if '#' not in group:
-            yield from place(size, record[groupend - 1:])
-    else:
-        for start in range(1, len(record) - size):
-            end = start + size
-            if record[start - 1] != '#' and record[end] != '#' and \
-                    all(record[i] != '.' for i in range(start, end)):
-                yield '.' + record[min(end + 1, len(record) - 1):]
-            if record[start] == '#':
-                break
-
 @cache
-def place(size, record):
-    return tuple(place_iter(size, record))
+def fit(size, record):
+    remainders = []
+    for start in range(1, len(record) - size):
+        end = start + size
+        if record[start - 1] != '#' and record[end] != '#' and \
+                all(record[i] != '.' for i in range(start, end)):
+            remainders.append('.' + record[end + 1:])
+        if record[start] == '#':
+            break
+    return tuple(remainders)
 
 @cache
-def arrangements(record, sizes):
+def repair(record, sizes):
     if not sizes:
         return int('#' not in record)
-    return sum(arrangements(r, sizes[1:]) for r in place(sizes[0], record))
+    return sum(repair(r, sizes[1:]) for r in fit(sizes[0], record))
 
 def normalize(record):
     return '.' + '.'.join(record.replace('.', ' ').split()) + '.'
 
 records = [(record, tuple(map(int, numbers.split(','))))
            for record, numbers in map(str.split, sys.stdin)]
-print(sum(arrangements(normalize(r), d) for r, d in records))
-print(sum(arrangements(normalize('?'.join([r] * 5)), d * 5) for r, d in records))
+print(sum(repair(normalize(r), d) for r, d in records))
+print(sum(repair(normalize('?'.join([r] * 5)), d * 5) for r, d in records))