Commit dedcf309 authored by Taddeus Kroes's avatar Taddeus Kroes

Fixed nominator term extraction rule for single-term nominators.

parent c9d72b47
from itertools import combinations, product from itertools import combinations, product, ifilterfalse
import copy import copy
from .utils import least_common_multiple, partition, is_numeric_node, \ from .utils import least_common_multiple, partition, is_numeric_node, \
...@@ -289,7 +289,7 @@ MESSAGES[divide_by_fraction] = \ ...@@ -289,7 +289,7 @@ MESSAGES[divide_by_fraction] = \
_('Move {3} to nominator of fraction {1} / {2}.') _('Move {3} to nominator of fraction {1} / {2}.')
def is_power_combination(a, b): def is_power_combination(pair):
""" """
Check if two nodes are powers that can be combined in a fraction, for Check if two nodes are powers that can be combined in a fraction, for
example: example:
...@@ -298,6 +298,8 @@ def is_power_combination(a, b): ...@@ -298,6 +298,8 @@ def is_power_combination(a, b):
a^2 and a^2 a^2 and a^2
a^2 and a a^2 and a
""" """
a, b = pair
if a.is_power(): if a.is_power():
a = a[0] a = a[0]
...@@ -343,24 +345,22 @@ def match_extract_fraction_terms(node): ...@@ -343,24 +345,22 @@ def match_extract_fraction_terms(node):
n_scope, d_scope = map(mult_scope, node) n_scope, d_scope = map(mult_scope, node)
p = [] p = []
if len(n_scope) == 1 and len(d_scope) == 1:
return p
nominator, denominator = node nominator, denominator = node
for n in n_scope: # ac / b
# ac / b for n in ifilterfalse(evals_to_numeric, n_scope):
if not evals_to_numeric(n): a_scope = mult_scope(nominator)
a_scope = mult_scope(nominator) a = remove_from_mult_scope(a_scope, n)
a = remove_from_mult_scope(a_scope, n)
if evals_to_numeric(a / denominator): if evals_to_numeric(a / denominator):
p.append(P(node, extract_nominator_term, (a, n))) p.append(P(node, extract_nominator_term, (a, n)))
# a ^ b * c / (a ^ d * e) if len(n_scope) == 1 and len(d_scope) == 1:
for d in [d for d in d_scope if is_power_combination(n, d)]: return p
p.append(P(node, extract_fraction_terms, (n_scope, d_scope, n, d)))
# a ^ b * c / (a ^ d * e)
for n, d in filter(is_power_combination, product(n_scope, d_scope)):
p.append(P(node, extract_fraction_terms, (n_scope, d_scope, n, d)))
return p return p
...@@ -374,6 +374,10 @@ def extract_nominator_term(root, args): ...@@ -374,6 +374,10 @@ def extract_nominator_term(root, args):
return a / root[1] * c return a / root[1] * c
MESSAGES[extract_nominator_term] = \
_('Extract {2} from the nominator of fraction {0}.')
def extract_fraction_terms(root, args): def extract_fraction_terms(root, args):
""" """
ab / a -> a / a * (b / 1) ab / a -> a / a * (b / 1)
......
...@@ -238,19 +238,26 @@ class TestRulesFractions(RulesTestCase): ...@@ -238,19 +238,26 @@ class TestRulesFractions(RulesTestCase):
self.assertEqualPos(match_extract_fraction_terms(root), self.assertEqualPos(match_extract_fraction_terms(root),
[P(root, extract_nominator_term, (2, a))]) [P(root, extract_nominator_term, (2, a))])
a, l3 = n, d = root = tree('a / 3')
self.assertEqualPos(match_extract_fraction_terms(root),
[P(root, extract_nominator_term, (1, a))])
root = tree('2*4 / 3') root = tree('2*4 / 3')
self.assertEqualPos(match_extract_fraction_terms(root), []) self.assertEqualPos(match_extract_fraction_terms(root), [])
n, d = root = tree('2a / 2') n, d = root = tree('2a / 2')
self.assertEqualPos(match_extract_fraction_terms(root), self.assertEqualPos(match_extract_fraction_terms(root),
[P(root, extract_fraction_terms, (Scope(n), lscp(d), 2, 2)), [P(root, extract_nominator_term, (2, a)),
P(root, extract_nominator_term, (2, a))]) P(root, extract_fraction_terms, (Scope(n), lscp(d), 2, 2))])
def test_extract_nominator_term(self): def test_extract_nominator_term(self):
root, expect = tree('2a / 3, 2 / 3 * a') root, expect = tree('2a / 3, 2 / 3 * a')
l2, a = root[0] l2, a = root[0]
self.assertEqual(extract_nominator_term(root, (l2, a)), expect) self.assertEqual(extract_nominator_term(root, (l2, a)), expect)
root, expect, l1 = tree('a / 3, 1 / 3 * a, 1')
self.assertEqual(extract_nominator_term(root, (l1, root[0])), expect)
def test_extract_fraction_terms_basic(self): def test_extract_fraction_terms_basic(self):
root, expect = tree('ab / (ca), a / a * (b / c)') root, expect = tree('ab / (ca), a / a * (b / c)')
n, d = root n, d = root
......
...@@ -75,8 +75,7 @@ class TestRulesLineq(RulesTestCase): ...@@ -75,8 +75,7 @@ class TestRulesLineq(RulesTestCase):
'5x = 0 - 5', '5x = 0 - 5',
'5x = -5', '5x = -5',
'5x / 5 = (-5) / 5', '5x / 5 = (-5) / 5',
'5 / 5 * (x / 1) = (-5) / 5', '5 / 5 * x = (-5) / 5',
'1(x / 1) = (-5) / 5',
'1x = (-5) / 5', '1x = (-5) / 5',
'x = (-5) / 5', 'x = (-5) / 5',
'x = -5 / 5', 'x = -5 / 5',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment