Przeglądaj źródła

Merged changes of Sander.

Taddeus Kroes 14 lat temu
rodzic
commit
dad8585cc2
5 zmienionych plików z 80 dodań i 65 usunięć
  1. 42 27
      src/node.py
  2. 11 7
      src/parser.py
  3. 7 0
      src/possibilities.py
  4. 20 8
      src/rules/poly.py
  5. 0 23
      tests/test_rules.py

+ 42 - 27
src/node.py

@@ -49,7 +49,30 @@ OP_MAP = {
         }
 
 
-class ExpressionNode(Node):
+class ExpressionBase(object):
+    def is_leaf(self):
+        return self.type != TYPE_OPERATOR
+
+    def is_power(self):
+        return not self.is_leaf() and self.op == OP_POW
+
+    def is_nary(self):
+        return not self.is_leaf() and self.op in [OP_ADD, OP_SUB, OP_MUL]
+
+    def is_identifier(self):
+        return self.is_leaf() and self.type & TYPE_IDENTIFIER
+
+    def is_int(self):
+        return self.is_leaf() and self.type & TYPE_INTEGER
+
+    def is_float(self):
+        return self.is_leaf() and self.type & TYPE_FLOAT
+
+    def is_numeric(self):
+        return self.is_leaf() and self.type & (TYPE_FLOAT | TYPE_INTEGER)
+
+
+class ExpressionNode(Node, ExpressionBase):
     def __init__(self, *args, **kwargs):
         super(ExpressionNode, self).__init__(*args, **kwargs)
         self.type = TYPE_OPERATOR
@@ -70,24 +93,25 @@ class ExpressionNode(Node):
         node.parent = self.parent
         self.parent = None
 
-    def is_power(self):
-        return self.op == OP_POW
-
-    def is_nary(self):
-        return self.op in [OP_ADD, OP_SUB, OP_MUL]
-
     def get_order(self):
-        if self.is_power() and self[0].is_identifier() \
-                and isinstance(self[1], Leaf):
+        if self.is_power() and self[0].is_identifier() and self[1].is_leaf():
+            # a ^ 3
             return (self[0].value, self[1].value, 1)
 
+        if not self.op & OP_MUL:
+            return
+
         for n0, n1 in [(0, 1), (1, 0)]:
-            if self[n0].is_numeric() and not isinstance(self[n1], Leaf) \
-                    and self[n1].is_power():
-                coeff, power = self
+            if self[n0].is_numeric():
+                if self[n1].is_identifier():
+                    # 2 * a
+                    return (self[n1].value, 1, self[n0].value)
+                elif self[n1].is_power():
+                    # 2 * a ^ 3
+                    coeff, power = self
 
-                if power[0].is_identifier() and isinstance(power[1], Leaf):
-                    return (power[0].value, power[1].value, coeff.value)
+                    if power[0].is_identifier() and power[1].is_leaf():
+                        return (power[0].value, power[1].value, coeff.value)
 
     def get_scope(self):
         scope = []
@@ -100,8 +124,11 @@ class ExpressionNode(Node):
 
         return scope
 
+    def get_scope_except(self, *args):
+        return list(set(self.get_scope()) - set(args))
+
 
-class ExpressionLeaf(Leaf):
+class ExpressionLeaf(Leaf, ExpressionBase):
     def __init__(self, *args, **kwargs):
         super(ExpressionLeaf, self).__init__(*args, **kwargs)
 
@@ -120,18 +147,6 @@ class ExpressionLeaf(Leaf):
         node.parent = self.parent
         self.parent = None
 
-    def is_identifier(self):
-        return self.type & TYPE_IDENTIFIER
-
-    def is_int(self):
-        return self.type & TYPE_INTEGER
-
-    def is_float(self):
-        return self.type & TYPE_FLOAT
-
-    def is_numeric(self):
-        return self.type & (TYPE_FLOAT | TYPE_INTEGER)
-
 
 if __name__ == '__main__':  # pragma: nocover
     l0 = ExpressionLeaf(3)

+ 11 - 7
src/parser.py

@@ -19,8 +19,9 @@ sys.path.insert(1, EXTERNAL_MODS)
 from pybison import BisonParser, BisonSyntaxError
 from graph_drawing.graph import generate_graph
 
-from node import TYPE_OPERATOR, OP_ADD, OP_MUL, OP_SUB
+from node import TYPE_OPERATOR
 from rules import RULES
+from possibilities import filter_duplicates
 
 
 ## Check for n-ary operator in child nodes
@@ -90,7 +91,10 @@ class Parser(BisonParser):
     def hook_read_before(self):
         if self.interactive and self.possibilities:
             print 'possibilities:'
-            print self.possibilities
+            items = filter_duplicates(self.possibilities)
+            print '  ' + '\n  '.join(map(str, items))
+
+        self.possibilities = []
 
     def hook_read_after(self, data):
         """
@@ -152,16 +156,17 @@ class Parser(BisonParser):
         return data
 
     def hook_handler(self, target, option, names, values, retval):
-        if not retval or retval.type not in RULES:
+        if target in ['exp', 'line', 'input'] or not retval \
+                or retval.type != TYPE_OPERATOR or retval.op not in RULES:
             return retval
 
-        for handler in RULES[retval.type]:
+        for handler in RULES[retval.op]:
             self.possibilities.extend(handler(retval))
 
         return retval
 
-    def hook_run(self, filename, retval):
-        return retval
+    #def hook_run(self, filename, retval):
+    #    return retval
 
     # ---------------------------------------------------------------
     # These methods are the python handlers for the bison targets.
@@ -398,5 +403,4 @@ def main():
     if interactive:
         print
 
-
     return node

+ 7 - 0
src/possibilities.py

@@ -3,3 +3,10 @@ class Possibility(object):
         self.root = root
         self.handler = handler
         self.args = args
+
+    def __str__(self):
+        return '<possibility root="%s" handler=%s args=%s>' \
+                % (self.root, self.handler, self.args)
+
+    def __repr__(self):
+        return str(self)

+ 20 - 8
src/rules/poly.py

@@ -1,6 +1,6 @@
 from itertools import combinations
 
-from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD
+from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD, OP_MUL
 from ..possibilities import Possibility as P
 
 
@@ -8,8 +8,8 @@ def match_expand(node):
     """
     a * (b + c) -> ab + ac
     """
-    if node.type != TYPE_OPERATOR or not node.op & OP_MUL:
-        return []
+    assert node.type == TYPE_OPERATOR
+    assert node.op & OP_MUL
 
     p = []
 
@@ -33,6 +33,7 @@ def match_expand(node):
 
     return p
 
+
 def expand_single(root, args):
     """
     Combine a leaf (left) multiplied with an addition of two expressions
@@ -41,15 +42,25 @@ def expand_single(root, args):
     >>> a * (b + c) -> a * b + a * c
     """
     left, right = args
-    others = list(set(root.get_scope()) - {left, right})
+    scope = root.get_scope_except(right)
+
+    replacement = Node('+', Node('*', left, right[0]), Node('*', left, right[1]))
+
+    for i, n in enumerate(scope):
+        if n == left:
+            scope[i] = replacement
+            break
+
+    return nary_node('*', scope)
+
 
 def match_combine_factors(node):
     """
     n + exp + m -> exp + (n + m)
     k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
     """
-    if node.type != TYPE_OPERATOR or not node.op & OP_ADD:
-        return []
+    assert node.type == TYPE_OPERATOR
+    assert node.op & OP_ADD
 
     p = []
 
@@ -62,17 +73,18 @@ def match_combine_factors(node):
     orders = []
 
     for n in node.get_scope():
-        if node.type == TYPE_OPERATOR:
+        if n.type == TYPE_OPERATOR:
             order = n.get_order()
 
             if order:
-                orders += order
+                orders.append(order)
         else:
             if n.is_numeric():
                 numerics.append(n)
             elif n.is_identifier():
                 orders.append((n.value, 1, 1))
 
+    print 'numerics:', numerics
     if len(numerics) > 1:
         for num0, num1 in combinations(numerics, 2):
             p.append(P(node, combine_numerics, (num0, num1)))

+ 0 - 23
tests/test_rules.py

@@ -1,10 +1,7 @@
 import unittest
 
 from src.node import ExpressionNode as N, ExpressionLeaf as L
-from src.rules.poly import match_combine_factors, combine_numerics, \
-        combine_orders
 from src.rules.utils import nary_node
-from src.possibilities import Possibility as P
 
 
 class TestRules(unittest.TestCase):
@@ -18,23 +15,3 @@ class TestRules(unittest.TestCase):
         l0, l1, l2 = L(1), L(2), L(3)
         plus = N('+', N('+', l0, l1), l2)
         self.assertEqual(nary_node('+', [l0, l1, l2]), plus)
-
-    def test_match_combine_factors_numeric_simple(self):
-        l0, l1 = L(1), L(2)
-        plus = N('+', l0, l1)
-        p = match_combine_factors(plus)
-        self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1))])
-
-    def test_match_combine_factors_numeric_combinations(self):
-        l0, l1, l2 = L(1), L(2), L(2)
-        plus = N('+', N('+', l0, l1), l2)
-        p = match_combine_factors(plus)
-        self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1)),
-                                P(plus, combine_numerics, (l0, l2)),
-                                P(plus, combine_numerics, (l1, l2))])
-
-    def assertEqualPos(self, possibilities, expected):
-        for p, e in zip(possibilities, expected):
-            self.assertEqual(p.root, e.root)
-            self.assertEqual(p.handler, e.handler)
-            self.assertEqual(p.args, e.args)