16_valves.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python3
  2. import sys
  3. from functools import reduce
  4. from heapq import heappop, heappush
  5. from itertools import chain, product, tee
  6. from operator import or_
  7. def parse(line):
  8. _, node, _, _, flow, _, _, _, _, nbs = line.split(' ', 9)
  9. return node, int(flow[5:-1]), nbs.rstrip().split(', ')
  10. def bestcase(graph, opened, time, actors):
  11. # each minute each actor opens the closed valve with the highest flow rate
  12. unopened = sorted((rate for i, rate, _ in graph.values()
  13. if not opened & (1 << i)), reverse=True)
  14. times = chain.from_iterable(zip(*tee(range(time, 0, -1), actors)))
  15. return sum(rate * t for rate, t in zip(unopened, times))
  16. def move(graph, pos, opened, time):
  17. i, rate, nbs = graph[pos]
  18. mask = 1 << i
  19. if rate and not opened & mask:
  20. yield pos, rate * time, mask
  21. elif not nbs:
  22. yield pos, 0, 0
  23. for nb in nbs:
  24. yield nb, 0, 0
  25. def open_valves(graph, time, actors):
  26. work = [(0, time, ('AA',) * actors, 0, 0)]
  27. best = {}
  28. all_opened = reduce(or_, (1 << i for i, rate, _ in graph.values() if rate))
  29. while work:
  30. _, t, pos, opened, released = heappop(work)
  31. if t < 2 or opened == all_opened:
  32. return released
  33. t -= 1
  34. for moves in product(*(move(graph, p, opened, t) for p in pos)):
  35. newpos, newflow, newopen = zip(*moves)
  36. masks = list(filter(None, newopen))
  37. # don't open the same valve with multiple actors
  38. if len(masks) == len(set(masks)):
  39. newreleased = released + sum(newflow)
  40. newopened = opened | reduce(or_, newopen)
  41. potential = newreleased + bestcase(graph, newopened, t, actors)
  42. key = newpos, newopened
  43. if best.get(key, 0) < potential:
  44. best[key] = potential
  45. heappush(work, (-potential, t, newpos, newopened, newreleased))
  46. graph = {valve: (i, rate, nbs)
  47. for i, (valve, rate, nbs) in enumerate(map(parse, sys.stdin))}
  48. print(open_valves(graph, 30, 1))
  49. print(open_valves(graph, 26, 2))