groups.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from itertools import combinations
  2. from .utils import evals_to_numeric
  3. from ..node import ExpressionLeaf as Leaf, Scope, OP_ADD, OP_MUL, nary_node, \
  4. negate
  5. from ..possibilities import Possibility as P, MESSAGES
  6. from ..translate import _
  7. def match_combine_groups(node):
  8. """
  9. Match possible combinations of groups of expressions using non-strict
  10. equivalence.
  11. Examples:
  12. a + a -> 2a
  13. a + 2a -> 3a
  14. ab + ab -> 2ab
  15. ab + 2ab -> 3ab
  16. ab + ba -> 2ab
  17. """
  18. assert node.is_op(OP_ADD)
  19. p = []
  20. groups = []
  21. scope = Scope(node)
  22. for n in scope:
  23. if not n.is_numeric():
  24. groups.append((Leaf(1), n, n, True))
  25. # Each number multiplication yields a group, multiple occurences of
  26. # the same group can be replaced by a single one
  27. if n.is_op(OP_MUL):
  28. n_scope = Scope(n)
  29. l = len(n_scope)
  30. for i, sub_node in enumerate(n_scope):
  31. if evals_to_numeric(sub_node):
  32. others = [n_scope[j] for j in range(i) + range(i + 1, l)]
  33. if len(others) == 1:
  34. g = others[0]
  35. else:
  36. g = nary_node(OP_MUL, others)
  37. groups.append((sub_node, g, n, False))
  38. for (c0, g0, n0, root0), (c1, g1, n1, root1) in combinations(groups, 2):
  39. if not root0:
  40. c0 = c0.negate(n0.negated)
  41. if not root1:
  42. c1 = c1.negate(n1.negated)
  43. if g0.equals(g1):
  44. p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
  45. elif g0.equals(g1, ignore_negation=True):
  46. # Move negations to constants
  47. c0 = c0.negate(g0.negated)
  48. c1 = c1.negate(g1.negated)
  49. g0 = negate(g0, 0)
  50. g1 = negate(g1, 0)
  51. p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
  52. return p
  53. def combine_groups(root, args):
  54. scope, c0, g0, n0, c1, g1, n1 = args
  55. # Replace the left node with the new expression
  56. scope.replace(n0, (c0 + c1) * g0)
  57. # Remove the right node
  58. scope.remove(n1)
  59. return scope.as_nary_node()
  60. MESSAGES[combine_groups] = \
  61. _('Group "{3}" is multiplied by {2} and {5}, combine them.')