Explorar el Código

Implemented some basis rewrite rules, along with unit tests for node.py.

Taddeus Kroes hace 14 años
padre
commit
513c04f37f
Se han modificado 5 ficheros con 326 adiciones y 46 borrados
  1. 104 9
      src/node.py
  2. 5 0
      src/possibilities.py
  3. 92 37
      src/rules.py
  4. 86 0
      tests/test_node.py
  5. 39 0
      tests/test_rules.py

+ 104 - 9
src/node.py

@@ -8,27 +8,110 @@ from graph_drawing.line import generate_line
 from graph_drawing.node import Node, Leaf
 
 
-#NODE_TYPE = 0
-#NODE_
+TYPE_OPERATOR = 1
+TYPE_IDENTIFIER = 2
+TYPE_INTEGER = 4
+TYPE_FLOAT = 8
+TYPE_NUMERIC = TYPE_INTEGER | TYPE_FLOAT
+
+
+# Unary
+OP_NEG = 1
+
+# Binary
+OP_ADD = 2
+OP_SUB = 3
+OP_MUL = 4
+OP_DIV = 5
+OP_POW = 6
+OP_MOD = 7
+
+# N-ary (functions)
+OP_INT = 8
+OP_EXPAND = 9
+
+
+TYPE_MAP = {
+        int: TYPE_INTEGER,
+        float: TYPE_FLOAT,
+        str: TYPE_IDENTIFIER,
+        }
+
+
+OPT_MAP = {
+        '+': OP_ADD,
+        '-': OP_SUB,
+        '*': OP_MUL,
+        '/': OP_DIV,
+        '^': OP_POW,
+        'mod': OP_MOD,
+        'int': OP_INT,
+        'expand': OP_EXPAND,
+        }
+
 
 class ExpressionNode(Node):
     def __init__(self, *args, **kwargs):
         super(ExpressionNode, self).__init__(*args, **kwargs)
-        #self.type = NODE_TYPE
+        self.type = TYPE_OPERATOR
+        self.opt = OPT_MAP[args[0]]
 
-    def __str__(self):
+    def __str__(self):  # pragma: nocover
         return generate_line(self)
 
+    def graph(self):  # pragma: nocover
+        return generate_graph(self)
+
     def replace(self, node):
         pos = self.parent.nodes.index(self)
         self.parent.nodes[pos] = node
         node.parent = self.parent
         self.parent = None
 
-    def graph(self):
-        return generate_graph(self)
+    def is_power(self):
+        return self.opt == OP_POW
+
+    def is_nary(self):
+        return self.opt in [OP_ADD, OP_SUB, OP_MUL]
+
+    def get_order(self):
+        if self.is_power() and self[0].is_identifier() \
+                and isinstance(self[1], Leaf):
+            return (self[0].value, self[1].value, 1)
+
+        for n0, n1 in [(0, 1), (1, 0)]:
+            if self[n0].is_numeric() and not isinstance(self[n1], Leaf) \
+                    and self[n1].is_power():
+                coeff, power = self
+
+                if power[0].is_identifier() and isinstance(power[1], Leaf):
+                    return (power[0].value, power[1].value, coeff.value)
+
+    def get_scope(self):
+        scope = []
+
+        for child in self:
+            if not isinstance(child, Leaf) and child.opt == self.opt:
+                scope += child.get_scope()
+            else:
+                scope.append(child)
+
+        return scope
+
 
 class ExpressionLeaf(Leaf):
+    def __init__(self, *args, **kwargs):
+        super(ExpressionLeaf, self).__init__(*args, **kwargs)
+
+        for data_type, type_repr in TYPE_MAP.iteritems():
+            if isinstance(args[0], data_type):
+                self.type = type_repr
+                break
+
+    def get_order(self):
+        if self.is_identifier():
+            return (self.value, 1, 1)
+
     def replace(self, node):
         if not hasattr(self, 'parent'):
             return
@@ -38,8 +121,20 @@ class ExpressionLeaf(Leaf):
         node.parent = self.parent
         self.parent = None
 
+    def is_identifier(self):
+        return self.type & TYPE_IDENTIFIER
+
+    def is_int(self):
+        return self.type & TYPE_INTEGER
+
+    def is_float(self):
+        return self.type & TYPE_FLOAT
+
+    def is_numeric(self):
+        return self.type & TYPE_NUMERIC
+
 
-if __name__ == '__main__':
+if __name__ == '__main__':  # pragma: nocover
     l0 = ExpressionLeaf(3)
     l1 = ExpressionLeaf(4)
     l2 = ExpressionLeaf(5)
@@ -67,8 +162,8 @@ if __name__ == '__main__':
         return res
 
     possibilities = [
-            (n0, lambda (x,y): ExpressionLeaf(x.value + y.value)),
-            (n1, lambda (x,y): ExpressionLeaf(x.value + y.value)),
+            (n0, lambda (x, y): ExpressionLeaf(x.value + y.value)),
+            (n1, lambda (x, y): ExpressionLeaf(x.value + y.value)),
             (n2, rewrite_multiply),
             ]
 

+ 5 - 0
src/possibilities.py

@@ -0,0 +1,5 @@
+class Possibility(object):
+    def __init__(self, root, handler, args):
+        self.root = root
+        self.handler = handler
+        self.args = args

+ 92 - 37
src/rules.py

@@ -1,55 +1,110 @@
-from node import ExpressionLeaf as Leaf
+from itertools import combinations
 
-def get_factor_constants(operand):
-    op = operand.title()
-    res = []
+from node import ExpressionNode as Node, ExpressionLeaf as Leaf
+from possibilities import Possibility as P
 
-    if operand.type == OP_MUL:
-        if operand[0].type == LEAF_NUM:
-            fn()
 
-        if operand[1].type == LEAF_NUM:
-            res += operand[1]
+def match_combine_factors(node):
+    """
+    n + exp + m -> exp + (n + m)
+    k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
+    """
+    p = []
 
-    return res
+    if node.is_nary():
+        # Collect all nodes that can be combined
+        # Numeric leaves
+        numerics = []
+
+        # Identifier leaves of all orders, tuple format is;
+        # (identifier, exponent, coefficient)
+        orders = []
+
+        # Nodes that cannot be combined
+        others = []
+
+        for n in node.get_scope():
+            if isinstance(n, Leaf):
+                if n.is_numeric():
+                    numerics.append(n)
+                elif n.is_identifier():
+                    orders.append((n.value, 1, 1))
+            else:
+                order = n.get_order()
+
+                if order:
+                    orders += order
+                else:
+                    others.append(n)
+
+        if len(numerics) > 1:
+            for num0, num1 in combinations(numerics, 2):
+                p.append(P(node, combine_numerics, (num0, num1, others)))
+
+        if len(orders) > 1:
+            for order0, order1 in combinations(orders, 2):
+                id0, exponent0, coeff0 = order0
+                id1, exponent1, coeff1 = order1
+
+                if id0 == id1 and exponent0 == exponent1:
+                    # Same identifier and exponent -> combine coefficients
+                    args = order0 + (coeff1,) + (others,)
+                    p.append(P(node, combine_orders, args))
 
-def combine_plus_factors(node):
-    p = []
+    return p
 
-    # Check if any numeric factors can be combined
-    def apply_numeric_factors(node, leaves):
-        return Leaf(reduce(lambda a, b: a.value + b.value, leaves))
 
-    num_nodes = []
+def combine_numerics(root, args):
+    """
+    Combine two numeric leaves in an n-ary plus.
 
-    for n in node:
-        # NUM + NUM -> NUM
-        if n.type == VAL_NUM:
-            num_nodes.append(n)
+    Example:
+    3 + 4 -> 7
+    """
+    numerics, others = args
+    value = sum([n.value for n in numerics])
 
-    if len(num_nodes) > 1:
-        p.append((node, apply_plus_factors, num_nodes))
+    return nary_node('+', others + [Leaf(value)])
 
-    # Check if any variable multiplcations/divisions can be combined
-    def apply_identifiers(node, operands):
-        apply_constant(lambda x: )
-        return Leaf(leaves[0].value + leaves[1].value)
 
-    id_nodes = []
+def combine_orders(root, args):
+    """
+    Combine two identifier multiplications of any order in an n-ary plus.
 
-    for n in node:
-        # NUM *  + NUM -> NUM
-        if n.type == OP_MUL:
-            consts = get_factor_constants(n)
+    Example:
+    3x + 4x -> 7x
+    """
+    identifier, exponent, coeff0, coeff1, others = args
 
-            if len(consts) > 1:
-                id_nodes += 
+    coeff = coeff0 + coeff1
 
-    if len(num_nodes) > 1:
-        p.append((node, apply_plus_factors, num_nodes))
+    if not exponent:
+        # a ^ 0 -> 1
+        ident = Leaf(1)
+    elif exponent == 1:
+        # a ^ 1 -> a
+        ident = Leaf(identifier)
+    else:
+        # a ^ n -> a ^ n
+        ident = Node('^', Leaf(identifier), Leaf(exponent))
+
+    if coeff == 1:
+        combined = ident
+    else:
+        combined = Node('*', Leaf(coeff), ident)
+
+    return nary_node('+', others + [combined])
+
+
+def nary_node(operator, scope):
+    """
+    Create a binary expression tree for an n-ary operator. Takes the operator
+    and a list of expression nodes as arguments.
+    """
+    return scope[0] if len(scope) == 1 \
+           else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
 
-    return p
 
 rules = {
-        '+': [combine_plus_factors],
+        '+': [match_combine_factors],
         }

+ 86 - 0
tests/test_node.py

@@ -0,0 +1,86 @@
+import unittest
+
+from src.node import ExpressionNode as N, ExpressionLeaf as L
+
+
+class TestNode(unittest.TestCase):
+
+    def setUp(self):
+        self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
+
+    def test_replace_node(self):
+        inner = N('+', L(1), L(2))
+        node = N('+', inner, L(3))
+        replacement = N('-', L(4), L(5))
+        inner.replace(replacement)
+        self.assertEqual(str(node), '4 - 5 + 3')
+
+    def test_replace_leaf(self):
+        inner = N('+', L(1), L(2))
+        node = N('+', inner, L(3))
+        replacement = L(4)
+        inner.replace(replacement)
+        self.assertEqual(str(node), '4 + 3')
+
+    def test_is_power_true(self):
+        self.assertTrue(N('^', *self.l[:2]).is_power())
+        self.assertFalse(N('+', *self.l[:2]).is_power())
+
+    def test_is_nary(self):
+        self.assertTrue(N('+', *self.l[:2]).is_nary())
+        self.assertTrue(N('-', *self.l[:2]).is_nary())
+        self.assertTrue(N('*', *self.l[:2]).is_nary())
+        self.assertFalse(N('^', *self.l[:2]).is_nary())
+
+    def test_is_identifier(self):
+        self.assertTrue(L('a').is_identifier())
+        self.assertFalse(L(1).is_identifier())
+
+    def test_is_int(self):
+        self.assertTrue(L(1).is_int())
+        self.assertFalse(L(1.5).is_int())
+        self.assertFalse(L('a').is_int())
+
+    def test_is_float(self):
+        self.assertTrue(L(1.5).is_float())
+        self.assertFalse(L(1).is_float())
+        self.assertFalse(L('a').is_float())
+
+    def test_is_numeric(self):
+        self.assertTrue(L(1).is_numeric())
+        self.assertTrue(L(1.5).is_numeric())
+        self.assertFalse(L('a').is_numeric())
+
+    def test_get_order_identifier(self):
+        self.assertEqual(L('a').get_order(), ('a', 1, 1))
+
+    def test_get_order_None(self):
+        self.assertIsNone(L(1).get_order())
+
+    def test_get_order_power(self):
+        power = N('^', L('a'), L(2))
+        self.assertEqual(power.get_order(), ('a', 2, 1))
+
+    def test_get_order_coefficient_exponent_int(self):
+        times = N('*', L(3), N('^', L('a'), L(2)))
+        self.assertEqual(times.get_order(), ('a', 2, 3))
+
+    def test_get_order_coefficient_exponent_id(self):
+        times = N('*', L(3), N('^', L('a'), L('b')))
+        self.assertEqual(times.get_order(), ('a', 'b', 3))
+
+    def test_get_scope_binary(self):
+        plus = N('+', *self.l[:2])
+        self.assertEqual(plus.get_scope(), self.l[:2])
+
+    def test_get_scope_nested_left(self):
+        plus = N('+', N('+', *self.l[:2]), self.l[2])
+        self.assertEqual(plus.get_scope(), self.l[:3])
+
+    def test_get_scope_nested_right(self):
+        plus = N('+', self.l[0], N('+', *self.l[1:3]))
+        self.assertEqual(plus.get_scope(), self.l[:3])
+
+    def test_get_scope_nested_deep(self):
+        plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
+        self.assertEqual(plus.get_scope(), self.l)

+ 39 - 0
tests/test_rules.py

@@ -0,0 +1,39 @@
+import unittest
+
+from src.node import ExpressionNode as N, ExpressionLeaf as L
+from src.rules import match_combine_factors, combine_numerics, \
+        combine_orders, nary_node
+from src.possibilities import Possibility as P
+
+
+class TestRules(unittest.TestCase):
+
+    def test_nary_node_binary(self):
+        l0, l1 = L(1), L(2)
+        plus = N('+', l0, l1)
+        self.assertEqual(nary_node('+', [l0, l1]), plus)
+
+    def test_nary_node_ternary(self):
+        l0, l1, l2 = L(1), L(2), L(3)
+        plus = N('+', N('+', l0, l1), l2)
+        self.assertEqual(nary_node('+', [l0, l1, l2]), plus)
+
+    def test_match_combine_factors_numeric_simple(self):
+        l0, l1 = L(1), L(2)
+        plus = N('+', l0, l1)
+        p = match_combine_factors(plus)
+        self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, []))])
+
+    def test_match_combine_factors_numeric_combinations(self):
+        l0, l1, l2 = L(1), L(2), L(2)
+        plus = N('+', N('+', l0, l1), l2)
+        p = match_combine_factors(plus)
+        self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, [])),
+                                P(plus, combine_numerics, (l0, l2, [])),
+                                P(plus, combine_numerics, (l1, l2, []))])
+
+    def assertEqualPos(self, possibilities, expected):
+        for p, e in zip(possibilities, expected):
+            self.assertEqual(p.root, e.root)
+            self.assertEqual(p.handler, e.handler)
+            self.assertEqual(p.args, e.args)