Built in support for writing n-ary operator nodes.

parent 00abd696
...@@ -31,11 +31,9 @@ def generate_line(root): ...@@ -31,11 +31,9 @@ def generate_line(root):
int(x, -oo, oo) int(x, -oo, oo)
""" """
# FIXME: Binding order
operators = [ operators = [
('+', '-'), ('+', '-'),
('mod', ), ('*', '/', 'mod'),
('*', '/'),
('^', ) ('^', )
] ]
max_assoc = len(operators) max_assoc = len(operators)
...@@ -77,26 +75,24 @@ def generate_line(root): ...@@ -77,26 +75,24 @@ def generate_line(root):
if is_operator(node): if is_operator(node):
if arity == 1: if arity == 1:
# Unary expression # Unary operator
s += traverse(node[0]) s += traverse(node[0])
elif arity == 2: else:
# Binary expression # N-ary operator
left, right = map(traverse, node)
# Check if there is an assiociativity conflict on either side.
# If so, add parentheses
node_assoc = assoc(node) node_assoc = assoc(node)
e = []
if assoc(node[0]) < node_assoc: for child in node:
left = '(' + left + ')' exp = traverse(child)
# Check if there is an assiociativity conflict.
# If so, add parentheses
if assoc(child) < node_assoc:
exp = '(' + exp + ')'
if assoc(node[1]) < node_assoc: e.append(exp)
right = '(' + right + ')'
s = left + ' ' + s + ' ' + right s = (' ' + s + ' ').join(e)
else: # pragma: nocover
raise ValueError('arity = %d is currently not supported.' \
% arity)
else: else:
# Function # Function
s += '(' + ', '.join(map(traverse, node)) + ')' s += '(' + ', '.join(map(traverse, node)) + ')'
......
...@@ -15,22 +15,27 @@ class TestLine(unittest.TestCase): ...@@ -15,22 +15,27 @@ class TestLine(unittest.TestCase):
def test_simple(self): def test_simple(self):
l0, l1 = Leaf(1), Leaf(2) l0, l1 = Leaf(1), Leaf(2)
plus = Node('+', l0, l1) plus = Node('+', l0, l1)
assert generate_line(plus) == '1 + 2' self.assertEquals(generate_line(plus), '1 + 2')
def test_parentheses(self): def test_parentheses(self):
l0, l1 = Leaf(1), Leaf(2) l0, l1 = Leaf(1), Leaf(2)
plus = Node('+', l0, l1) plus = Node('+', l0, l1)
times = Node('*', plus, plus) times = Node('*', plus, plus)
assert generate_line(times) == '(1 + 2) * (1 + 2)' self.assertEquals(generate_line(times), '(1 + 2) * (1 + 2)')
def test_function(self): def test_function(self):
exp = Leaf('x') exp = Leaf('x')
inf = Leaf('oo') inf = Leaf('oo')
minus_inf = Node('-', inf) minus_inf = Node('-', inf)
integral = Node('int', exp, minus_inf, inf) integral = Node('int', exp, minus_inf, inf)
assert generate_line(integral) == 'int(x, -oo, oo)' self.assertEquals(generate_line(integral), 'int(x, -oo, oo)')
def test_mod(self): def test_mod(self):
l0, l1 = Leaf(1), Leaf(2) l0, l1 = Leaf(1), Leaf(2)
mod = Node('mod', l1, l0) mod = Node('mod', l1, l0)
assert generate_line(mod) == '2 mod 1' self.assertEquals(generate_line(mod), '2 mod 1')
def test_n_ary(self):
l0, l1, l2 = Leaf(1), Leaf(2), Leaf(3)
plus = Node('+', l0, l1, l2)
self.assertEquals(generate_line(plus), '1 + 2 + 3')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment