07_amplifiers.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #!/usr/bin/env python3
  2. import sys
  3. from itertools import permutations
  4. def run(p, inputs):
  5. pc = 0
  6. while p[pc] != 99:
  7. modes, opcode = divmod(p[pc], 100)
  8. def modeswitch(offset):
  9. value = p[pc + offset]
  10. mode = modes // (10 ** (offset - 1)) % 10
  11. return value if mode else p[value]
  12. if opcode in (1, 2):
  13. a = modeswitch(1)
  14. b = modeswitch(2)
  15. out = p[pc + 3]
  16. p[out] = a + b if opcode == 1 else a * b
  17. pc += 4
  18. elif opcode == 3:
  19. address = p[pc + 1]
  20. p[address] = inputs.pop()
  21. pc += 2
  22. elif opcode == 4:
  23. inputs = yield modeswitch(1)
  24. pc += 2
  25. elif opcode == 5:
  26. pc = modeswitch(2) if modeswitch(1) else pc + 3
  27. elif opcode == 6:
  28. pc = modeswitch(2) if not modeswitch(1) else pc + 3
  29. elif opcode in (7, 8):
  30. a = modeswitch(1)
  31. b = modeswitch(2)
  32. out = p[pc + 3]
  33. p[out] = int(a < b if opcode == 7 else a == b)
  34. pc += 4
  35. def amplify(p, phases):
  36. amps = []
  37. signal = 0
  38. for phase in phases:
  39. amp = run(list(p), [signal, phase])
  40. amps.append(amp)
  41. signal = next(amp)
  42. try:
  43. while True:
  44. for amp in amps:
  45. signal = amp.send([signal])
  46. except StopIteration:
  47. return signal
  48. program = list(map(int, sys.stdin.read().split(',')))
  49. print(max(amplify(program, phases) for phases in permutations(range(5))))
  50. print(max(amplify(program, phases) for phases in permutations(range(5, 10))))