Ver Fonte

Optimize 2020 day 20 with the assumption that all tiles only connect to one possible tile on each side

Taddeus Kroes há 5 anos atrás
pai
commit
783f40c477
1 ficheiros alterados com 51 adições e 55 exclusões
  1. 51 55
      2020/20_jigsaw.py

+ 51 - 55
2020/20_jigsaw.py

@@ -1,10 +1,6 @@
 #!/usr/bin/env python3
 import sys
-from collections import defaultdict, deque
-from functools import reduce
-from itertools import permutations, product
-
-BOTTOM, RIGHT = range(2)
+from functools import lru_cache
 
 def parse(f):
     tile = ''
@@ -18,12 +14,6 @@ def parse(f):
             tile += line.rstrip()
     yield ident, tile
 
-def print_tile(tile, sz, highlight=set()):
-    for y in range(0, sz * sz, sz):
-        print(''.join('O' if i in highlight else tile[i]
-                      for i in range(y, y + sz)))
-    print()
-
 def vflip(tile, sz):
     return ''.join(tile[i:i + sz] for i in range((sz - 1) * sz, -1, -sz))
 
@@ -37,40 +27,53 @@ def fliprot(tile, sz):
             variant = lrotate(variant, sz)
             yield variant
 
-def connections(a, b, sz):
-    if a[:sz] == b[-sz:]: yield BOTTOM
-    if a[::sz] == b[sz - 1::sz]: yield RIGHT
-
-def make_graph(tiles, sz):
-    variants = [(ident, set(fliprot(tile, sz))) for ident, tile in tiles]
-    graph = {(i, v): defaultdict(set) for i, allv in variants for v in allv}
-    for (ia, va), (ib, vb) in permutations(variants, 2):
-        for a, b in product(va, vb):
-            for side in connections(a, b, sz):
-                graph[(ib, b)][side].add((ia, a))
-    return graph
-
-def find_grid(graph, sq):
-    corner = next(i for (i, _), sides in graph.items() if not sides)
-    worklist = deque(((), node) for node in graph if node[0] == corner)
-
-    while worklist:
-        grid, node = worklist.popleft()
-        y, x = divmod(len(grid), sq)
-        if y == 0 or node in graph[grid[-sq]][BOTTOM]:
-            grid = grid + (node,)
-            if x == sq - 1 and y == sq - 1:
-                return grid
-            neighbors = graph[node][RIGHT] if len(grid) % sq \
-                        else graph[grid[-sq]][BOTTOM]
-            for nb in neighbors:
-                worklist.append((grid, nb))
-
-def stitch(grid, sq, sz=10):
-    return ''.join(tile[x + 1:x + sz - 1]
-                   for i in range(0, sq * sq, sq)
-                   for x in range(sz, sz * sz - sz, sz)
-                   for ident, tile in grid[i:i + sq])
+def make_grid(tiles, sz=10):
+    tiles = dict(tiles)
+    placed = {}
+    unavail = set()
+
+    @lru_cache(maxsize=None)
+    def variants(ident):
+        return list(fliprot(tiles[ident], sz))
+
+    def place(ident, tile, x, y):
+        tile_t = placed.get((x, y - 1), (None, None))[1]
+        tile_b = placed.get((x, y + 1), (None, None))[1]
+        tile_r = placed.get((x + 1, y), (None, None))[1]
+        tile_l = placed.get((x - 1, y), (None, None))[1]
+
+        if tile_t and tile_t[-sz:] != tile[:sz] or \
+           tile_b and tile_b[:sz] != tile[-sz:] or \
+           tile_r and tile_r[::sz] != tile[sz - 1::sz] or \
+           tile_l and tile_l[sz - 1::sz] != tile[::sz]:
+            return False
+
+        placed[(x, y)] = ident, tile
+        unavail.add(ident)
+        for nb in ((x, y + 1), (x + 1, y), (x, y - 1), (x - 1, y)):
+            if nb not in placed:
+                for ident in tiles:
+                    if ident not in unavail:
+                        for variant in variants(ident):
+                            if place(ident, variant, *nb):
+                                break
+        return True
+
+    start = next(iter(tiles))
+    place(start, tiles[start], 0, 0)
+
+    minx = min(x for x, y in placed)
+    miny = min(y for x, y in placed)
+    maxx = max(x for x, y in placed) + 1
+    maxy = max(y for x, y in placed) + 1
+    return [[placed[(x, y)] for x in range(minx, maxx)]
+            for y in range(miny, maxy)]
+
+def stitch(grid, sz=10):
+    return ''.join(tile[y + 1:y + sz - 1]
+                   for row in grid
+                   for y in range(sz, sz * sz - sz, sz)
+                   for ident, tile in row)
 
 def findpat(tile, pattern, sz):
     diffs = [y * sz + x
@@ -94,13 +97,6 @@ def roughness(habitat, sz):
     variant, monsters = findpat(habitat, monster, sz)
     return variant.count('#') - len(monsters)
 
-sz, sq = 10, 12
-
-tiles = list(parse(sys.stdin))
-graph = make_graph(tiles, sz)
-corners = set(i for (i, _), sides in graph.items() if not sides)
-print(reduce(lambda a, b: a * b, corners))
-
-grid = find_grid(graph, sq)
-habitat = stitch(grid, sq, sz)
-print(roughness(habitat, sq * (sz - 2)))
+grid = make_grid(parse(sys.stdin))
+print(grid[0][0][0] * grid[0][-1][0] * grid[-1][0][0] * grid[-1][-1][0])
+print(roughness(stitch(grid), len(grid) * 8))