Commit e17779f2 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Reimplemented line printer using precedences in accordance with the parser.

- All special cases for negation have been removed, except when they are also
  present in the parser.
- Some callbacks have been added to be able to customize output for certain
  operators in an extension of the Node class. This prevents chaotic sceneries
  in the graph_drawing library.
- Some useful unit tests have been added to test the precedence- and
  callback-meschanisms.
parent 107e93a1
from traverse import traverse_depth_first
from node import Node
all_parens = ('()', '[]', '||', '{}')
OPERATORS = (
('vv', ),
('^^', ),
('=', ),
('+', '-'),
('*', 'mod'),
('/', ),
('^', '_'),
('left', ('vv', )),
('left', ('^^', )),
('left', ('=', )),
('left', ('+', '-')),
('nonassoc', ('int', 'd/d')),
('left', ('*', 'mod')),
('left', ('/', )),
('nonassoc', ('\'', )),
('nonassoc', ('neg', )),
('nonassoc', ('function', )),
('right', ('^', '_')),
('nonassoc', all_parens),
)
NEG_PRED = 3
assocs = {}
preds = {}
for i, (assoc, ops) in enumerate(OPERATORS):
for op in ops:
assocs[op] = assoc
preds[op] = i
NEG_PRED = preds['neg']
FUNC_PRED = preds['function']
MAX_PRED = len(OPERATORS)
def is_operator(node):
def is_function(node):
"""
Check if a given node is an operator (otherwise, it's a function).
Check if a given node is a function. A node is considered a function if it
is not a leaf, and is not known in the precedences list above.
"""
label = node.title()
return any(map(lambda x: label in x, OPERATORS))
return not node.is_leaf and node.title() not in preds
def pred(node):
"""
Get the precedence of an operator node.
Get the operator precedence of a node. Leaf nodes have the highest
precedence for practical reasons.
"""
# Check binary and n-ary operators
if not node.is_leaf and len(node) > 1:
# Check known operators
if not node.is_leaf:
op = node.title()
for i, group in enumerate(OPERATORS):
if op in group:
return i
if node.is_negation():
if node[0].title() in '*/':
return preds['-']
return NEG_PRED
#if node.is_postfix() and not node[0].is_leaf \
# and node[0].title() in all_parens:
# return preds['()']
if op in preds:
return preds[op]
return FUNC_PRED
# Unary operator and leaves have highest precedence
return MAX_PRED
def rightmost_node(node):
if node.is_leaf or not len(node):
return node
return rightmost_node(node[-1])
def is_unary_prefix(node):
"""
Check if a node is a unary operator that is placed before the operand.
"""
return not node.is_leaf and len(node) == 1 and node.title() == '-'
def is_left_assoc(op):
return op in assocs and assocs[op] == 'left'
def is_right_assoc(op):
return op in assocs and assocs[op] == 'right'
def is_id(node):
return node.is_leaf and not node.title().isdigit()
......@@ -52,6 +101,24 @@ def is_power(node):
return not node.is_leaf and node.title() == '^'
def preprocess_node(node):
node = node.clone()
node.preprocess_str_exp()
if node.negated:
node.negated -= 1
return Node('-', preprocess_node(node))
if not node.is_leaf:
for i, child in enumerate(node):
node[i] = preprocess_node(child)
if node.title() == '+' and node[1].is_negation():
return Node('-', node[0], node[1][0])
return node
def generate_line(root):
"""
Print an expression tree in a single text line. Where needed, add
......@@ -77,13 +144,6 @@ def generate_line(root):
>>> print generate_line(times)
(1 + 2) * -3
>>> exp = Leaf('x')
>>> inf = Leaf('oo')
>>> minus_inf = Node('-', inf)
>>> integral = Node('int', exp, minus_inf, inf)
>>> print generate_line(integral)
int(x, -oo, oo)
>>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
>>> print generate_line(minus)
2 - -15x
......@@ -107,124 +167,175 @@ def generate_line(root):
content = {}
def mult_sign(left, right, lparens, rparens):
# Get the previous multiplication element in an nary multiplication
if left.title() == '*':
left = rightmost_node(left)
# a * b -> ab
# a * 2 -> a * 2
# a * (b) -> a(b)
# (a) * b -> (a)b
# (a) * (b) -> (a)(b)
# 2 * a -> 2a
# a * sin(b) -> a sin(b)
left_char = content[left][-1]
right_char = content[right][0]
left_paren = lparens or left_char in ')]}'
right_paren = rparens or right_char in '([{'
right_alpha = right_char.isalpha()
left_simple = is_id(left) or is_int(left)
if left_paren or (right_paren and left_simple) \
or (is_id(left) and is_id(right)) \
or (is_int(left) and right_alpha):
return ''
if is_id(left) and right_alpha:
return ' '
return ' * '
def construct_unary(node):
op = node.title()
value = node[0]
strval = content[value]
# -a
# -3 * 4
# --a
if value.is_leaf \
or not ' ' in content[value] or pred(value) > NEG_PRED:
return op + content[value]
if op in ('()', '[]', '||', '{}'):
return op[0] + strval + op[1]
# -(a + b)
return '%s(%s)' % (op, content[value])
parens = False
if pred(value) < pred(node):
parens = not value.is_negation()
elif pred(value) == pred(node):
parens = len(value) > 1
def construct_nary(node):
op = node.title()
if parens:
strval = '(' + strval + ')'
# N-ary operator
node_pred = pred(node)
sep = ' ' + op + ' '
e = []
if node.is_postfix():
return strval + node.operator()
for i, child in enumerate(node):
exp = content[child]
prefix = node.operator()
#if i and op == '+' and exp[:2] == '-(':
# exp = '-' + exp[2:-1]
# print 'exp:', exp
# Check if there is a precedence conflict
# If so, add parentheses
child_pred = pred(child)
if child.negated:
# (-a) ^ b
# -a ^ -b
# (-a) * b
# a * -b
# (-a) / b
if node_pred > NEG_PRED:
exp = '(' + exp + ')'
elif child_pred < node_pred:
exp = '(' + exp + ')'
elif child_pred == node_pred:
if i and (op != child.title() or op == '/' \
or (op == '+' and child[1].negated)):
exp = '(' + exp + ')'
elif not i and op == '^':
exp = '(' + exp + ')'
if prefix != '-' and not (strval[0] in '([|{' and is_function(node)):
prefix += ' '
e.append(exp)
return prefix + strval
def construct_binary(node):
op = node.title()
op_pred = pred(node)
if node.no_spacing:
sep = node.operator()
else:
sep = ' ' + node.operator() + ' '
if op == '*':
# Check if an explicit multiplication sign is nessecary
left, right = node
lstr = content[left]
rstr = content[right]
lpred = pred(left)
rpred = pred(right)
lparens = rparens = False
unary_right = is_unary_prefix(right)
# Get the previous multiplication element if the arity is
# greater than 2
if left.title() == '*':
left = left[1]
if lpred < op_pred or (op in '*/' and left.is_negation()):
lparens = True
elif lpred == op_pred:
lparens = is_right_assoc(left.title()) or is_right_assoc(op)
# a * b -> ab
# a * 2 -> a * 2
# a * (b) -> a(b)
# (a) * b -> (a)b
# (a) * (b) -> (a)(b)
# 2 * a -> 2a
l = e[0][-1]
r = e[1][0]
left_simple = is_id(left) or is_int(left)
if rpred < op_pred:
rparens = not unary_right
elif rpred == op_pred and len(right) > 1:
if right.title() == op:
rparens = not is_right_assoc(op)
elif is_left_assoc(right.title()):
rparens = True
if (r in ('(', '[') and left_simple) or (l == ')' and r != '-') \
or (left_simple and r.isalpha()):
sep = ''
# Check if multiplication sig is necessary
if op == '*' and not unary_right:
sep = mult_sign(left, right, lparens, rparens)
exp = sep.join(e)
if lparens:
lstr = '(' + lstr + ')'
if node.negated and op not in ('*', '/', '^'):
exp = '(' + exp + ')'
if rparens:
rstr = '(' + rstr + ')'
return exp
return lstr + sep + rstr
def construct_function(node):
buf = []
def construct_nary_mult(node):
op_pred = pred(node)
lstr = content[node[0]]
lparens = pred(node[0]) < op_pred or node[0].is_negation()
if lparens:
lstr = '(' + lstr + ')'
for i, right in enumerate(node[1:]):
rparens = pred(right) < op_pred
rstr = content[right]
if rparens:
rstr = '(' + rstr + ')'
sign = mult_sign(node[i], right, lparens, rparens)
lstr += sign + rstr
lparens = rparens
return lstr
def construct_nary(node):
if node.title() == '*':
return construct_nary_mult(node)
op_pred = pred(node)
e = []
for child in node:
buf.append(content[child])
exp = content[child]
return '%s(%s)' % (node.title(), ', '.join(buf))
if pred(child) < op_pred:
exp = '(' + exp + ')'
e.append(exp)
return (' ' + node.operator() + ' ').join(e)
def construct_function(node):
children = [content[child] for child in node]
return '%s(%s)' % (node.operator(), ', '.join(children))
# Convert negations to unary nodes to be able to account for operator
# precedence
root = preprocess_node(root.clone())
# Traverse the expression tree and construct the mathematical expression in
# the leafs and nodes in depth first order.
for node in traverse_depth_first(root):
custom = node.custom_line()
if custom is not None:
content[node] = custom
continue
if node.is_leaf:
content[node] = str(node)
nodestr = str(node.value)
else:
arity = len(node)
arity = node.arity()
if is_operator(node):
if arity == 1:
content[node] = construct_unary(node)
else:
content[node] = construct_nary(node)
nodestr = construct_unary(node)
elif is_function(node):
nodestr = construct_function(node)
elif arity == 2:
nodestr = construct_binary(node)
else:
result = None
if hasattr(node, 'construct_function'):
children = [content[c] for c in node]
result = node.construct_function(children)
if result == None:
result = construct_function(node)
content[node] = result
nodestr = construct_nary(node)
# Add negations
content[node] = '-' * node.negated + content[node]
content[node] = node.postprocess_str(nodestr)
# Merge binary plus and unary minus signs into binary minus.
# Merge binary plus and unary minus signs into a binary minus
return content[root].replace('+ -', '- ')
......@@ -7,6 +7,7 @@ class Node(object):
super(Node, self).__init__()
self.value, self.nodes = value, list(nodes)
self.is_leaf = False
self.no_spacing = kwargs.get('no_spacing', False)
self.negated = kwargs.get('negated', 0)
def __getitem__(self, n):
......@@ -26,11 +27,17 @@ class Node(object):
and self.nodes == node.nodes
def __neg__(self):
copied = deepcopy(self)
copied = self.clone()
copied.negated += 1
return copied
def __pos__(self):
copied = self.clone()
copied.negated = max(copied.negated - 1, 0)
return copied
def __str__(self):
return '<Node value=%s nodes=%s negated=%d>' \
% (str(self.value), str(self.nodes), self.negated)
......@@ -41,11 +48,39 @@ class Node(object):
def title(self):
return str(self.value)
def operator(self):
return self.value
def clone(self):
return deepcopy(self)
def arity(self):
return len(self)
def is_postfix(self):
return self.value == '\''
def is_negation(self):
return self.value == '-' and len(self) == 1
def custom_line(self):
pass
def preprocess_str_exp(self):
pass
def postprocess_str(self, string):
return string
class Leaf(Node):
def __init__(self, value, **kwargs):
super(Leaf, self).__init__(value, **kwargs)
self.value = value
if type(value) in (int, float) and value < 0:
self.value = abs(value)
self.negated += 1
self.nodes = None
self.is_leaf = True
......
import unittest
import doctest
import new
import line
from node import Node as N, Leaf as L
......@@ -38,6 +39,8 @@ class TestLine(unittest.TestCase):
minus = N('-', l0, plus)
self.assertEquals(generate_line(minus), '1 - (2 + 3)')
power = N('^', l0, N('_', l1, l2))
self.assertEquals(generate_line(power), '1 ^ 2 _ 3')
power = N('^', l0, N('^', l1, l2))
self.assertEquals(generate_line(power), '1 ^ 2 ^ 3')
power = N('^', N('^', l0, l1), l2)
......@@ -63,11 +66,8 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(plus), '1 + 2 + 3')
def test_function(self):
exp = L('x')
inf = L('oo')
minus_inf = -L('oo')
integral = N('int', exp, minus_inf, inf)
self.assertEquals(generate_line(integral), 'int(x, -oo, oo)')
sin = N('sin', N('*', L(2), L('x')))
self.assertEquals(generate_line(sin), 'sin(2x)')
def test_mod(self):
l0, l1 = L(1), L(2)
......@@ -77,7 +77,7 @@ class TestLine(unittest.TestCase):
def test_multiplication_identifiers(self):
a, b = L('a'), L('b')
self.assertEquals(generate_line(N('*', a, b)), 'ab')
self.assertEquals(generate_line(N('*', a, -b)), 'a(-b)')
self.assertEquals(generate_line(N('*', a, -b)), 'a * -b')
def test_multiplication_constant_identifier(self):
l0, a = L(2), L('a')
......@@ -206,8 +206,14 @@ class TestLine(unittest.TestCase):
neg = -N('-', L(1), L(2))
self.assertEquals(generate_line(neg), '-(1 - 2)')
# FIXME: neg = N('+', L(1), N('+', L(1), L(2)))
# FIXME: self.assertEquals(generate_line(neg), '1 + 1 + 2')
neg = N('+', N('+', L(1), L(2)), L(3))
self.assertEquals(generate_line(neg), '1 + 2 + 3')
neg = N('+', L(1), N('+', L(1), L(2)))
self.assertEquals(generate_line(neg), '1 + 1 + 2')
self.assertEquals(generate_line(neg), '1 + (1 + 2)')
neg = N('+', L(1), -N('+', L(1), L(2)))
self.assertEquals(generate_line(neg), '1 - (1 + 2)')
......@@ -219,7 +225,7 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(neg), '-4a')
neg = N('*', L(4), -L('a'))
self.assertEquals(generate_line(neg), '4(-a)')
self.assertEquals(generate_line(neg), '4 * -a')
neg = -N('*', L(4), L(5))
self.assertEquals(generate_line(neg), '-4 * 5')
......@@ -234,11 +240,15 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(plus), 'a / b - c / d')
mul = N('*', N('+', L('a'), L('b')), -N('+', L('c'), L('d')))
self.assertEquals(generate_line(mul), '(a + b)(-(c + d))')
self.assertEquals(generate_line(mul), '(a + b) * -(c + d)')
def test_double_negation(self):
neg = --L(1)
self.assertEquals(generate_line(neg), '--1')
neg = --N('*', L('x'), L(2))
self.assertEquals(generate_line(neg), '--x * 2')
neg = --N('^', L('x'), L(2))
self.assertEquals(generate_line(neg), '--x ^ 2')
def test_divide_fractions(self):
a, b, c, d = L('a'), L('b'), L('c'), L('d')
......@@ -246,3 +256,85 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(div), 'a / (b / c)')
div = N('/', N('/', a, b), N('/', c, d))
self.assertEquals(generate_line(div), 'a / b / (c / d)')
def test_prime(self):
a, b, c, d = L('a'), L('b'), L('c'), L('d')
root = N('*', a, N("'", b))
self.assertEquals(generate_line(root), "a b'")
root = N("'", -a)
self.assertEquals(generate_line(root), "-a'")
root = -N("'", a)
self.assertEquals(generate_line(root), "-(a')")
root = N("'", N('*', a, b))
self.assertEquals(generate_line(root), "(ab)'")
root = N("'", N('/', a, b))
def test_function(self):
root = N('sin', L('x'))
self.assertEquals(generate_line(root), 'sin x')
root = N('sin', N('+', L('x'), L(2)))
self.assertEquals(generate_line(root), 'sin(x + 2)')
root = N('dummyfunc', L('x'), L(2))
self.assertEquals(generate_line(root), 'dummyfunc(x, 2)')
def test_no_spacing(self):
root = N('+', L('x'), L(2), no_spacing=True)
self.assertEquals(generate_line(root), 'x+2')
def test_explicit_parentheses(self):
root = N('[]', L('x'))
self.assertEquals(generate_line(root), '[x]')
root = N('()', L('x'))
self.assertEquals(generate_line(root), '(x)')
root = N('{}', L('x'))
self.assertEquals(generate_line(root), '{x}')
root = N('^', N('[]', N('+', L('x'), L('y'))), L(2))
self.assertEquals(generate_line(root), '[x + y] ^ 2')
def test_abs(self):
root = N('||', L('x'))
self.assertEquals(generate_line(root), '|x|')
root = N('||', N('+', L('x'), L(1)))
self.assertEquals(generate_line(root), '|x + 1|')
root = N('ln', N('||', L('x')))
self.assertEquals(generate_line(root), 'ln|x|')
def test_postprocess_str(self):
root = N('int', N('^', L('x'), L(2)), L('x'))
root.arity = lambda: 1
root.postprocess_str = lambda s: s + ' dx'
self.assertEquals(generate_line(root), 'int x ^ 2 dx')
def test_concat_with_negation(self):
root = N('*', -L(2), L('x'))
self.assertEquals(generate_line(root), '(-2)x')
root = N('*', N('*', L(3), -L(2)), L('x'))
self.assertEquals(generate_line(root), '3 * -2x')
root = N('*', L(3), -L(2), L('x'))
self.assertEquals(generate_line(root), '3 * -2 * x')
def test_first_child_negation(self):
root = N('*', -L(1), L(2))
self.assertEquals(generate_line(root), '(-1)2')
root = -N('*', L(1), L(2))
self.assertEquals(generate_line(root), '-1 * 2')
root = N('/', -L(1), L(2))
self.assertEquals(generate_line(root), '(-1) / 2')
root = -N('/', L(1), L(2))
self.assertEquals(generate_line(root), '-1 / 2')
def test_postfix_brackets(self):
root = N('*', L('x'), N("'", N('[]', N('^', L('x'), L(2)))))
self.assertEquals(generate_line(root), "x[x ^ 2]'")
def test_custom_line(self):
root = N('*', L(1), L(2))
root.custom_line = lambda: 'test'
self.assertEquals(generate_line(root), 'test')
def test_preprocess_str_exp(self):
root = N('-', L(1))
def addbrackets(self): self[0] = N('[]', self[0])
root.preprocess_str_exp = new.instancemethod(addbrackets, root)
self.assertEquals(generate_line(root), '-[1]')
......@@ -51,3 +51,6 @@ class TestNode(unittest.TestCase):
self.assertEqual(Node('+', l1, l2, negated=1).negated, 1)
self.assertEqual(Leaf(1, negated=2).negated, 2)
def test_negated_int_constructor(self):
self.assertEquals(-Leaf(2), Leaf(-2))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment