20_jigsaw.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #!/usr/bin/env python3
  2. import sys
  3. from collections import defaultdict, deque
  4. from functools import reduce
  5. from itertools import permutations, product
  6. BOTTOM, RIGHT = range(2)
  7. def parse(f):
  8. tile = ''
  9. for line in f:
  10. if line.startswith('Tile'):
  11. ident = int(line.split()[1][:-1])
  12. elif line == '\n':
  13. yield ident, tile
  14. tile = ''
  15. else:
  16. tile += line.rstrip()
  17. yield ident, tile
  18. def print_tile(tile, sz, highlight=set()):
  19. for y in range(0, sz * sz, sz):
  20. print(''.join('O' if i in highlight else tile[i]
  21. for i in range(y, y + sz)))
  22. print()
  23. def vflip(tile, sz):
  24. return ''.join(tile[i:i + sz] for i in range((sz - 1) * sz, -1, -sz))
  25. def lrotate(tile, sz):
  26. return ''.join(tile[sz - 1 - i::sz] for i in range(sz))
  27. def fliprot(tile, sz):
  28. for variant in (tile, vflip(tile, sz)):
  29. yield variant
  30. for i in range(3):
  31. variant = lrotate(variant, sz)
  32. yield variant
  33. def connections(a, b, sz):
  34. if a[:sz] == b[-sz:]: yield BOTTOM
  35. if a[::sz] == b[sz - 1::sz]: yield RIGHT
  36. def make_graph(tiles, sz):
  37. variants = [(ident, set(fliprot(tile, sz))) for ident, tile in tiles]
  38. graph = {(i, v): defaultdict(set) for i, allv in variants for v in allv}
  39. for (ia, va), (ib, vb) in permutations(variants, 2):
  40. for a, b in product(va, vb):
  41. for side in connections(a, b, sz):
  42. graph[(ib, b)][side].add((ia, a))
  43. return graph
  44. def find_grid(graph, sq):
  45. corner = next(i for (i, _), sides in graph.items() if not sides)
  46. worklist = deque(((), node) for node in graph if node[0] == corner)
  47. while worklist:
  48. grid, node = worklist.popleft()
  49. y, x = divmod(len(grid), sq)
  50. if y == 0 or node in graph[grid[-sq]][BOTTOM]:
  51. grid = grid + (node,)
  52. if x == sq - 1 and y == sq - 1:
  53. return grid
  54. neighbors = graph[node][RIGHT] if len(grid) % sq \
  55. else graph[grid[-sq]][BOTTOM]
  56. for nb in neighbors:
  57. worklist.append((grid, nb))
  58. def stitch(grid, sq, sz=10):
  59. return ''.join(tile[x + 1:x + sz - 1]
  60. for i in range(0, sq * sq, sq)
  61. for x in range(sz, sz * sz - sz, sz)
  62. for ident, tile in grid[i:i + sq])
  63. def findpat(tile, pattern, sz):
  64. diffs = [y * sz + x
  65. for y, line in enumerate(pattern)
  66. for x, char in enumerate(line)
  67. if char == '#']
  68. found = set()
  69. for variant in fliprot(tile, sz):
  70. for y in range(sz - len(pattern) + 1):
  71. for x in range(sz - len(pattern[0]) + 1):
  72. base = y * sz + x
  73. if all(variant[base + i] == '#' for i in diffs):
  74. found |= {base + i for i in diffs}
  75. if found:
  76. return variant, found
  77. def roughness(habitat, sz):
  78. monster = (' # ',
  79. '# ## ## ###',
  80. ' # # # # # # ')
  81. variant, monsters = findpat(habitat, monster, sz)
  82. return variant.count('#') - len(monsters)
  83. sz, sq = 10, 12
  84. tiles = list(parse(sys.stdin))
  85. graph = make_graph(tiles, sz)
  86. corners = set(i for (i, _), sides in graph.items() if not sides)
  87. print(reduce(lambda a, b: a * b, corners))
  88. grid = find_grid(graph, sq)
  89. habitat = stitch(grid, sq, sz)
  90. print(roughness(habitat, sq * (sz - 2)))