20_donutmaze.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python3
  2. import sys
  3. from collections import deque
  4. from itertools import chain
  5. def read_grid(f):
  6. rows = [line.replace('\n', '') for line in f]
  7. return list(chain.from_iterable(rows)), len(rows[0])
  8. def find_labels(grid, w):
  9. def get(i):
  10. return grid[i] if 0 <= i < len(grid) else ' '
  11. def try_label(i, spacediff, outer):
  12. if get(i + spacediff) == '.':
  13. labels.setdefault(label, []).append((i, i + spacediff, outer))
  14. labels = {}
  15. h = len(grid) // w
  16. for i, cell in enumerate(grid):
  17. if cell.isalpha():
  18. y, x = divmod(i, w)
  19. label = get(i) + get(i + 1)
  20. if label.isalpha():
  21. try_label(i, -1, x > w // 2)
  22. try_label(i + 1, 1, x < w // 2)
  23. label = get(i) + get(i + w)
  24. if label.isalpha():
  25. try_label(i, -w, y > h // 2)
  26. try_label(i + w, w, y < h // 2)
  27. return labels
  28. def find_portals(labels):
  29. assert all(len(v) == 2 for v in labels.values())
  30. inner = {}
  31. outer = {}
  32. for (l1, s1, o1), (l2, s2, o2) in labels.values():
  33. assert o1 ^ o2
  34. (outer if o1 else inner)[l1] = s2
  35. (outer if o2 else inner)[l2] = s1
  36. return inner, outer
  37. def shortest_path(grid, w, leveldiff):
  38. labels = find_labels(grid, w)
  39. src = labels.pop('AA')[0][1]
  40. dst = labels.pop('ZZ')[0][1]
  41. inner, outer = find_portals(labels)
  42. work = deque([(src, 0, 0)])
  43. visited = set()
  44. while work:
  45. i, dist, level = work.popleft()
  46. if i == dst and level == 0:
  47. return dist
  48. if (i, level) in visited:
  49. continue
  50. visited.add((i, level))
  51. for nb in (i - 1, i + 1, i - w, i + w):
  52. if grid[nb] == '.':
  53. work.append((nb, dist + 1, level))
  54. elif nb in inner:
  55. work.append((inner[nb], dist + 1, level + leveldiff))
  56. elif nb in outer and level > 0:
  57. work.append((outer[nb], dist + 1, level - leveldiff))
  58. grid, w = read_grid(sys.stdin)
  59. print(shortest_path(grid, w, 0))
  60. print(shortest_path(grid, w, 1))