22_cave.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #!/usr/bin/env python3
  2. from heapq import heapify, heappush, heappop
  3. depth = 6084
  4. tx, ty = target = 14, 709
  5. #depth = 510
  6. #tx, ty = target = 10, 10
  7. def scan(pad):
  8. def erode(geo_index):
  9. return (geo_index + depth) % 20183
  10. w = tx + pad + 1
  11. h = ty + pad + 1
  12. grid = [erode(x * 16807) for x in range(w)]
  13. for y in range(1, h):
  14. grid.append(erode(y * 48271))
  15. for x in range(1, w):
  16. if x == tx and y == ty:
  17. idx = 0
  18. else:
  19. top = grid[(y - 1) * w + x]
  20. left = grid[y * w + x - 1]
  21. idx = top * left
  22. grid.append(erode(idx))
  23. for i in range(len(grid)):
  24. grid[i] %= 3
  25. return grid, w
  26. def shortest_path(graph, source, target):
  27. dist = len(graph) * [1 << 32]
  28. dist[source] = 0
  29. Q = [(dist[v], v) for v, nb in enumerate(graph) if nb]
  30. heapify(Q)
  31. while Q:
  32. udist, u = heappop(Q)
  33. if udist > dist[u]:
  34. continue
  35. if u == target:
  36. return udist
  37. for v, weight in graph[u]:
  38. alt = udist + weight
  39. if alt < dist[v]:
  40. dist[v] = alt
  41. heappush(Q, (alt, v))
  42. def rescue(grid, w):
  43. # approach:
  44. # - build graph with (x, y, tool) tuples as vertices
  45. # - record weighted edges including tolao witching between vertices where
  46. # switching is allowed
  47. # - do Dijkstra on the result
  48. h = len(grid) // w
  49. def neighbours(i):
  50. y, x = divmod(i, w)
  51. if y > 0: yield i - w
  52. if x > 0: yield i - 1
  53. if x < w - 1: yield i + 1
  54. if y < h - 1: yield i + w
  55. TORCH, GEAR, NEITHER = range(3)
  56. tools = [
  57. (GEAR, TORCH), # rocky
  58. (GEAR, NEITHER), # wet
  59. (TORCH, NEITHER), # narrow
  60. ]
  61. # for efficiency, graph is encoded as 3 concatenated lists of (y * w + x)
  62. # vertices, one for each tool
  63. off = len(grid)
  64. graph = [[] for i in range(3 * off)]
  65. for i, sty in enumerate(grid):
  66. for j in neighbours(i):
  67. tty = grid[j]
  68. for stool in tools[sty]:
  69. for ttool in tools[tty]:
  70. if ttool in tools[sty]:
  71. cost = 1 if ttool == stool else 8
  72. graph[i + off * stool].append((j + off * ttool, cost))
  73. # we start and end with a torch so in the first of 3 lists
  74. return shortest_path(graph, 0, ty * w + tx)
  75. grid, w = scan(15)
  76. # part 1
  77. print(sum(sum(grid[y * w:y * w + tx + 1]) for y in range(ty + 1)))
  78. # part 2
  79. print(rescue(grid, w))