16_opcodes.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #!/usr/bin/env python3
  2. import sys
  3. isa = {
  4. 'addr': lambda a, b, reg: reg[a] + reg[b],
  5. 'addi': lambda a, b, reg: reg[a] + b,
  6. 'mulr': lambda a, b, reg: reg[a] * reg[b],
  7. 'muli': lambda a, b, reg: reg[a] * b,
  8. 'banr': lambda a, b, reg: reg[a] & reg[b],
  9. 'bani': lambda a, b, reg: reg[a] & b,
  10. 'borr': lambda a, b, reg: reg[a] | reg[b],
  11. 'bori': lambda a, b, reg: reg[a] | b,
  12. 'setr': lambda a, b, reg: reg[a],
  13. 'seti': lambda a, b, reg: a,
  14. 'gtir': lambda a, b, reg: int(a > reg[b]),
  15. 'gtri': lambda a, b, reg: int(reg[a] > b),
  16. 'gtrr': lambda a, b, reg: int(reg[a] > reg[b]),
  17. 'eqir': lambda a, b, reg: int(a == reg[b]),
  18. 'eqri': lambda a, b, reg: int(reg[a] == b),
  19. 'eqrr': lambda a, b, reg: int(reg[a] == reg[b]),
  20. }
  21. def run(exe, a, b, reg):
  22. reg = list(reg)
  23. reg[out] = exe(a, b, reg)
  24. return tuple(reg)
  25. # part 1
  26. effects, program = sys.stdin.read().rstrip().split('\n\n\n\n')
  27. opcodes = [set(isa.keys()) for i in range(len(isa))]
  28. three = 0
  29. for effect in effects.split('\n\n'):
  30. before, inst, after = effect.split('\n')
  31. reg = tuple(map(int, before[9:-1].split(', ')))
  32. opcode, a, b, out = map(int, inst.split())
  33. expect = tuple(map(int, after[9:-1].split(', ')))
  34. mnems = set(m for m, exe in isa.items() if run(exe, a, b, reg) == expect)
  35. opcodes[opcode] &= mnems
  36. three += int(len(mnems) >= 3)
  37. print(three)
  38. # part 2
  39. while sum(map(len, opcodes)) > len(isa):
  40. for opcode, mnems in enumerate(opcodes):
  41. if len(mnems) == 1:
  42. certain = next(iter(mnems))
  43. for other, mnems in enumerate(opcodes):
  44. if other != opcode:
  45. mnems.discard(certain)
  46. opcodes = [mnems.pop() for mnems in opcodes]
  47. reg = [0, 0, 0, 0]
  48. for inst in program.split('\n'):
  49. opcode, a, b, out = map(int, inst.split())
  50. reg[out] = isa[opcodes[opcode]](a, b, reg)
  51. print(reg[0])