factors.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from itertools import product, combinations
  2. from .utils import nary_node
  3. from ..node import OP_ADD, OP_MUL
  4. from ..possibilities import Possibility as P, MESSAGES
  5. def match_expand(node):
  6. """
  7. a * (b + c) -> ab + ac
  8. (b + c) * a -> ab + ac
  9. (a + b) * (c + d) -> ac + ad + bc + bd
  10. """
  11. assert node.is_op(OP_MUL)
  12. p = []
  13. leaves = []
  14. additions = []
  15. for n in node.get_scope():
  16. if n.is_leaf():
  17. leaves.append(n)
  18. elif n.op == OP_ADD:
  19. additions.append(n)
  20. for args in product(leaves, additions):
  21. p.append(P(node, expand_single, args))
  22. for args in combinations(additions, 2):
  23. p.append(P(node, expand_double, args))
  24. return p
  25. def expand_single(root, args):
  26. """
  27. Combine a leaf (a) multiplied with an addition of two expressions
  28. (b + c) to an addition of two multiplications.
  29. a * (b + c) -> ab + ac
  30. (b + c) * a -> ab + ac
  31. """
  32. a, bc = args
  33. b, c = bc
  34. scope = root.get_scope()
  35. # Replace 'a' with the new expression
  36. scope[scope.index(a)] = a * b + a * c
  37. # Remove the addition
  38. scope.remove(bc)
  39. return nary_node('*', scope)
  40. def expand_double(root, args):
  41. """
  42. Rewrite two multiplied additions to an addition of four multiplications.
  43. (a + b) * (c + d) -> ac + ad + bc + bd
  44. """
  45. (a, b), (c, d) = ab, cd = args
  46. scope = root.get_scope()
  47. # Replace 'b + c' with the new expression
  48. scope[scope.index(ab)] = a * c + a * d + b * c + b * d
  49. # Remove the right addition
  50. scope.remove(cd)
  51. return nary_node('*', scope)