Kaynağa Gözat

Removed node type from graph drawer.

Sander Mathijs van Veen 14 yıl önce
ebeveyn
işleme
d851ef2b2b
2 değiştirilmiş dosya ile 108 ekleme ve 34 silme
  1. 53 24
      graph.py
  2. 55 10
      tests/test_graph.py

+ 53 - 24
graph.py

@@ -1,7 +1,9 @@
 # vim: set fileencoding=utf-8 :
+# XXX Used in doctests (we should use them in the __main__ section below too).
+from node import Leaf, Node
 
 
-def generate_graph(root, node_type, separator=' ', verbose=False):
+def generate_graph(root, separator=' ', verbose=False):
     """
     Return a text-based, utf-8 graph of a tree-like structure. Each tree node
     is represented by a length-2 list. If a node has an attribute called
@@ -11,17 +13,17 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
     >>> l0, l1 = Leaf(0), Leaf(1)
     >>> n0 = Node('+', l0, l1)
     >>> l2 = Leaf(2)
-    >>> print generate_graph(n0, Node)
+    >>> print generate_graph(n0)
      +
     ╭┴╮
     0 1
     >>> n1 = Node('-', l2)
-    >>> print generate_graph(n1, Node)
+    >>> print generate_graph(n1)
     -
     2
     >>> n2 = Node('*', n0, n1)
-    >>> print generate_graph(n2, Node)
+    >>> print generate_graph(n2)
        *
      ╭─┴╮
      +  -
@@ -30,45 +32,55 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
     """
 
     node_width = {}
+    node_middle = {}
 
     separator_len = len(separator)
 
-    def calculate_width(node):
+    def calculate_node_sizes(node):
         title = node.title()
         title_len = len(title)
 
         # Leaves do not have children and therefore the length of its title is
         # the width of the leaf.
-        if not isinstance(node, node_type):
+        if not isinstance(node, Node):
             node_width[node] = title_len
+            node_middle[node] = int(title_len / 2)
             return title_len
 
+        node_len = len(node)
+
         width = 0
+        middle = 0
+        middle_pos = int(node_len / 2)
+
+        for i, child in enumerate(node):
+            tmp = calculate_node_sizes(child)
+
+            if i < middle_pos:
+                middle += tmp
 
-        for child in node:
-            width += calculate_width(child)
+            width += tmp
+
+        middle += max(middle_pos - int(node_len % 2 == 0), 0) * separator_len
 
         # Add a separator between each node (thus n - 1 separators).
         width += separator_len * (len(node) - 1)
 
-        # Odd numbers of children should be minus 1, since the middle child
-        # can be placed directly below the parent. With even numbers, there
-        # is no middle child, so the space below the parent cannot be used.
-        #if len(node) % 2 == 1:
-        #    width -= 1
-
         # If the title of the node is wider than the sum of its children, the
         # title's width should be used.
-        width = max(title_len, width)
+        if title_len > width:
+            width = title_len
+            middle = int(title_len / 2)
 
         # print 'width of "%s":' % node.title(), width
 
         node_width[node] = width
+        node_middle[node] = middle
 
         return width
 
     def format_lines(node):
-        if not isinstance(node, node_type):
+        if not isinstance(node, Node):
             # Leaf titles do not need to be centered, since the parent will
             # center those lines. And if there are no parents, the entire graph
             # consists of a single leaf, so in that case there still is no
@@ -98,11 +110,13 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
 
         line_width = node_width[node]
 
+        # TODO: substitute box_widths with node_width
         box_widths = [len(lines[0]) for lines in child_lines]
         node_len = len(node)
         middle_node = int(node_len / 2)
-        middle = sum([box_widths[i] for i in range(middle_node)]) \
-                 + max(middle_node - int(node_len % 2 == 0), 0) * separator_len
+        #middle = sum([box_widths[i] for i in range(middle_node)]) \
+        #         + max(middle_node - int(node_len % 2 == 0), 0) * separator_len
+        middle = node_middle[node]
 
         title_line = center_text(node.title(), line_width, middle)
 
@@ -142,7 +156,7 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
                             marker = tsplit_dn_sign
 
                         edge_line += center_text(marker, box_widths[i],
-                                                middle=0, left=dash_sign)
+                                                 middle=0, left=dash_sign)
         else:
             # n-ary operators (n is even)
             edge_line = ''
@@ -157,7 +171,8 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
                 if i < middle_node:
                     marker = (left_sign if i == 0 else tsplit_dn_sign)
                     edge_line += center_text(marker, box_widths[i],
-                                             middle=0, right=dash_sign)
+                                             middle=0,
+                                             right=dash_sign)
                 else:
                     if i == node_len - 1:
                         marker = right_sign
@@ -165,7 +180,8 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
                         marker = tsplit_dn_sign
 
                     edge_line += center_text(marker, box_widths[i],
-                                             middle=0, left=dash_sign)
+                                             middle=0,
+                                             left=dash_sign)
 
         try:
             assert len(title_line) == len(edge_line)
@@ -182,12 +198,12 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
         # Add the line of this node before all child lines.
         return [title_line, edge_line] + result
 
-    calculate_width(root)
+    calculate_node_sizes(root)
 
     #if verbose:
-    #    print '------- node_width ---------'
+    #    print '------- node_{width,middle} ---------'
     #    for node, width in node_width.iteritems():
-    #        print node.title(), 'width:', width
+    #        print node.title(), 'width:', width, 'middle:', node_middle[node]
 
     lines = format_lines(root)
 
@@ -201,6 +217,19 @@ def center_text(text, width, middle=0, left=' ', right=' '):
     >>> center_text('+', 15, 11)
     '           +   '
+    >>> left = center_text('╭'.decode('utf-8'), 11, 8, right='─'.decode('utf-8'))
+    >>> len(left) == 11
+    True
+    >>> right = center_text('╮'.decode('utf-8'), 3, 2, left='─'.decode('utf-8'))
+    >>> len(right) == 3
+    True
+    >>> edge_line = left + '┴'.decode('utf-8') + right
+    >>> len(edge_line) == 15
+    True
+    >>> title_line = center_text('+', 15, 11)
+    >>> print '|%s|\\n|%s|' % (title_line, edge_line.encode('utf-8'))
+    |           +   |
+    |        ╭──┴──╮|
     """
     text_len = len(text)
     text_mid = text_len / 2

+ 55 - 10
tests/test_graph.py

@@ -15,7 +15,7 @@ class TestGraph(unittest.TestCase):
 
     def test_simple_unary(self):
         uminus = Node('-', self.l1)
-        g = generate_graph(uminus, Node)
+        g = generate_graph(uminus)
         self.assertEqualGraphs(g, """
         -
@@ -24,7 +24,7 @@ class TestGraph(unittest.TestCase):
 
     def test_simple_binary(self):
         plus = Node('+', self.l0, self.l1)
-        g = generate_graph(plus, Node)
+        g = generate_graph(plus)
         self.assertEqualGraphs(g, """
          +
         ╭┴╮
@@ -33,7 +33,7 @@ class TestGraph(unittest.TestCase):
 
     def test_multichar_unary(self):
         uminus = Node('-', self.multi)
-        g = generate_graph(uminus, Node)
+        g = generate_graph(uminus)
         self.assertEqualGraphs(g, """
          -
@@ -42,7 +42,7 @@ class TestGraph(unittest.TestCase):
 
     def test_multichar_binary(self):
         plus = Node('+', self.multi, self.l1)
-        g = generate_graph(plus, Node)
+        g = generate_graph(plus)
         self.assertEqualGraphs(g, """
             +
          ╭──┴╮
@@ -54,7 +54,7 @@ class TestGraph(unittest.TestCase):
         inf = Leaf('o')
         minus_inf = Node('-', Leaf('L'))
         integral = Node('n', exp, minus_inf, inf)
-        g = generate_graph(integral, Node, verbose=True)
+        g = generate_graph(integral)
         self.assertEqualGraphs(g, """
           n
         ╭─┼─╮
@@ -68,7 +68,7 @@ class TestGraph(unittest.TestCase):
         inf = Leaf('oo')
         minus_inf = Node('-', Leaf('LL'))
         integral = Node('int', exp, minus_inf, inf)
-        g = generate_graph(integral, Node, verbose=True)
+        g = generate_graph(integral)
         self.assertEqualGraphs(g, """
          int
         ╭─┼──╮
@@ -82,7 +82,7 @@ class TestGraph(unittest.TestCase):
         minus_99 = Node('-', Leaf('99'))
         minus_inf = Node('-', Leaf('oo'))
         integral = Node('int', exp, minus_inf, minus_99)
-        g = generate_graph(integral, Node, verbose=True)
+        g = generate_graph(integral)
         self.assertEqualGraphs(g, """
          int
         ╭─┼──╮
@@ -96,7 +96,7 @@ class TestGraph(unittest.TestCase):
         minus_99 = Node('-', Leaf('99'))
         ten = Leaf('10')
         integral = Node('int', exp, ten, minus_99)
-        g = generate_graph(integral, Node, verbose=True)
+        g = generate_graph(integral)
         self.assertEqualGraphs(g, """
          int
         ╭─┼──╮
@@ -108,7 +108,7 @@ class TestGraph(unittest.TestCase):
     def test_quaternary(self):
         a, b, c, d = Leaf(0), Leaf(1), Leaf(2), Leaf(3)
         sum_node = Node('sum', a, b, c, d)
-        g = generate_graph(sum_node, Node, verbose=True)
+        g = generate_graph(sum_node)
         self.assertEqualGraphs(g, """
           sum
         ╭─┬┴┬─╮
@@ -118,13 +118,58 @@ class TestGraph(unittest.TestCase):
     def test_quinary(self):
         a, b, c, d, e = Leaf(0), Leaf(1), Leaf(2), Leaf(3), Leaf(4)
         sum_node = Node('sum', a, b, c, d, e)
-        g = generate_graph(sum_node, Node, verbose=True)
+        g = generate_graph(sum_node)
         self.assertEqualGraphs(g, """
            sum
         ╭─┬─┼─┬─╮
         0 1 2 3 4
         """)
 
+    def test_expression_small(self):
+        l0 = Leaf(3)
+        l1 = Leaf(4)
+        l2 = Leaf(5)
+        l3 = Leaf(7)
+
+        n0 = Node('+', l0, l1)
+        n1 = Node('+', l2, l3)
+        n2 = Node('*', n0, n1)
+
+        g = generate_graph(n2)
+        self.assertEqualGraphs(g, """
+           *
+         ╭─┴─╮
+         +   +
+        ╭┴╮ ╭┴╮
+        3 4 5 7
+        """)
+
+    def test_expression_larger(self):
+        a = Leaf(3)
+        b = Leaf(4)
+        c = Leaf(5)
+        d = Leaf(7)
+
+        ac = Node('*', a, c)
+        ad = Node('*', a, d)
+        bc = Node('*', b, c)
+        bd = Node('*', b, d)
+
+        root = Node('+', Node('+', Node('+', ac, ad), bc), bd)
+
+        g = generate_graph(root, verbose=True)
+        self.assertEqualGraphs(g, """
+                   +
+               ╭───┴─╮
+               +     *
+           ╭───┴─╮  ╭┴╮
+           +     *  4 7
+         ╭─┴─╮  ╭┴╮
+         *   *  4 5
+        ╭┴╮ ╭┴╮
+        3 5 3 7
+        """)
+
     def strip(self, str):
         return str.replace('\n        ', '\n')[1:-1]