18_vault.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #!/usr/bin/env python3
  2. import sys
  3. from heapq import heappop, heappush
  4. from itertools import chain, cycle
  5. def read_grid(f):
  6. rows = [line.rstrip() for line in f]
  7. return list(chain.from_iterable(rows)), len(rows[0])
  8. def build_graph(grid, w):
  9. def dfs(loc, prev, dist):
  10. nblocs = (loc - 1, loc + 1, loc - w, loc + w)
  11. neighbors = [nb for nb in nblocs if grid[nb] != '#']
  12. node = grid[loc]
  13. if node == '@' or node.isalpha() or len(neighbors) > 2:
  14. if node in doubles:
  15. seen = doubles[node]
  16. count = seen.setdefault(loc, len(seen))
  17. node += str(count)
  18. if prev is None:
  19. graph[node] = {}
  20. else:
  21. graph[prev][node] = dist
  22. graph.setdefault(node, {})[prev] = dist
  23. prev = node
  24. dist = 0
  25. if loc not in visited:
  26. visited.add(loc)
  27. for nb in neighbors:
  28. dfs(nb, prev, dist + 1)
  29. doubles = {'@': {}, '.': {}}
  30. graph = {}
  31. visited = set()
  32. for i, node in enumerate(grid):
  33. if node == '@':
  34. dfs(i, None, 0)
  35. return graph
  36. def collect_keys_alone(graph):
  37. totalkeys = sum(node.islower() for node in graph)
  38. inf = 10000000000
  39. work = [(0, totalkeys, '@0', '')]
  40. best_dists = {('@0', ''): 0}
  41. def visit(node, dist, keys, remkeys):
  42. ident = node, keys
  43. if best_dists.get(ident, inf) > dist:
  44. best_dists[ident] = dist
  45. heappush(work, (dist, remkeys, node, keys))
  46. while work:
  47. dist, remkeys, node, keys = heappop(work)
  48. if remkeys == 0:
  49. return dist
  50. for nb, step in graph[node].items():
  51. if nb.islower() and nb not in keys:
  52. nbkeys = ''.join(sorted(keys + nb))
  53. visit(nb, dist + step, nbkeys, remkeys - 1)
  54. elif not nb.isupper() or nb.lower() in keys:
  55. visit(nb, dist + step, keys, remkeys)
  56. def split_grid(grid, w):
  57. e = grid.index('@')
  58. grid[e - w - 1:e - w + 2] = '@#@'
  59. grid[e - 1:e + 2] = '###'
  60. grid[e + w - 1:e + w + 2] = '@#@'
  61. def collect_from(graph, root, remkeys, keys, dist):
  62. inf = 100000000
  63. work = [(dist, remkeys, root, keys)]
  64. best_dists = {(root, keys): dist}
  65. def visit(node, dist, keys, remkeys):
  66. ident = node, keys
  67. if best_dists.get(ident, inf) > dist:
  68. best_dists[ident] = dist
  69. heappush(work, (dist, remkeys, node, keys))
  70. best = dist, remkeys, root, keys
  71. while True:
  72. while work:
  73. dist, remkeys, node, keys = heappop(work)
  74. if remkeys == 0:
  75. yield True, keys, dist
  76. return
  77. if remkeys < best[1] or (remkeys == best[1] and dist < best[0]):
  78. best = dist, remkeys, node, keys
  79. for nb, step in graph[node].items():
  80. if nb.islower() and nb not in keys:
  81. nbkeys = ''.join(sorted(keys + nb))
  82. visit(nb, dist + step, nbkeys, remkeys - 1)
  83. elif not nb.isupper() or nb.lower() in keys:
  84. visit(nb, dist + step, keys, remkeys)
  85. dist, remkeys, node, keys = best
  86. newkeys, newdist = yield False, keys, dist
  87. newremkeys = remkeys - (len(newkeys) - len(keys))
  88. work.append((newdist, newremkeys, node, newkeys))
  89. def collect_keys(graph):
  90. entrances = tuple(node for node in graph if node[0] == '@')
  91. totalkeys = sum(node.islower() for node in graph)
  92. bots = []
  93. keys = ''
  94. dist = 0
  95. for entrance in entrances:
  96. bot = collect_from(graph, entrance, totalkeys - len(keys), keys, dist)
  97. bots.append(bot)
  98. done, keys, dist = next(bot)
  99. if done:
  100. return dist
  101. for bot in cycle(bots):
  102. done, keys, dist = bot.send((keys, dist))
  103. if done:
  104. return dist
  105. # part 1
  106. grid, w = read_grid(sys.stdin)
  107. graph = build_graph(grid, w)
  108. print(collect_keys(graph))
  109. # part 2
  110. split_grid(grid, w)
  111. graph = build_graph(grid, w)
  112. print(collect_keys(graph))