Selaa lähdekoodia

Added 'function' writing to graph-to-line writer.

Sander Mathijs van Veen 14 vuotta sitten
vanhempi
sitoutus
c123d2c411
2 muutettua tiedostoa jossa 51 lisäystä ja 26 poistoa
  1. 49 25
      line.py
  2. 2 1
      node.py

+ 49 - 25
line.py

@@ -4,30 +4,49 @@ def generate_line(root, node_type):
     parentheses.
     parentheses.
 
 
     >>> from node import Node, Leaf
     >>> from node import Node, Leaf
-    >>> #from graph import generate_graph
+    >>> from graph import generate_graph
+
     >>> l0, l1 = Leaf(1), Leaf(2)
     >>> l0, l1 = Leaf(1), Leaf(2)
-    >>> n0 = Node('+', l0, l1)
-    >>> print generate_line(n0, Node)
+    >>> plus = Node('+', l0, l1)
+    >>> print generate_line(plus, Node)
     1 + 2
     1 + 2
-    >>> n1 = Node('+', l0, l1)
-    >>> n2 = Node('*', n0, n1)
-    >>> print generate_line(n2, Node)
+
+    >>> plus2 = Node('+', l0, l1)
+    >>> times = Node('*', plus, plus2)
+    >>> print generate_line(times, Node)
     (1 + 2) * (1 + 2)
     (1 + 2) * (1 + 2)
+
     >>> l2 = Leaf(3)
     >>> l2 = Leaf(3)
-    >>> n3 = Node('-', l2)
-    >>> n4 = Node('*', n1, n3)
-    >>> print generate_line(n4, Node)
+    >>> uminus = Node('-', l2)
+    >>> times = Node('*', plus, uminus)
+    >>> print generate_line(times, Node)
     (1 + 2) * -3
     (1 + 2) * -3
-    >>> #integral = Node('int', n0, n1)
+
+    >>> exp = Leaf('x')
+    >>> inf = Leaf('oo')
+    >>> minus_inf = Node('-', inf)
+    >>> integral = Node('int', exp, minus_inf, inf)
+    >>> print generate_line(integral, Node)
+    int(x, -oo, oo)
     """
     """
 
 
     operators = [
     operators = [
             ('+', '-'),
             ('+', '-'),
+            ('mod', ),
             ('*', '/'),
             ('*', '/'),
             ('^', )
             ('^', )
             ]
             ]
     max_assoc = len(operators)
     max_assoc = len(operators)
 
 
+    def is_operator(node):
+        """
+        Check if a given node is an operator (otherwise, it's a function).
+        """
+        label = node.title()
+        either = lambda a, b: a or b
+
+        return reduce(either, map(lambda x: label in x, operators))
+
     def assoc(node):
     def assoc(node):
         """
         """
         Get the associativity of an operator node.
         Get the associativity of an operator node.
@@ -45,8 +64,7 @@ def generate_line(root, node_type):
         """
         """
         The expression tree is traversed using preorder traversal:
         The expression tree is traversed using preorder traversal:
         1. Visit the root
         1. Visit the root
-        2. Traverse the left subtree
-        3. Traverse the right subtree
+        2. Traverse the subtrees in left-to-right order
         """
         """
         s = node.title()
         s = node.title()
 
 
@@ -55,22 +73,28 @@ def generate_line(root, node_type):
 
 
         arity = len(node)
         arity = len(node)
 
 
-        if arity == 1:
-            # Unary expression
-            s += traverse(node[0])
-        elif arity == 2:
-            # Binary expression
-            left, right = map(traverse, node)
+        if is_operator(node):
+            if arity == 1:
+                # Unary expression
+                s += traverse(node[0])
+            elif arity == 2:
+                # Binary expression
+                left, right = map(traverse, node)
+
+                # Check if there is an assiociativity conflict on either side.
+                # If so, add parentheses
+                node_assoc = assoc(node)
 
 
-            # Check if there is an assiociativity conflict on either side. If
-            # so, add parentheses
-            if assoc(node[0]) < assoc(node):
-                left = '(' + left + ')'
+                if assoc(node[0]) < node_assoc:
+                    left = '(' + left + ')'
 
 
-            if assoc(node[1]) < assoc(node):
-                right = '(' + right + ')'
+                if assoc(node[1]) < node_assoc:
+                    right = '(' + right + ')'
 
 
-            s = left + ' ' + s + ' ' + right
+                s = left + ' ' + s + ' ' + right
+        else:
+            # Function
+            s += '(' + ', '.join(map(traverse, node.nodes)) + ')'
 
 
         return s
         return s
 
 

+ 2 - 1
node.py

@@ -1,8 +1,9 @@
 # vim: set fileencoding=utf-8 :
 # vim: set fileencoding=utf-8 :
 
 
+
 class Node(object):
 class Node(object):
     def __init__(self, label, *nodes):
     def __init__(self, label, *nodes):
-        self.label, self.nodes = label, nodes
+        self.label, self.nodes = label, list(nodes)
 
 
     def __getitem__(self, n):
     def __getitem__(self, n):
         return self.nodes[n]
         return self.nodes[n]