12_fences.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #!/usr/bin/env python3
  2. import sys
  3. from itertools import pairwise
  4. def neighbors(x, y):
  5. return ((x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1))
  6. def regions(grid):
  7. h = len(grid)
  8. w = len(grid[0])
  9. visited = set()
  10. def bfs(x, y):
  11. ty = grid[y][x]
  12. visited.add((x, y))
  13. work = [(x, y)]
  14. while work:
  15. x, y = work.pop()
  16. yield x, y
  17. for nb in neighbors(x, y):
  18. nx, ny = nb
  19. if 0 <= nx < w and 0 <= ny < h and \
  20. grid[ny][nx] == ty and nb not in visited:
  21. visited.add(nb)
  22. work.append(nb)
  23. for y in range(h):
  24. for x in range(w):
  25. if (x, y) not in visited:
  26. yield set(bfs(x, y))
  27. def fences(region):
  28. for xy in region:
  29. for nb in neighbors(*xy):
  30. if nb not in region:
  31. yield xy, nb
  32. def sides(region):
  33. indexed = {}
  34. for (x1, y1), (x2, y2) in fences(region):
  35. if x1 == x2:
  36. indexed.setdefault((True, y1, y2), []).append(x1)
  37. else:
  38. indexed.setdefault((False, x1, x2), []).append(y1)
  39. return sum(1 + sum(b - a > 1 for a, b in pairwise(sorted(line)))
  40. for line in indexed.values())
  41. def perimeter(region):
  42. return sum(1 for _ in fences(region))
  43. grid = [line[:-1] for line in sys.stdin]
  44. reg = list(regions(grid))
  45. print(sum(len(r) * perimeter(r) for r in reg))
  46. print(sum(len(r) * sides(r) for r in reg))