factors.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from itertools import product, combinations
  2. from ..node import Scope, OP_ADD, OP_MUL, OP_NEG
  3. from ..possibilities import Possibility as P, MESSAGES
  4. from ..translate import _
  5. def match_expand(node):
  6. """
  7. a * (b + c) -> ab + ac
  8. (b + c) * a -> ab + ac
  9. (a + b) * (c + d) -> ac + ad + bc + bd
  10. """
  11. assert node.is_op(OP_MUL)
  12. p = []
  13. leaves = []
  14. additions = []
  15. for n in Scope(node):
  16. if n.is_leaf() or n.is_op(OP_NEG) and n[0].is_leaf():
  17. leaves.append(n)
  18. elif n.op == OP_ADD:
  19. additions.append(n)
  20. for args in product(leaves, additions):
  21. p.append(P(node, expand_single, args))
  22. for args in combinations(additions, 2):
  23. p.append(P(node, expand_double, args))
  24. return p
  25. def expand_single(root, args):
  26. """
  27. Combine a leaf (a) multiplied with an addition of two expressions
  28. (b + c) to an addition of two multiplications.
  29. a * (b + c) -> ab + ac
  30. (b + c) * a -> ab + ac
  31. """
  32. a, bc = args
  33. b, c = bc
  34. scope = Scope(root)
  35. # Replace 'a' with the new expression
  36. scope.remove(a, a * b + a * c)
  37. # Remove the addition
  38. scope.remove(bc)
  39. return scope.as_nary_node()
  40. MESSAGES[expand_single] = _('Expand {1}({2}).')
  41. def expand_double(root, args):
  42. """
  43. Rewrite two multiplied additions to an addition of four multiplications.
  44. (a + b) * (c + d) -> ac + ad + bc + bd
  45. """
  46. (a, b), (c, d) = ab, cd = args
  47. scope = Scope(root)
  48. # Replace 'a + b' with the new expression
  49. scope.remove(ab, a * c + a * d + b * c + b * d)
  50. # Remove the right addition
  51. scope.remove(cd)
  52. return scope.as_nary_node()
  53. MESSAGES[expand_double] = _('Expand ({1})({2}).')