18_vault.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python3
  2. import sys
  3. from heapq import heappop, heappush
  4. from itertools import chain, cycle
  5. def read_grid(f):
  6. rows = [line.rstrip() for line in f]
  7. return list(chain.from_iterable(rows)), len(rows[0])
  8. def build_graph(grid, w):
  9. def dfs(loc, prev, dist):
  10. nblocs = (loc - 1, loc + 1, loc - w, loc + w)
  11. neighbors = [nb for nb in nblocs if grid[nb] != '#']
  12. node = grid[loc]
  13. if node == '@' or node.isalpha() or len(neighbors) > 2:
  14. if node in doubles:
  15. seen = doubles[node]
  16. count = seen.setdefault(loc, len(seen))
  17. node += str(count)
  18. if prev is None:
  19. graph[node] = {}
  20. else:
  21. graph[prev][node] = dist
  22. graph.setdefault(node, {})[prev] = dist
  23. prev = node
  24. dist = 0
  25. if loc not in visited:
  26. visited.add(loc)
  27. for nb in neighbors:
  28. dfs(nb, prev, dist + 1)
  29. doubles = {'@': {}, '.': {}}
  30. graph = {}
  31. visited = set()
  32. for i, node in enumerate(grid):
  33. if node == '@':
  34. dfs(i, None, 0)
  35. return graph
  36. def split_grid(grid, w):
  37. e = grid.index('@')
  38. grid[e - w - 1:e - w + 2] = '@#@'
  39. grid[e - 1:e + 2] = '###'
  40. grid[e + w - 1:e + w + 2] = '@#@'
  41. def collect_from(graph, root, remkeys, keys, dist):
  42. inf = 100000000
  43. work = [(dist, remkeys, root, keys)]
  44. best_dists = {(root, keys): dist}
  45. def visit(node, dist, keys, remkeys):
  46. ident = node, keys
  47. if best_dists.get(ident, inf) > dist:
  48. best_dists[ident] = dist
  49. heappush(work, (dist, remkeys, node, keys))
  50. best = dist, remkeys, root, keys
  51. while True:
  52. while work:
  53. dist, remkeys, node, keys = heappop(work)
  54. if remkeys == 0:
  55. yield keys, dist
  56. return
  57. if remkeys < best[1] or (remkeys == best[1] and dist < best[0]):
  58. best = dist, remkeys, node, keys
  59. for nb, step in graph[node].items():
  60. if nb.islower() and nb not in keys:
  61. nbkeys = ''.join(sorted(keys + nb))
  62. visit(nb, dist + step, nbkeys, remkeys - 1)
  63. elif not nb.isupper() or nb.lower() in keys:
  64. visit(nb, dist + step, keys, remkeys)
  65. dist, remkeys, node, keys = best
  66. newkeys, newdist = yield keys, dist
  67. newremkeys = remkeys - (len(newkeys) - len(keys))
  68. work.append((newdist, newremkeys, node, newkeys))
  69. def collect_keys(graph):
  70. entrances = [node for node in graph if node[0] == '@']
  71. nkeys = sum(node.islower() for node in graph)
  72. bots = []
  73. keys = ''
  74. dist = 0
  75. for entrance in entrances:
  76. bot = collect_from(graph, entrance, nkeys - len(keys), keys, dist)
  77. bots.append(bot)
  78. keys, dist = next(bot)
  79. if len(keys) == nkeys:
  80. return dist
  81. for bot in cycle(bots):
  82. keys, dist = bot.send((keys, dist))
  83. if len(keys) == nkeys:
  84. return dist
  85. # part 1
  86. grid, w = read_grid(sys.stdin)
  87. print(collect_keys(build_graph(grid, w)))
  88. # part 2
  89. split_grid(grid, w)
  90. print(collect_keys(build_graph(grid, w)))