Browse Source

Fixed omission of n-ary multiplication sign in line generator.

Taddeus Kroes 14 years ago
parent
commit
df5100b220
2 changed files with 35 additions and 9 deletions
  1. 18 9
      line.py
  2. 17 0
      tests/test_line.py

+ 18 - 9
line.py

@@ -88,12 +88,12 @@ def generate_line(root):
                 node_pred = pred(node)
                 result = ''
                 sep = ' ' + op + ' '
-                prev = None
+                e = []
 
                 for child in node:
                     exp = traverse(child)
 
-                    # Check if there is a precedence conflict.
+                    # Check if there is a precedence conflict
                     # If so, add parentheses
                     child_pred = pred(child)
 
@@ -101,17 +101,26 @@ def generate_line(root):
                             (child_pred == node_pred and op != child.title()):
                         exp = '(' + exp + ')'
 
+                    e.append(exp)
+
+                # Check if a multiplication sign is nessecary
+                if op == '*':
+                    left, right = node
+
+                    # Get the previous multiplcation element if the arity is #
+                    # greater than 2
+                    if left.title() == '*':
+                        left = left[1]
+
                     # a * b -> ab
                     # a * 2 -> a * 2
-                    # a * (...) -> a(...)
+                    # a * (b) -> a(b)
                     # 2 * a -> 2a
-                    if prev and not (op == '*' \
-                                     and ((is_id(prev) and not is_int(child)) \
-                                          or (is_int(prev) and is_id(child)))):
-                        result += sep
+                    if (is_id(left) and (is_id(right) or e[1][0] == '(')) \
+                            or (is_int(left) and is_id(right)):
+                        sep = ''
 
-                    result += exp
-                    prev = child
+                result += sep.join(e)
         else:
             # Function
             result = op + '(' + ', '.join(map(traverse, node)) + ')'

+ 17 - 0
tests/test_line.py

@@ -90,3 +90,20 @@ class TestLine(unittest.TestCase):
         node_pow = N('^', node_bc, node_de)
         node_mul = N('*', a, node_pow)
         self.assertEquals(generate_line(node_mul), 'a(b + c) ^ (d + e)')
+
+    def test_multiplication_sign(self):
+        a, b, c, l2 = L('a'), L('b'), L('c'), L(2)
+        mul = N('*', a, b)
+        self.assertEquals(generate_line(mul), 'ab')
+
+        mul = N('*', mul, c)
+        self.assertEquals(generate_line(mul), 'abc')
+
+        mul = N('*', a, N('-', b, c))
+        self.assertEquals(generate_line(mul), 'a(b - c)')
+
+        mul = N('*', l2, a)
+        self.assertEquals(generate_line(mul), '2a')
+
+        mul = N('*', a, l2)
+        self.assertEquals(generate_line(mul), 'a * 2')