Просмотр исходного кода

Added check for redundant multiplication operator.

Taddeus Kroes 14 лет назад
Родитель
Сommit
567e8f5c26
2 измененных файлов с 77 добавлено и 39 удалено
  1. 28 10
      line.py
  2. 49 29
      tests/test_line.py

+ 28 - 10
line.py

@@ -72,26 +72,28 @@ def generate_line(root):
         if not node:
             return '<empty expression>'
 
-        s = node.title()
+        op = node.title()
 
         if not node.nodes:
-            return s
+            return op
 
         arity = len(node)
 
         if is_operator(node):
             if arity == 1:
                 # Unary operator
-                s += traverse(node[0])
+                result = op + traverse(node[0])
             else:
                 # N-ary operator
                 node_pred = pred(node)
-                e = []
+                result = ''
+                sep = ' ' + op + ' '
+                prev = None
 
                 for child in node:
                     exp = traverse(child)
 
-                    # Check if there is an precedence conflict.
+                    # Check if there is a precedence conflict.
                     # If so, add parentheses
                     child_pred = pred(child)
 
@@ -99,13 +101,29 @@ def generate_line(root):
                             (child_pred == node_pred and s != child.title()):
                         exp = '(' + exp + ')'
 
-                    e.append(exp)
-
-                s = (' ' + s + ' ').join(e)
+                    # a * b -> ab
+                    # a * 2 -> a * 2
+                    # a * (...) -> a(...)
+                    # 2 * a -> 2a
+                    if prev and not (op == '*' \
+                                     and ((is_id(prev) and not is_int(child)) \
+                                          or (is_int(prev) and is_id(child)))):
+                        result += sep
+
+                    result += exp
+                    prev = child
         else:
             # Function
-            s += '(' + ', '.join(map(traverse, node)) + ')'
+            result = op + '(' + ', '.join(map(traverse, node)) + ')'
 
-        return s
+        return result
 
     return traverse(root)
+
+
+def is_id(node):
+    return isinstance(node, Leaf) and not node.title().isdigit()
+
+
+def is_int(node):
+    return isinstance(node, Leaf) and node.title().isdigit()

+ 49 - 29
tests/test_line.py

@@ -1,6 +1,6 @@
 import unittest
 
-from node import Node, Leaf
+from node import Node as N, Leaf as L
 from line import generate_line
 
 
@@ -13,14 +13,14 @@ class TestLine(unittest.TestCase):
         pass
 
     def test_simple(self):
-        l0, l1 = Leaf(1), Leaf(2)
-        plus = Node('+', l0, l1)
+        l0, l1 = L(1), L(2)
+        plus = N('+', l0, l1)
         self.assertEquals(generate_line(plus), '1 + 2')
 
     def test_parentheses(self):
-        l0, l1 = Leaf(1), Leaf(2)
-        plus = Node('+', l0, l1)
-        times = Node('*', plus, plus)
+        l0, l1 = L(1), L(2)
+        plus = N('+', l0, l1)
+        times = N('*', plus, plus)
         self.assertEquals(generate_line(times), '(1 + 2) * (1 + 2)')
 
     def test_parentheses_equal_precedence(self):
@@ -35,41 +35,61 @@ class TestLine(unittest.TestCase):
         self.assertEquals(generate_line(plus), '1 + 2 + 3')
 
     def test_function(self):
-        exp = Leaf('x')
-        inf = Leaf('oo')
-        minus_inf = Node('-', inf)
-        integral = Node('int', exp, minus_inf, inf)
+        exp = L('x')
+        inf = L('oo')
+        minus_inf = N('-', inf)
+        integral = N('int', exp, minus_inf, inf)
         self.assertEquals(generate_line(integral), 'int(x, -oo, oo)')
 
     def test_mod(self):
-        l0, l1 = Leaf(1), Leaf(2)
-        mod = Node('mod', l1, l0)
+        l0, l1 = L(1), L(2)
+        mod = N('mod', l1, l0)
         self.assertEquals(generate_line(mod), '2 mod 1')
 
-    def test_n_ary(self):
-        l0, l1, l2 = Leaf(1), Leaf(2), Leaf(3)
-        plus = Node('+', l0, l1, l2)
+    def test_multiplication_identifiers(self):
+        a, b = L('a'), L('b')
+        mul = N('*', a, b)
+        self.assertEquals(generate_line(mul), 'ab')
+
+    def test_multiplication_constant_identifier(self):
+        l0, a = L(2), L('a')
+        mul = N('*', l0, a)
+        self.assertEquals(generate_line(mul), '2a')
+
+    def test_multiplication_identifier_constant(self):
+        l0, a = L(2), L('a')
+        mul = N('*', a, l0)
+        self.assertEquals(generate_line(mul), 'a * 2')
+
+    def test_multiplication_constants(self):
+        l0, l1 = L(1), L(2)
+        mul = N('*', l0, l1)
+        self.assertEquals(generate_line(mul), '1 * 2')
+
+    def test_nary(self):
+        l0, l1, l2 = L(1), L(2), L(3)
+        plus = N('+', l0, l1, l2)
         self.assertEquals(generate_line(plus), '1 + 2 + 3')
 
     def test_pow_basic(self):
-        a, b, c = Leaf('a'), Leaf('b'), Leaf('c')
-        node_pow = Node('^', a, Node('+', b, c))
+        a, b, c = L('a'), L('b'), L('c')
+        node_pow = N('^', a, N('+', b, c))
         self.assertEquals(generate_line(node_pow), 'a ^ (b + c)')
 
     def test_pow_intermediate1(self):
         # expression: (a(b+c))^(d+e)
-        a, b, c, d, e = Leaf('a'), Leaf('b'), Leaf('c'), Leaf('d'), Leaf('e')
-        node_bc = Node('+', b, c)
-        node_de = Node('+', d, e)
-        node_mul = Node('*', a, node_bc)
-        node_pow = Node('^', node_mul, node_de)
-        self.assertEquals(generate_line(node_pow), '(a * (b + c)) ^ (d + e)')
+        a, b, c, d, e = L('a'), L('b'), L('c'), L('d'), L('e')
+        node_bc = N('+', b, c)
+        node_de = N('+', d, e)
+        node_mul = N('*', a, node_bc)
+        node_pow = N('^', node_mul, node_de)
+        self.assertEquals(generate_line(node_pow), '(a(b + c)) ^ (d + e)')
 
     def test_pow_intermediate2(self):
         # expression: a(b+c)^(d+e)
-        a, b, c, d, e = Leaf('a'), Leaf('b'), Leaf('c'), Leaf('d'), Leaf('e')
-        node_bc = Node('+', b, c)
-        node_de = Node('+', d, e)
-        node_pow = Node('^', node_bc, node_de)
-        node_mul = Node('*', a, node_pow)
-        self.assertEquals(generate_line(node_mul), 'a * (b + c) ^ (d + e)')
+        a, b, c, d, e = L('a'), L('b'), L('c'), L('d'), L('e')
+        node_bc = N('+', b, c)
+        node_de = N('+', d, e)
+        node_pow = N('^', node_bc, node_de)
+        node_mul = N('*', a, node_pow)
+        self.assertEquals(generate_line(node_mul), 'a(b + c) ^ (d + e)')