Removed node type from graph drawer.

parent bea696ab
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
# XXX Used in doctests (we should use them in the __main__ section below too).
from node import Leaf, Node
def generate_graph(root, node_type, separator=' ', verbose=False): def generate_graph(root, separator=' ', verbose=False):
""" """
Return a text-based, utf-8 graph of a tree-like structure. Each tree node Return a text-based, utf-8 graph of a tree-like structure. Each tree node
is represented by a length-2 list. If a node has an attribute called is represented by a length-2 list. If a node has an attribute called
...@@ -11,17 +13,17 @@ def generate_graph(root, node_type, separator=' ', verbose=False): ...@@ -11,17 +13,17 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
>>> l0, l1 = Leaf(0), Leaf(1) >>> l0, l1 = Leaf(0), Leaf(1)
>>> n0 = Node('+', l0, l1) >>> n0 = Node('+', l0, l1)
>>> l2 = Leaf(2) >>> l2 = Leaf(2)
>>> print generate_graph(n0, Node) >>> print generate_graph(n0)
+ +
╭┴╮ ╭┴╮
0 1 0 1
>>> n1 = Node('-', l2) >>> n1 = Node('-', l2)
>>> print generate_graph(n1, Node) >>> print generate_graph(n1)
- -
2 2
>>> n2 = Node('*', n0, n1) >>> n2 = Node('*', n0, n1)
>>> print generate_graph(n2, Node) >>> print generate_graph(n2)
* *
╭─┴╮ ╭─┴╮
+ - + -
...@@ -30,45 +32,55 @@ def generate_graph(root, node_type, separator=' ', verbose=False): ...@@ -30,45 +32,55 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
""" """
node_width = {} node_width = {}
node_middle = {}
separator_len = len(separator) separator_len = len(separator)
def calculate_width(node): def calculate_node_sizes(node):
title = node.title() title = node.title()
title_len = len(title) title_len = len(title)
# Leaves do not have children and therefore the length of its title is # Leaves do not have children and therefore the length of its title is
# the width of the leaf. # the width of the leaf.
if not isinstance(node, node_type): if not isinstance(node, Node):
node_width[node] = title_len node_width[node] = title_len
node_middle[node] = int(title_len / 2)
return title_len return title_len
node_len = len(node)
width = 0 width = 0
middle = 0
middle_pos = int(node_len / 2)
for i, child in enumerate(node):
tmp = calculate_node_sizes(child)
if i < middle_pos:
middle += tmp
for child in node: width += tmp
width += calculate_width(child)
middle += max(middle_pos - int(node_len % 2 == 0), 0) * separator_len
# Add a separator between each node (thus n - 1 separators). # Add a separator between each node (thus n - 1 separators).
width += separator_len * (len(node) - 1) width += separator_len * (len(node) - 1)
# Odd numbers of children should be minus 1, since the middle child
# can be placed directly below the parent. With even numbers, there
# is no middle child, so the space below the parent cannot be used.
#if len(node) % 2 == 1:
# width -= 1
# If the title of the node is wider than the sum of its children, the # If the title of the node is wider than the sum of its children, the
# title's width should be used. # title's width should be used.
width = max(title_len, width) if title_len > width:
width = title_len
middle = int(title_len / 2)
# print 'width of "%s":' % node.title(), width # print 'width of "%s":' % node.title(), width
node_width[node] = width node_width[node] = width
node_middle[node] = middle
return width return width
def format_lines(node): def format_lines(node):
if not isinstance(node, node_type): if not isinstance(node, Node):
# Leaf titles do not need to be centered, since the parent will # Leaf titles do not need to be centered, since the parent will
# center those lines. And if there are no parents, the entire graph # center those lines. And if there are no parents, the entire graph
# consists of a single leaf, so in that case there still is no # consists of a single leaf, so in that case there still is no
...@@ -98,11 +110,13 @@ def generate_graph(root, node_type, separator=' ', verbose=False): ...@@ -98,11 +110,13 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
line_width = node_width[node] line_width = node_width[node]
# TODO: substitute box_widths with node_width
box_widths = [len(lines[0]) for lines in child_lines] box_widths = [len(lines[0]) for lines in child_lines]
node_len = len(node) node_len = len(node)
middle_node = int(node_len / 2) middle_node = int(node_len / 2)
middle = sum([box_widths[i] for i in range(middle_node)]) \ #middle = sum([box_widths[i] for i in range(middle_node)]) \
+ max(middle_node - int(node_len % 2 == 0), 0) * separator_len # + max(middle_node - int(node_len % 2 == 0), 0) * separator_len
middle = node_middle[node]
title_line = center_text(node.title(), line_width, middle) title_line = center_text(node.title(), line_width, middle)
...@@ -157,7 +171,8 @@ def generate_graph(root, node_type, separator=' ', verbose=False): ...@@ -157,7 +171,8 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
if i < middle_node: if i < middle_node:
marker = (left_sign if i == 0 else tsplit_dn_sign) marker = (left_sign if i == 0 else tsplit_dn_sign)
edge_line += center_text(marker, box_widths[i], edge_line += center_text(marker, box_widths[i],
middle=0, right=dash_sign) middle=0,
right=dash_sign)
else: else:
if i == node_len - 1: if i == node_len - 1:
marker = right_sign marker = right_sign
...@@ -165,7 +180,8 @@ def generate_graph(root, node_type, separator=' ', verbose=False): ...@@ -165,7 +180,8 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
marker = tsplit_dn_sign marker = tsplit_dn_sign
edge_line += center_text(marker, box_widths[i], edge_line += center_text(marker, box_widths[i],
middle=0, left=dash_sign) middle=0,
left=dash_sign)
try: try:
assert len(title_line) == len(edge_line) assert len(title_line) == len(edge_line)
...@@ -182,12 +198,12 @@ def generate_graph(root, node_type, separator=' ', verbose=False): ...@@ -182,12 +198,12 @@ def generate_graph(root, node_type, separator=' ', verbose=False):
# Add the line of this node before all child lines. # Add the line of this node before all child lines.
return [title_line, edge_line] + result return [title_line, edge_line] + result
calculate_width(root) calculate_node_sizes(root)
#if verbose: #if verbose:
# print '------- node_width ---------' # print '------- node_{width,middle} ---------'
# for node, width in node_width.iteritems(): # for node, width in node_width.iteritems():
# print node.title(), 'width:', width # print node.title(), 'width:', width, 'middle:', node_middle[node]
lines = format_lines(root) lines = format_lines(root)
...@@ -201,6 +217,19 @@ def center_text(text, width, middle=0, left=' ', right=' '): ...@@ -201,6 +217,19 @@ def center_text(text, width, middle=0, left=' ', right=' '):
>>> center_text('+', 15, 11) >>> center_text('+', 15, 11)
' + ' ' + '
>>> left = center_text('╭'.decode('utf-8'), 11, 8, right='─'.decode('utf-8'))
>>> len(left) == 11
True
>>> right = center_text('╮'.decode('utf-8'), 3, 2, left='─'.decode('utf-8'))
>>> len(right) == 3
True
>>> edge_line = left + '┴'.decode('utf-8') + right
>>> len(edge_line) == 15
True
>>> title_line = center_text('+', 15, 11)
>>> print '|%s|\\n|%s|' % (title_line, edge_line.encode('utf-8'))
| + |
| ╭──┴──╮|
""" """
text_len = len(text) text_len = len(text)
text_mid = text_len / 2 text_mid = text_len / 2
......
...@@ -15,7 +15,7 @@ class TestGraph(unittest.TestCase): ...@@ -15,7 +15,7 @@ class TestGraph(unittest.TestCase):
def test_simple_unary(self): def test_simple_unary(self):
uminus = Node('-', self.l1) uminus = Node('-', self.l1)
g = generate_graph(uminus, Node) g = generate_graph(uminus)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
- -
...@@ -24,7 +24,7 @@ class TestGraph(unittest.TestCase): ...@@ -24,7 +24,7 @@ class TestGraph(unittest.TestCase):
def test_simple_binary(self): def test_simple_binary(self):
plus = Node('+', self.l0, self.l1) plus = Node('+', self.l0, self.l1)
g = generate_graph(plus, Node) g = generate_graph(plus)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
+ +
╭┴╮ ╭┴╮
...@@ -33,7 +33,7 @@ class TestGraph(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestGraph(unittest.TestCase):
def test_multichar_unary(self): def test_multichar_unary(self):
uminus = Node('-', self.multi) uminus = Node('-', self.multi)
g = generate_graph(uminus, Node) g = generate_graph(uminus)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
- -
...@@ -42,7 +42,7 @@ class TestGraph(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestGraph(unittest.TestCase):
def test_multichar_binary(self): def test_multichar_binary(self):
plus = Node('+', self.multi, self.l1) plus = Node('+', self.multi, self.l1)
g = generate_graph(plus, Node) g = generate_graph(plus)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
+ +
╭──┴╮ ╭──┴╮
...@@ -54,7 +54,7 @@ class TestGraph(unittest.TestCase): ...@@ -54,7 +54,7 @@ class TestGraph(unittest.TestCase):
inf = Leaf('o') inf = Leaf('o')
minus_inf = Node('-', Leaf('L')) minus_inf = Node('-', Leaf('L'))
integral = Node('n', exp, minus_inf, inf) integral = Node('n', exp, minus_inf, inf)
g = generate_graph(integral, Node, verbose=True) g = generate_graph(integral)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
n n
╭─┼─╮ ╭─┼─╮
...@@ -68,7 +68,7 @@ class TestGraph(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestGraph(unittest.TestCase):
inf = Leaf('oo') inf = Leaf('oo')
minus_inf = Node('-', Leaf('LL')) minus_inf = Node('-', Leaf('LL'))
integral = Node('int', exp, minus_inf, inf) integral = Node('int', exp, minus_inf, inf)
g = generate_graph(integral, Node, verbose=True) g = generate_graph(integral)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
int int
╭─┼──╮ ╭─┼──╮
...@@ -82,7 +82,7 @@ class TestGraph(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestGraph(unittest.TestCase):
minus_99 = Node('-', Leaf('99')) minus_99 = Node('-', Leaf('99'))
minus_inf = Node('-', Leaf('oo')) minus_inf = Node('-', Leaf('oo'))
integral = Node('int', exp, minus_inf, minus_99) integral = Node('int', exp, minus_inf, minus_99)
g = generate_graph(integral, Node, verbose=True) g = generate_graph(integral)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
int int
╭─┼──╮ ╭─┼──╮
...@@ -96,7 +96,7 @@ class TestGraph(unittest.TestCase): ...@@ -96,7 +96,7 @@ class TestGraph(unittest.TestCase):
minus_99 = Node('-', Leaf('99')) minus_99 = Node('-', Leaf('99'))
ten = Leaf('10') ten = Leaf('10')
integral = Node('int', exp, ten, minus_99) integral = Node('int', exp, ten, minus_99)
g = generate_graph(integral, Node, verbose=True) g = generate_graph(integral)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
int int
╭─┼──╮ ╭─┼──╮
...@@ -108,7 +108,7 @@ class TestGraph(unittest.TestCase): ...@@ -108,7 +108,7 @@ class TestGraph(unittest.TestCase):
def test_quaternary(self): def test_quaternary(self):
a, b, c, d = Leaf(0), Leaf(1), Leaf(2), Leaf(3) a, b, c, d = Leaf(0), Leaf(1), Leaf(2), Leaf(3)
sum_node = Node('sum', a, b, c, d) sum_node = Node('sum', a, b, c, d)
g = generate_graph(sum_node, Node, verbose=True) g = generate_graph(sum_node)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
sum sum
╭─┬┴┬─╮ ╭─┬┴┬─╮
...@@ -118,13 +118,58 @@ class TestGraph(unittest.TestCase): ...@@ -118,13 +118,58 @@ class TestGraph(unittest.TestCase):
def test_quinary(self): def test_quinary(self):
a, b, c, d, e = Leaf(0), Leaf(1), Leaf(2), Leaf(3), Leaf(4) a, b, c, d, e = Leaf(0), Leaf(1), Leaf(2), Leaf(3), Leaf(4)
sum_node = Node('sum', a, b, c, d, e) sum_node = Node('sum', a, b, c, d, e)
g = generate_graph(sum_node, Node, verbose=True) g = generate_graph(sum_node)
self.assertEqualGraphs(g, """ self.assertEqualGraphs(g, """
sum sum
╭─┬─┼─┬─╮ ╭─┬─┼─┬─╮
0 1 2 3 4 0 1 2 3 4
""") """)
def test_expression_small(self):
l0 = Leaf(3)
l1 = Leaf(4)
l2 = Leaf(5)
l3 = Leaf(7)
n0 = Node('+', l0, l1)
n1 = Node('+', l2, l3)
n2 = Node('*', n0, n1)
g = generate_graph(n2)
self.assertEqualGraphs(g, """
*
╭─┴─╮
+ +
╭┴╮ ╭┴╮
3 4 5 7
""")
def test_expression_larger(self):
a = Leaf(3)
b = Leaf(4)
c = Leaf(5)
d = Leaf(7)
ac = Node('*', a, c)
ad = Node('*', a, d)
bc = Node('*', b, c)
bd = Node('*', b, d)
root = Node('+', Node('+', Node('+', ac, ad), bc), bd)
g = generate_graph(root, verbose=True)
self.assertEqualGraphs(g, """
+
╭───┴─╮
+ *
╭───┴─╮ ╭┴╮
+ * 4 7
╭─┴─╮ ╭┴╮
* * 4 5
╭┴╮ ╭┴╮
3 5 3 7
""")
def strip(self, str): def strip(self, str):
return str.replace('\n ', '\n')[1:-1] return str.replace('\n ', '\n')[1:-1]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment