Parcourir la source

Improved multiplication sign omission.

Taddeus Kroes il y a 14 ans
Parent
commit
01cf1342c3
2 fichiers modifiés avec 8 ajouts et 10 suppressions
  1. 4 9
      line.py
  2. 4 1
      tests/test_line.py

+ 4 - 9
line.py

@@ -166,16 +166,11 @@ def generate_line(root):
             # (a) * b -> (a)b
             # (a) * (b) -> (a)(b)
             # 2 * a -> 2a
-            left_paren = e[0][-1] == ')'
-            right_paren = e[1][0] == '('
+            l = e[0][-1]
+            r = e[1][0]
 
-            left_suited = left_paren or is_id(left) or is_int(left)
-            right_suited = right_paren
-
-            if not right_suited and not right.negated:
-                right_suited = is_id(right) or is_power(right)
-
-            if left_suited and right_suited:
+            if r == '(' or (l == ')' and r != '-') \
+                    or ((is_id(left) or is_int(left)) and r.isalpha()):
                 sep = ''
 
         exp = sep.join(e)

+ 4 - 1
tests/test_line.py

@@ -139,7 +139,7 @@ class TestLine(unittest.TestCase):
         mul = N('*', N('+', a, b), c)
         self.assertEquals(generate_line(mul), '(a + b)c')
         mul = N('*', N('+', a, b), l2)
-        self.assertEquals(generate_line(mul), '(a + b) * 2')
+        self.assertEquals(generate_line(mul), '(a + b)2')
 
         mul = N('*', N('+', a, b), N('+', c, l2))
         self.assertEquals(generate_line(mul), '(a + b)(c + 2)')
@@ -152,6 +152,9 @@ class TestLine(unittest.TestCase):
         mul = N('*', l2, N('^', a, l2))
         self.assertEquals(generate_line(mul), '2a ^ 2')
 
+        mul = N('*', l2, N('^', l2, a))
+        self.assertEquals(generate_line(mul), '2 * 2 ^ a')
+
     def test_plus_to_minus(self):
         plus = N('+', L(1), -L(2))
         self.assertEquals(generate_line(plus), '1 - 2')