瀏覽代碼

Successfully built n-ary graph drawer.

Sander Mathijs van Veen 14 年之前
父節點
當前提交
7e89101870
共有 2 個文件被更改,包括 154 次插入108 次删除
  1. 120 91
      graph.py
  2. 34 17
      tests/test_graph.py

+ 120 - 91
graph.py

@@ -2,7 +2,7 @@
 from node import Node, Leaf
 
 
-def generate_graph(root, node_type, separator=' '):
+def generate_graph(root, node_type, separator=' ', verbose=False):
     """
     Return a text-based, utf-8 graph of a tree-like structure. Each tree node
     is represented by a length-2 list. If a node has an attribute called
@@ -11,11 +11,11 @@ def generate_graph(root, node_type, separator=' '):
 
     >>> l0, l1 = Leaf(0), Leaf(1)
     >>> n0 = Node('+', l0, l1)
+    >>> l2 = Leaf(2)
     >>> print generate_graph(n0, Node)
-     + 
+     +
     ╭┴╮
     0 1
-    >>> l2 = Leaf(2)
     >>> n1 = Node('-', l2)
     >>> print generate_graph(n1, Node)
     -
@@ -23,8 +23,8 @@ def generate_graph(root, node_type, separator=' '):
     2
     >>> n2 = Node('*', n0, n1)
     >>> print generate_graph(n2, Node)
-      +  
-     ╭─╮
+       *
+     ╭─
      +  -
     ╭┴╮ │
     0 1 2
@@ -50,28 +50,30 @@ def generate_graph(root, node_type, separator=' '):
             width += calculate_width(child)
 
         # Add a separator between each node (thus n - 1 separators).
-        width += separator_len * len(node) - 1
+        width += separator_len * (len(node) - 1)
 
         # Odd numbers of children should be minus 1, since the middle child
         # can be placed directly below the parent. With even numbers, there
         # is no middle child, so the space below the parent cannot be used.
-        if len(node) % 2 == 1:
-            width -= 1
+        #if len(node) % 2 == 1:
+        #    width -= 1
 
         # If the title of the node is wider than the sum of its children, the
         # title's width should be used.
         width = max(title_len, width)
 
+        # print 'width of "%s":' % node.title(), width
+
         node_width[node] = width
 
         return width
 
     def node_middle(node):
         node_len = len(node)
-        node_mid = node_len / 2
+        node_mid = int(node_len / 2)
 
-        if node_len % 2 == 1:
-            node_mid += 1
+        #if node_len % 2 == 1:
+        #    node_mid += 1
 
         middle = 0
 
@@ -79,9 +81,9 @@ def generate_graph(root, node_type, separator=' '):
             if i == node_mid:
                 break
 
-            middle += separator_len + node_width[child]
+            middle += node_width[child]
 
-        return middle - separator_len
+        return middle + max(0, separator_len * (node_mid - 1))
 
     def format_lines(node):
         if not isinstance(node, node_type):
@@ -94,111 +96,138 @@ def generate_graph(root, node_type, separator=' '):
         # At least one child, otherwise it would be a leaf.
         assert node[0]
 
-        line_width = node_width[node]
-
-        lines = format_lines(node[0])
+        child_lines = [format_lines(child) for child in node]
+        max_height = max(map(len, child_lines))
 
-        for child in node[1:]:
-            pos = len(lines[0])
+        # Assert that all child boxes are of equal height
+        for lines in child_lines:
+            additional_line = separator * len(lines[0])
+            lines += [additional_line for i in range(max_height - len(lines))]
 
-            child_lines = format_lines(child)
-            child_lines_len = len(child_lines)
+        assert len(child_lines[0]) == max_height
+        from copy import deepcopy
+        result = deepcopy(child_lines[0])
 
-            # A node cannot have zero child lines.
-            assert child_lines
+        for lines in child_lines[1:]:
+            assert len(lines) == max_height
 
             for i, line in enumerate(lines):
-                #print 'lines -> %d, "%s"' % (i, line)
-
-                if i < child_lines_len:
-                    padding_right = ' ' * (line_width - pos \
-                                           - len(child_lines[i]) \
-                                           - separator_len)
-
-                    lines[i] += separator + child_lines[i] + padding_right
-                else:
-                    # There are no more neighbor node on the right.
-                    lines[i] += ' '  * (line_width - pos)
-
-            # Add the child nodes that do not have neighbor nodes on the left.
-            for i, line in enumerate(child_lines[i+1:]):
-                #print 'child_lines[i:] -> %d, "%s"' % (i, line)
-                line = ' ' * (pos + separator_len) + line \
-                     + ' ' * (line_width - separator_len - pos - len(line))
-                lines.append(line)
-
-            # Validate that each line has an equal width
-            try:
-                for line in lines:
-                    assert len(line) == line_width
-            except AssertionError:
-
-                for l in lines:
-                    l = l.encode('utf-8')
-                    print l.replace(' ', '#'), len(l)
-
-                l = line.encode('utf-8')
-                print 'failed at:', l, len(l)
-                print 'current node:', node.title(), 'width:', line_width
-
-                raise
+                result[i] += separator + line
 
-        # Place the title above the middle two child nodes, or when there is a
-        # odd number of nodes, above the child node in the middle.
-        middle = node_middle(node)
+        line_width = node_width[node]
 
-        #print 'node:', node.title(), 'middle:', middle, 'node_mid:', node_mid
+        box_widths = [len(lines[0]) for lines in child_lines]
+        node_len = len(node)
+        middle_node = int(node_len / 2)
+        middle = sum([box_widths[i] for i in range(middle_node)]) \
+                 + max(middle_node - int(node_len % 2 == 0), 0) * separator_len
 
         title_line = center_text(node.title(), line_width, middle)
 
-        node_len = len(node)
         node_mid = node_len / 2
 
-        edge_line = ''
-
-        for i, child in enumerate(node):
-            if isinstance(child, node_type):
-                middle = node_middle(child)
-            else:
-                middle = 0
-
-            if i < node_mid:
-                marker = ('╭' if i == 0 else '┬')
-                edge_line += center_text(marker.decode('utf-8'),
-                                        node_width[child],
-                                        middle, right='─'.decode('utf-8'))
-            else:
-                if i == node_mid:
-                    edge_line += '┴'.decode('utf-8')
-
-                marker = ('╮' if i == node_len - 1 else '┬')
-                edge_line += center_text(marker.decode('utf-8'),
-                                        node_width[child],
-                                        middle, left='─'.decode('utf-8'))
-
-        #try:
-        #    assert len(title_line) == len(edge_line)
-        #except AssertionError:
-        #    print 'title_line:', title_line, len(title_line)
-        #    print 'edge_line:', edge_line, len(edge_line)
-        #    raise
+        pipe_sign = '│'.decode('utf-8')
+        dash_sign = '─'.decode('utf-8')
+        cross_sign = '┼'.decode('utf-8')
+        tsplit_dn_sign = '┬'.decode('utf-8')
+        tsplit_up_sign = '┴'.decode('utf-8')
+        left_sign = '╭'.decode('utf-8')
+        right_sign = '╮'.decode('utf-8')
+
+        if node_len == 1:
+            # Unary operators
+            edge_line = center_text(pipe_sign, box_widths[0], middle)
+        elif node_len % 2:
+            # n-ary operators (n is odd)
+            edge_line = ''
+
+            for i, child in enumerate(node):
+                if i > 0:
+                    edge_line += dash_sign
+
+                if i < middle_node:
+                    marker = (left_sign if i == 0 else tsplit_dn_sign)
+                    edge_line += center_text(marker, box_widths[i],
+                                             middle=0, right=dash_sign)
+                else:
+                    if i == middle_node:
+                        marker = cross_sign
+                        edge_line += center_text(marker, box_widths[i],
+                                                middle=0, right=dash_sign,
+                                                left=dash_sign)
+                    else:
+                        if i == node_len - 1:
+                            marker = right_sign
+                        else:
+                            marker = tsplit_dn_sign
+
+                        edge_line += center_text(marker, box_widths[i],
+                                                middle=0, left=dash_sign)
+        else:
+            # n-ary operators (n is even)
+            edge_line = ''
+
+            for i, child in enumerate(node):
+                if i > 0:
+                    if i == middle_node:
+                        edge_line += tsplit_up_sign
+                    else:
+                        edge_line += dash_sign
+
+                if i < middle_node:
+                    marker = (left_sign if i == 0 else tsplit_sign)
+                    edge_line += center_text(marker, box_widths[i],
+                                             middle=0, right=dash_sign)
+                else:
+                    marker = (right_sign if i == node_len - 1 else tsplit_sign)
+                    edge_line += center_text(marker, box_widths[i],
+                                             middle=0, left=dash_sign)
+
+        try:
+            assert len(title_line) == len(edge_line)
+        except AssertionError:
+            print '------------------'
+            print 'line_width:', line_width
+            print 'title_line:', title_line, 'len:', len(title_line)
+            print 'edge_line: %s (%d)' % (edge_line.encode('utf-8'),
+                                          len(edge_line))
+            print 'lines:'
+            print '\n'.join(map(lambda x: x + ' %d' % len(x), lines))
+            raise Exception()
 
         # Add the line of this node before all child lines.
-        return [title_line, edge_line] + lines
+        return [title_line, edge_line] + result
 
     calculate_width(root)
+
+    #if verbose:
+    #    print '------- node_width ---------'
+    #    for node, width in node_width.iteritems():
+    #        print node.title(), 'width:', width
+
     lines = format_lines(root)
 
-    return '\n'.join(lines).encode('utf-8')
+    # Strip trailing whitespace.
+    return '\n'.join(map(lambda x: x.rstrip(), lines)).encode('utf-8')
+
 
 def center_text(text, width, middle=0, left=' ', right=' '):
     """
+    >>> print center_text('│', 1, 1)
+    │
     >>> center_text('+', 15, 11)
     '           +   '
     """
     text_len = len(text)
     text_mid = text_len / 2
 
+    #print '---------'
+    #print 'text_len:', text_len
+    #print 'text_mid:', text_mid
+    #print 'width:', width
+    #print 'middle:', middle
+    #print '---------'
+
     # TODO: this code requires cleanup.
 
     if middle:

+ 34 - 17
tests/test_graph.py

@@ -16,59 +16,76 @@ class TestGraph(unittest.TestCase):
     def test_simple_unary(self):
         uminus = Node('-', self.l1)
         g = generate_graph(uminus, Node)
-        expect = self.strip("""
+        self.assertEqualGraphs(g, """
         -
         1
         """)
-        assert g == expect
 
     def test_simple_binary(self):
         plus = Node('+', self.l0, self.l1)
         g = generate_graph(plus, Node)
-        expect = self.strip("""
+        self.assertEqualGraphs(g, """
          +
         ╭┴╮
         0 1
         """)
-        assert g == expect
 
     def test_multichar_unary(self):
         uminus = Node('-', self.multi)
         g = generate_graph(uminus, Node)
-        expect = self.strip("""
+        self.assertEqualGraphs(g, """
          -
         test
         """)
-        print g
-        print expect
-        assert g == expect
 
     def test_multichar_binary(self):
         plus = Node('+', self.multi, self.l1)
         g = generate_graph(plus, Node)
-        expect = self.strip("""
+        self.assertEqualGraphs(g, """
             +
-        ╭──┴╮
+         ╭──┴╮
         test 1
         """)
-        assert g == expect
 
-    def test_function(self):
+    def test_ternary(self):
+        exp = Leaf('x')
+        inf = Leaf('o')
+        minus_inf = Node('-', Leaf('L'))
+        integral = Node('n', exp, minus_inf, inf)
+        g = generate_graph(integral, Node, verbose=True)
+        self.assertEqualGraphs(g, """
+          n
+        ╭─┼─╮
+        x - o
+          │
+          L
+        """)
+
+    def test_ternary_multichar(self):
         exp = Leaf('x')
         inf = Leaf('oo')
-        minus_inf = Node('-', inf)
+        minus_inf = Node('-', Leaf('LL'))
         integral = Node('int', exp, minus_inf, inf)
-        g = generate_graph(integral, Node)
-        expect = self.strip("""
+        g = generate_graph(integral, Node, verbose=True)
+        self.assertEqualGraphs(g, """
          int
         ╭─┼──╮
         x -  oo
-          oo
+          LL
         """)
-        assert g == expect
 
     def strip(self, str):
         return str.replace('\n        ', '\n')[1:-1]
+
+    def assertEqualGraphs(self, g, expect):
+        expect = self.strip(expect)
+
+        if g != expect:
+            print 'Expected:'
+            print expect
+            print 'Got:'
+            print g
+            raise AssertionError('Graph does not match expected value')