22_cave.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #!/usr/bin/env python3
  2. from heapq import heapify, heappush, heappop
  3. depth = 6084
  4. tx, ty = target = 14, 709
  5. #depth = 510
  6. #tx, ty = target = 10, 10
  7. def scan(pad):
  8. def erode(geo_index):
  9. return (geo_index + depth) % 20183
  10. w = tx + pad + 1
  11. h = ty + pad + 1
  12. grid = [erode(x * 16807) for x in range(w)]
  13. for y in range(1, h):
  14. grid.append(erode(y * 48271))
  15. for x in range(1, w):
  16. if x == tx and y == ty:
  17. idx = 0
  18. else:
  19. top = grid[(y - 1) * w + x]
  20. left = grid[y * w + x - 1]
  21. idx = top * left
  22. grid.append(erode(idx))
  23. for i in range(len(grid)):
  24. grid[i] %= 3
  25. return grid, w
  26. def shortest_path(graph, source, target):
  27. dist = len(graph) * [1 << 32]
  28. dist[source] = 0
  29. Q = [(dist[v], v) for v, nb in enumerate(graph) if nb]
  30. heapify(Q)
  31. while Q:
  32. udist, u = heappop(Q)
  33. if udist == dist[u]:
  34. if u == target:
  35. return udist
  36. for v, weight in graph[u]:
  37. alt = udist + weight
  38. if alt < dist[v]:
  39. dist[v] = alt
  40. heappush(Q, (alt, v))
  41. def rescue(grid, w):
  42. # approach:
  43. # - build graph with (x, y, tool) tuples as vertices
  44. # - record weighted edges including tolao witching between vertices where
  45. # switching is allowed
  46. # - do Dijkstra on the result
  47. h = len(grid) // w
  48. def neighbours(i):
  49. y, x = divmod(i, w)
  50. if y > 0: yield i - w
  51. if x > 0: yield i - 1
  52. if x < w - 1: yield i + 1
  53. if y < h - 1: yield i + w
  54. TORCH, GEAR, NEITHER = range(3)
  55. tools = [
  56. (GEAR, TORCH), # rocky
  57. (GEAR, NEITHER), # wet
  58. (TORCH, NEITHER), # narrow
  59. ]
  60. # for efficiency, graph is encoded as 3 concatenated lists of (y * w + x)
  61. # vertices, one for each tool
  62. off = len(grid)
  63. graph = [[] for i in range(3 * off)]
  64. for i, sty in enumerate(grid):
  65. for j in neighbours(i):
  66. tty = grid[j]
  67. for stool in tools[sty]:
  68. for ttool in tools[tty]:
  69. if ttool in tools[sty]:
  70. cost = 1 if ttool == stool else 8
  71. graph[i + off * stool].append((j + off * ttool, cost))
  72. # we start and end with a torch so in the first of 3 lists
  73. return shortest_path(graph, 0, ty * w + tx)
  74. grid, w = scan(15)
  75. # part 1
  76. print(sum(sum(grid[y * w:y * w + tx + 1]) for y in range(ty + 1)))
  77. # part 2
  78. print(rescue(grid, w))