Successfully built n-ary graph drawer.

parent d9474ea0
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from node import Node, Leaf from node import Node, Leaf
def generate_graph(root, node_type, separator=' '): def generate_graph(root, node_type, separator=' ', verbose=False):
""" """
Return a text-based, utf-8 graph of a tree-like structure. Each tree node Return a text-based, utf-8 graph of a tree-like structure. Each tree node
is represented by a length-2 list. If a node has an attribute called is represented by a length-2 list. If a node has an attribute called
...@@ -11,11 +11,11 @@ def generate_graph(root, node_type, separator=' '): ...@@ -11,11 +11,11 @@ def generate_graph(root, node_type, separator=' '):
>>> l0, l1 = Leaf(0), Leaf(1) >>> l0, l1 = Leaf(0), Leaf(1)
>>> n0 = Node('+', l0, l1) >>> n0 = Node('+', l0, l1)
>>> l2 = Leaf(2)
>>> print generate_graph(n0, Node) >>> print generate_graph(n0, Node)
+ +
╭┴╮ ╭┴╮
0 1 0 1
>>> l2 = Leaf(2)
>>> n1 = Node('-', l2) >>> n1 = Node('-', l2)
>>> print generate_graph(n1, Node) >>> print generate_graph(n1, Node)
- -
...@@ -23,8 +23,8 @@ def generate_graph(root, node_type, separator=' '): ...@@ -23,8 +23,8 @@ def generate_graph(root, node_type, separator=' '):
2 2
>>> n2 = Node('*', n0, n1) >>> n2 = Node('*', n0, n1)
>>> print generate_graph(n2, Node) >>> print generate_graph(n2, Node)
+ *
┴─ ─┴
+ - + -
╭┴╮ │ ╭┴╮ │
0 1 2 0 1 2
...@@ -50,28 +50,30 @@ def generate_graph(root, node_type, separator=' '): ...@@ -50,28 +50,30 @@ def generate_graph(root, node_type, separator=' '):
width += calculate_width(child) width += calculate_width(child)
# Add a separator between each node (thus n - 1 separators). # Add a separator between each node (thus n - 1 separators).
width += separator_len * len(node) - 1 width += separator_len * (len(node) - 1)
# Odd numbers of children should be minus 1, since the middle child # Odd numbers of children should be minus 1, since the middle child
# can be placed directly below the parent. With even numbers, there # can be placed directly below the parent. With even numbers, there
# is no middle child, so the space below the parent cannot be used. # is no middle child, so the space below the parent cannot be used.
if len(node) % 2 == 1: #if len(node) % 2 == 1:
width -= 1 # width -= 1
# If the title of the node is wider than the sum of its children, the # If the title of the node is wider than the sum of its children, the
# title's width should be used. # title's width should be used.
width = max(title_len, width) width = max(title_len, width)
# print 'width of "%s":' % node.title(), width
node_width[node] = width node_width[node] = width
return width return width
def node_middle(node): def node_middle(node):
node_len = len(node) node_len = len(node)
node_mid = node_len / 2 node_mid = int(node_len / 2)
if node_len % 2 == 1: #if node_len % 2 == 1:
node_mid += 1 # node_mid += 1
middle = 0 middle = 0
...@@ -79,9 +81,9 @@ def generate_graph(root, node_type, separator=' '): ...@@ -79,9 +81,9 @@ def generate_graph(root, node_type, separator=' '):
if i == node_mid: if i == node_mid:
break break
middle += separator_len + node_width[child] middle += node_width[child]
return middle - separator_len return middle + max(0, separator_len * (node_mid - 1))
def format_lines(node): def format_lines(node):
if not isinstance(node, node_type): if not isinstance(node, node_type):
...@@ -94,111 +96,138 @@ def generate_graph(root, node_type, separator=' '): ...@@ -94,111 +96,138 @@ def generate_graph(root, node_type, separator=' '):
# At least one child, otherwise it would be a leaf. # At least one child, otherwise it would be a leaf.
assert node[0] assert node[0]
line_width = node_width[node] child_lines = [format_lines(child) for child in node]
max_height = max(map(len, child_lines))
lines = format_lines(node[0])
for child in node[1:]: # Assert that all child boxes are of equal height
pos = len(lines[0]) for lines in child_lines:
additional_line = separator * len(lines[0])
lines += [additional_line for i in range(max_height - len(lines))]
child_lines = format_lines(child) assert len(child_lines[0]) == max_height
child_lines_len = len(child_lines) from copy import deepcopy
result = deepcopy(child_lines[0])
# A node cannot have zero child lines. for lines in child_lines[1:]:
assert child_lines assert len(lines) == max_height
for i, line in enumerate(lines): for i, line in enumerate(lines):
#print 'lines -> %d, "%s"' % (i, line) result[i] += separator + line
if i < child_lines_len:
padding_right = ' ' * (line_width - pos \
- len(child_lines[i]) \
- separator_len)
lines[i] += separator + child_lines[i] + padding_right
else:
# There are no more neighbor node on the right.
lines[i] += ' ' * (line_width - pos)
# Add the child nodes that do not have neighbor nodes on the left. line_width = node_width[node]
for i, line in enumerate(child_lines[i+1:]):
#print 'child_lines[i:] -> %d, "%s"' % (i, line)
line = ' ' * (pos + separator_len) + line \
+ ' ' * (line_width - separator_len - pos - len(line))
lines.append(line)
# Validate that each line has an equal width
try:
for line in lines:
assert len(line) == line_width
except AssertionError:
for l in lines:
l = l.encode('utf-8')
print l.replace(' ', '#'), len(l)
l = line.encode('utf-8')
print 'failed at:', l, len(l)
print 'current node:', node.title(), 'width:', line_width
raise
# Place the title above the middle two child nodes, or when there is a
# odd number of nodes, above the child node in the middle.
middle = node_middle(node)
#print 'node:', node.title(), 'middle:', middle, 'node_mid:', node_mid box_widths = [len(lines[0]) for lines in child_lines]
node_len = len(node)
middle_node = int(node_len / 2)
middle = sum([box_widths[i] for i in range(middle_node)]) \
+ max(middle_node - int(node_len % 2 == 0), 0) * separator_len
title_line = center_text(node.title(), line_width, middle) title_line = center_text(node.title(), line_width, middle)
node_len = len(node)
node_mid = node_len / 2 node_mid = node_len / 2
pipe_sign = '│'.decode('utf-8')
dash_sign = '─'.decode('utf-8')
cross_sign = '┼'.decode('utf-8')
tsplit_dn_sign = '┬'.decode('utf-8')
tsplit_up_sign = '┴'.decode('utf-8')
left_sign = '╭'.decode('utf-8')
right_sign = '╮'.decode('utf-8')
if node_len == 1:
# Unary operators
edge_line = center_text(pipe_sign, box_widths[0], middle)
elif node_len % 2:
# n-ary operators (n is odd)
edge_line = '' edge_line = ''
for i, child in enumerate(node): for i, child in enumerate(node):
if isinstance(child, node_type): if i > 0:
middle = node_middle(child) edge_line += dash_sign
if i < middle_node:
marker = (left_sign if i == 0 else tsplit_dn_sign)
edge_line += center_text(marker, box_widths[i],
middle=0, right=dash_sign)
else: else:
middle = 0 if i == middle_node:
marker = cross_sign
edge_line += center_text(marker, box_widths[i],
middle=0, right=dash_sign,
left=dash_sign)
else:
if i == node_len - 1:
marker = right_sign
else:
marker = tsplit_dn_sign
if i < node_mid: edge_line += center_text(marker, box_widths[i],
marker = ('╭' if i == 0 else '┬') middle=0, left=dash_sign)
edge_line += center_text(marker.decode('utf-8'),
node_width[child],
middle, right='─'.decode('utf-8'))
else: else:
if i == node_mid: # n-ary operators (n is even)
edge_line += '┴'.decode('utf-8') edge_line = ''
for i, child in enumerate(node):
if i > 0:
if i == middle_node:
edge_line += tsplit_up_sign
else:
edge_line += dash_sign
marker = ('╮' if i == node_len - 1 else '┬') if i < middle_node:
edge_line += center_text(marker.decode('utf-8'), marker = (left_sign if i == 0 else tsplit_sign)
node_width[child], edge_line += center_text(marker, box_widths[i],
middle, left='─'.decode('utf-8')) middle=0, right=dash_sign)
else:
marker = (right_sign if i == node_len - 1 else tsplit_sign)
edge_line += center_text(marker, box_widths[i],
middle=0, left=dash_sign)
#try: try:
# assert len(title_line) == len(edge_line) assert len(title_line) == len(edge_line)
#except AssertionError: except AssertionError:
# print 'title_line:', title_line, len(title_line) print '------------------'
# print 'edge_line:', edge_line, len(edge_line) print 'line_width:', line_width
# raise print 'title_line:', title_line, 'len:', len(title_line)
print 'edge_line: %s (%d)' % (edge_line.encode('utf-8'),
len(edge_line))
print 'lines:'
print '\n'.join(map(lambda x: x + ' %d' % len(x), lines))
raise Exception()
# Add the line of this node before all child lines. # Add the line of this node before all child lines.
return [title_line, edge_line] + lines return [title_line, edge_line] + result
calculate_width(root) calculate_width(root)
#if verbose:
# print '------- node_width ---------'
# for node, width in node_width.iteritems():
# print node.title(), 'width:', width
lines = format_lines(root) lines = format_lines(root)
return '\n'.join(lines).encode('utf-8') # Strip trailing whitespace.
return '\n'.join(map(lambda x: x.rstrip(), lines)).encode('utf-8')
def center_text(text, width, middle=0, left=' ', right=' '): def center_text(text, width, middle=0, left=' ', right=' '):
""" """
>>> print center_text('│', 1, 1)
>>> center_text('+', 15, 11) >>> center_text('+', 15, 11)
' + ' ' + '
""" """
text_len = len(text) text_len = len(text)
text_mid = text_len / 2 text_mid = text_len / 2
#print '---------'
#print 'text_len:', text_len
#print 'text_mid:', text_mid
#print 'width:', width
#print 'middle:', middle
#print '---------'
# TODO: this code requires cleanup. # TODO: this code requires cleanup.
if middle: if middle:
......
...@@ -16,59 +16,76 @@ class TestGraph(unittest.TestCase): ...@@ -16,59 +16,76 @@ class TestGraph(unittest.TestCase):
def test_simple_unary(self): def test_simple_unary(self):
uminus = Node('-', self.l1) uminus = Node('-', self.l1)
g = generate_graph(uminus, Node) g = generate_graph(uminus, Node)
expect = self.strip(""" self.assertEqualGraphs(g, """
- -
1 1
""") """)
assert g == expect
def test_simple_binary(self): def test_simple_binary(self):
plus = Node('+', self.l0, self.l1) plus = Node('+', self.l0, self.l1)
g = generate_graph(plus, Node) g = generate_graph(plus, Node)
expect = self.strip(""" self.assertEqualGraphs(g, """
+ +
╭┴╮ ╭┴╮
0 1 0 1
""") """)
assert g == expect
def test_multichar_unary(self): def test_multichar_unary(self):
uminus = Node('-', self.multi) uminus = Node('-', self.multi)
g = generate_graph(uminus, Node) g = generate_graph(uminus, Node)
expect = self.strip(""" self.assertEqualGraphs(g, """
- -
test test
""") """)
print g
print expect
assert g == expect
def test_multichar_binary(self): def test_multichar_binary(self):
plus = Node('+', self.multi, self.l1) plus = Node('+', self.multi, self.l1)
g = generate_graph(plus, Node) g = generate_graph(plus, Node)
expect = self.strip(""" self.assertEqualGraphs(g, """
+ +
╭───┴╮ ──┴╮
test 1 test 1
""") """)
assert g == expect
def test_function(self): def test_ternary(self):
exp = Leaf('x')
inf = Leaf('o')
minus_inf = Node('-', Leaf('L'))
integral = Node('n', exp, minus_inf, inf)
g = generate_graph(integral, Node, verbose=True)
self.assertEqualGraphs(g, """
n
╭─┼─╮
x - o
L
""")
def test_ternary_multichar(self):
exp = Leaf('x') exp = Leaf('x')
inf = Leaf('oo') inf = Leaf('oo')
minus_inf = Node('-', inf) minus_inf = Node('-', Leaf('LL'))
integral = Node('int', exp, minus_inf, inf) integral = Node('int', exp, minus_inf, inf)
g = generate_graph(integral, Node) g = generate_graph(integral, Node, verbose=True)
expect = self.strip(""" self.assertEqualGraphs(g, """
int int
╭─┼──╮ ╭─┼──╮
x - oo x - oo
oo LL
""") """)
assert g == expect
def strip(self, str): def strip(self, str):
return str.replace('\n ', '\n')[1:-1] return str.replace('\n ', '\n')[1:-1]
def assertEqualGraphs(self, g, expect):
expect = self.strip(expect)
if g != expect:
print 'Expected:'
print expect
print 'Got:'
print g
raise AssertionError('Graph does not match expected value')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment