Prechádzať zdrojové kódy

Built in support for writing n-ary operator nodes.

Sander Mathijs van Veen 14 rokov pred
rodič
commit
35dfadad12
2 zmenil súbory, kde vykonal 23 pridanie a 22 odobranie
  1. 14 18
      line.py
  2. 9 4
      tests/test_line.py

+ 14 - 18
line.py

@@ -31,11 +31,9 @@ def generate_line(root):
     int(x, -oo, oo)
     """
 
-    # FIXME: Binding order
     operators = [
             ('+', '-'),
-            ('mod', ),
-            ('*', '/'),
+            ('*', '/', 'mod'),
             ('^', )
             ]
     max_assoc = len(operators)
@@ -77,26 +75,24 @@ def generate_line(root):
 
         if is_operator(node):
             if arity == 1:
-                # Unary expression
+                # Unary operator
                 s += traverse(node[0])
-            elif arity == 2:
-                # Binary expression
-                left, right = map(traverse, node)
-
-                # Check if there is an assiociativity conflict on either side.
-                # If so, add parentheses
+            else:
+                # N-ary operator
                 node_assoc = assoc(node)
+                e = []
+
+                for child in node:
+                    exp = traverse(child)
 
-                if assoc(node[0]) < node_assoc:
-                    left = '(' + left + ')'
+                    # Check if there is an assiociativity conflict.
+                    # If so, add parentheses
+                    if assoc(child) < node_assoc:
+                        exp = '(' + exp + ')'
 
-                if assoc(node[1]) < node_assoc:
-                    right = '(' + right + ')'
+                    e.append(exp)
 
-                s = left + ' ' + s + ' ' + right
-            else:  # pragma: nocover
-                raise ValueError('arity = %d is currently not supported.' \
-                                 % arity)
+                s = (' ' + s + ' ').join(e)
         else:
             # Function
             s += '(' + ', '.join(map(traverse, node)) + ')'

+ 9 - 4
tests/test_line.py

@@ -15,22 +15,27 @@ class TestLine(unittest.TestCase):
     def test_simple(self):
         l0, l1 = Leaf(1), Leaf(2)
         plus = Node('+', l0, l1)
-        assert generate_line(plus) == '1 + 2'
+        self.assertEquals(generate_line(plus), '1 + 2')
 
     def test_parentheses(self):
         l0, l1 = Leaf(1), Leaf(2)
         plus = Node('+', l0, l1)
         times = Node('*', plus, plus)
-        assert generate_line(times) == '(1 + 2) * (1 + 2)'
+        self.assertEquals(generate_line(times), '(1 + 2) * (1 + 2)')
 
     def test_function(self):
         exp = Leaf('x')
         inf = Leaf('oo')
         minus_inf = Node('-', inf)
         integral = Node('int', exp, minus_inf, inf)
-        assert generate_line(integral) == 'int(x, -oo, oo)'
+        self.assertEquals(generate_line(integral), 'int(x, -oo, oo)')
 
     def test_mod(self):
         l0, l1 = Leaf(1), Leaf(2)
         mod = Node('mod', l1, l0)
-        assert generate_line(mod) == '2 mod 1'
+        self.assertEquals(generate_line(mod), '2 mod 1')
+
+    def test_n_ary(self):
+        l0, l1, l2 = Leaf(1), Leaf(2), Leaf(3)
+        plus = Node('+', l0, l1, l2)
+        self.assertEquals(generate_line(plus), '1 + 2 + 3')