Browse Source

Multiplication sign is not omitted anymore for multiplication of equal identifiers

Taddeus Kroes 13 years ago
parent
commit
c8ecf7081c
2 changed files with 21 additions and 12 deletions
  1. 16 12
      line.py
  2. 5 0
      tests/test_line.py

+ 16 - 12
line.py

@@ -173,6 +173,7 @@ def generate_line(root):
             left = rightmost_node(left)
 
         # a * b -> ab
+        # a * a -> a * a
         # a * 2 -> a * 2
         # a * (b) -> a(b)
         # (a) * b -> (a)b
@@ -181,18 +182,21 @@ def generate_line(root):
         # a * sin(b) -> a sin(b)
         left_char = content[left][-1]
         right_char = content[right][0]
-        left_paren = lparens or left_char in ')]}'
-        right_paren = rparens or right_char in '([{'
-        right_alpha = right_char.isalpha()
-        left_simple = is_id(left) or is_int(left)
-
-        if left_paren or (right_paren and left_simple) \
-                or (is_id(left) and is_id(right)) \
-                or (is_int(left) and right_alpha):
-            return ''
-
-        if is_id(left) and right_alpha:
-            return ' '
+
+        if lparens or rparens or left_char != right_char:
+            left_paren = lparens or left_char in ')]}'
+            right_paren = rparens or right_char in '([{'
+            left_alpha = left_char.isalpha()
+            right_alpha = right_char.isalpha()
+            left_simple = is_id(left) or is_int(left)
+
+            if left_paren or (right_paren and left_simple) \
+                    or (is_id(left) and is_id(right)) \
+                    or (is_int(left) and right_alpha):
+                return ''
+
+            if is_id(left) and right_alpha:
+                return ' '
 
         return ' * '
 

+ 5 - 0
tests/test_line.py

@@ -138,6 +138,11 @@ class TestLine(unittest.TestCase):
         mul = N('*', N('*', a, l2), b)
         self.assertEquals(generate_line(mul), 'a * 2b')
 
+        mul = N('*', a, a)
+        self.assertEquals(generate_line(mul), 'a * a')
+        mul = N('*', N('+', a, b), N('+', b, a))
+        self.assertEquals(generate_line(mul), '(a + b)(b + a)')
+
         mul = -N('*', N('*', a, b), c)
         self.assertEquals(generate_line(mul), '-abc')
         mul = N('*', N('*', -a, b), c)