16_instructions.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #!/usr/bin/env python3
  2. import sys
  3. from collections import defaultdict
  4. isa = {
  5. 'addr': lambda a, b, reg: reg[a] + reg[b],
  6. 'addi': lambda a, b, reg: reg[a] + b,
  7. 'mulr': lambda a, b, reg: reg[a] * reg[b],
  8. 'muli': lambda a, b, reg: reg[a] * b,
  9. 'banr': lambda a, b, reg: reg[a] & reg[b],
  10. 'bani': lambda a, b, reg: reg[a] & b,
  11. 'borr': lambda a, b, reg: reg[a] | reg[b],
  12. 'bori': lambda a, b, reg: reg[a] | b,
  13. 'setr': lambda a, b, reg: reg[a],
  14. 'seti': lambda a, b, reg: a,
  15. 'gtir': lambda a, b, reg: int(a > reg[b]),
  16. 'gtri': lambda a, b, reg: int(reg[a] > b),
  17. 'gtrr': lambda a, b, reg: int(reg[a] > reg[b]),
  18. 'eqir': lambda a, b, reg: int(a == reg[b]),
  19. 'eqri': lambda a, b, reg: int(reg[a] == b),
  20. 'eqrr': lambda a, b, reg: int(reg[a] == reg[b]),
  21. }
  22. def run(exe, a, b, reg):
  23. reg = list(reg)
  24. reg[out] = exe(a, b, reg)
  25. return tuple(reg)
  26. # part 1
  27. effects, program = sys.stdin.read().rstrip().split('\n\n\n\n')
  28. opcodes = [set(isa.keys()) for i in range(len(isa))]
  29. three = 0
  30. for effect in effects.split('\n\n'):
  31. before, inst, after = effect.split('\n')
  32. reg = tuple(map(int, before[9:-1].split(', ')))
  33. opcode, a, b, out = map(int, inst.split())
  34. expect = tuple(map(int, after[9:-1].split(', ')))
  35. mnems = set(mnem for mnem, exe in isa.items() if run(exe, a, b, reg) == expect)
  36. opcodes[opcode] &= mnems
  37. three += int(len(mnems) >= 3)
  38. print(three)
  39. # part 2
  40. while sum(map(len, opcodes)) > len(isa):
  41. for opcode, mnems in enumerate(opcodes):
  42. if len(mnems) == 1:
  43. certain = next(iter(mnems))
  44. for other, mnems in enumerate(opcodes):
  45. if other != opcode:
  46. mnems.discard(certain)
  47. opcodes = [next(iter(mnems)) for mnems in opcodes]
  48. reg = [0, 0, 0, 0]
  49. for inst in program.split('\n'):
  50. opcode, a, b, out = map(int, inst.split())
  51. reg[out] = isa[opcodes[opcode]](a, b, reg)
  52. print(reg[0])