factors.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # This file is part of TRS (http://math.kompiler.org)
  2. #
  3. # TRS is free software: you can redistribute it and/or modify it under the
  4. # terms of the GNU Affero General Public License as published by the Free
  5. # Software Foundation, either version 3 of the License, or (at your option) any
  6. # later version.
  7. #
  8. # TRS is distributed in the hope that it will be useful, but WITHOUT ANY
  9. # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
  10. # A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
  11. # details.
  12. #
  13. # You should have received a copy of the GNU Affero General Public License
  14. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
  15. from itertools import product
  16. from .utils import is_numeric_node
  17. from ..node import ExpressionNode as N, Scope, OP_ADD, OP_MUL
  18. from ..possibilities import Possibility as P, MESSAGES
  19. from ..translate import _
  20. def is_expandable(node):
  21. """
  22. Check if a node is expandable. Only additions that consist of not only
  23. numerics can be expanded.
  24. """
  25. return node.is_op(OP_ADD) \
  26. and not all(map(is_numeric_node, Scope(node)))
  27. def match_expand(node):
  28. """
  29. Expand multiplication of non-numeric additions.
  30. Examples:
  31. (a + b)(c + d) -> ac + ad + bc + bd
  32. (b + c)a -> ab + ac
  33. a(b + c) -> ab + ac
  34. """
  35. assert node.is_op(OP_MUL)
  36. p = []
  37. scope = Scope(node)
  38. l = len(scope)
  39. for distance in range(1, l):
  40. for i, left in enumerate(scope[:-distance]):
  41. right = scope[i + distance]
  42. l_expandable = is_expandable(left)
  43. r_expandable = is_expandable(right)
  44. if l_expandable and r_expandable:
  45. p.append(P(node, expand_double, (scope, left, right)))
  46. elif l_expandable ^ r_expandable:
  47. p.append(P(node, expand_single, (scope, left, right)))
  48. return p
  49. def expand(root, args):
  50. """
  51. (a + b)(c + d) -> ac + ad + bc + bd
  52. (a + b)c -> ac + bc
  53. a(b + c) -> ab + ac
  54. etc..
  55. """
  56. scope, left, right = args
  57. left_scope = Scope(left) if left.is_op(OP_ADD) else [left]
  58. right_scope = Scope(right) if right.is_op(OP_ADD) else [right]
  59. add_scope = [l * r for l, r in product(left_scope, right_scope)]
  60. add = Scope(N(OP_ADD, *add_scope)).as_nary_node()
  61. add.negated = left.negated + right.negated
  62. scope.replace(left, add)
  63. scope.remove(right)
  64. return scope.as_nary_node()
  65. def expand_double(root, args):
  66. return expand(root, args)
  67. MESSAGES[expand_double] = _('Expand ({2})({3}).')
  68. def expand_single(root, args):
  69. return expand(root, args)
  70. MESSAGES[expand_single] = _('Expand ({2})({3}).')