07_towers.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python3
  2. import sys
  3. from collections import Counter
  4. class Program:
  5. def __init__(self, name, weight, children):
  6. self.name = name
  7. self.weight = weight
  8. self.children = children
  9. self.accweight = None
  10. @classmethod
  11. def parse(cls, f):
  12. parents = {}
  13. progs = {}
  14. for line in f:
  15. if '->' in line:
  16. left, right = line.rstrip().split(' -> ')
  17. children = right.split(', ')
  18. else:
  19. left = line.rstrip()
  20. children = []
  21. name, weight = left[:-1].split(' (')
  22. prog = cls(name, int(weight), children)
  23. for child in children:
  24. parents[child] = name
  25. progs[name] = prog
  26. for p in progs.values():
  27. p.children = [progs[child] for child in p.children]
  28. return next(p for name, p in progs.items() if name not in parents)
  29. def postorder(self):
  30. for child in self.children:
  31. yield from child.postorder()
  32. yield self
  33. def tower_weight(self):
  34. if self.accweight is None:
  35. self.accweight = self.weight + sum(c.tower_weight()
  36. for c in self.children)
  37. return self.accweight
  38. def find_unbalanced(self):
  39. for prog in self.postorder():
  40. weights = Counter(c.tower_weight() for c in prog.children)
  41. if len(weights) > 1:
  42. assert len(weights) == 2
  43. unba, ba = (p[1] for p in sorted((n, w) for w, n in weights.items()))
  44. u = next(c for c in prog.children if c.tower_weight() == unba)
  45. return u.weight + ba - unba
  46. root = Program.parse(sys.stdin)
  47. print(root.name)
  48. print(root.find_unbalanced())