Commit e17779f2 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Reimplemented line printer using precedences in accordance with the parser.

- All special cases for negation have been removed, except when they are also
  present in the parser.
- Some callbacks have been added to be able to customize output for certain
  operators in an extension of the Node class. This prevents chaotic sceneries
  in the graph_drawing library.
- Some useful unit tests have been added to test the precedence- and
  callback-meschanisms.
parent 107e93a1
from traverse import traverse_depth_first from traverse import traverse_depth_first
from node import Node
all_parens = ('()', '[]', '||', '{}')
OPERATORS = ( OPERATORS = (
('vv', ), ('left', ('vv', )),
('^^', ), ('left', ('^^', )),
('=', ), ('left', ('=', )),
('+', '-'), ('left', ('+', '-')),
('*', 'mod'), ('nonassoc', ('int', 'd/d')),
('/', ), ('left', ('*', 'mod')),
('^', '_'), ('left', ('/', )),
('nonassoc', ('\'', )),
('nonassoc', ('neg', )),
('nonassoc', ('function', )),
('right', ('^', '_')),
('nonassoc', all_parens),
) )
NEG_PRED = 3 assocs = {}
preds = {}
for i, (assoc, ops) in enumerate(OPERATORS):
for op in ops:
assocs[op] = assoc
preds[op] = i
NEG_PRED = preds['neg']
FUNC_PRED = preds['function']
MAX_PRED = len(OPERATORS) MAX_PRED = len(OPERATORS)
def is_operator(node): def is_function(node):
""" """
Check if a given node is an operator (otherwise, it's a function). Check if a given node is a function. A node is considered a function if it
is not a leaf, and is not known in the precedences list above.
""" """
label = node.title() return not node.is_leaf and node.title() not in preds
return any(map(lambda x: label in x, OPERATORS))
def pred(node): def pred(node):
""" """
Get the precedence of an operator node. Get the operator precedence of a node. Leaf nodes have the highest
precedence for practical reasons.
""" """
# Check binary and n-ary operators # Check known operators
if not node.is_leaf and len(node) > 1: if not node.is_leaf:
op = node.title() op = node.title()
for i, group in enumerate(OPERATORS): if node.is_negation():
if op in group: if node[0].title() in '*/':
return i return preds['-']
return NEG_PRED
#if node.is_postfix() and not node[0].is_leaf \
# and node[0].title() in all_parens:
# return preds['()']
if op in preds:
return preds[op]
return FUNC_PRED
# Unary operator and leaves have highest precedence # Unary operator and leaves have highest precedence
return MAX_PRED return MAX_PRED
def rightmost_node(node):
if node.is_leaf or not len(node):
return node
return rightmost_node(node[-1])
def is_unary_prefix(node):
"""
Check if a node is a unary operator that is placed before the operand.
"""
return not node.is_leaf and len(node) == 1 and node.title() == '-'
def is_left_assoc(op):
return op in assocs and assocs[op] == 'left'
def is_right_assoc(op):
return op in assocs and assocs[op] == 'right'
def is_id(node): def is_id(node):
return node.is_leaf and not node.title().isdigit() return node.is_leaf and not node.title().isdigit()
...@@ -52,6 +101,24 @@ def is_power(node): ...@@ -52,6 +101,24 @@ def is_power(node):
return not node.is_leaf and node.title() == '^' return not node.is_leaf and node.title() == '^'
def preprocess_node(node):
node = node.clone()
node.preprocess_str_exp()
if node.negated:
node.negated -= 1
return Node('-', preprocess_node(node))
if not node.is_leaf:
for i, child in enumerate(node):
node[i] = preprocess_node(child)
if node.title() == '+' and node[1].is_negation():
return Node('-', node[0], node[1][0])
return node
def generate_line(root): def generate_line(root):
""" """
Print an expression tree in a single text line. Where needed, add Print an expression tree in a single text line. Where needed, add
...@@ -77,13 +144,6 @@ def generate_line(root): ...@@ -77,13 +144,6 @@ def generate_line(root):
>>> print generate_line(times) >>> print generate_line(times)
(1 + 2) * -3 (1 + 2) * -3
>>> exp = Leaf('x')
>>> inf = Leaf('oo')
>>> minus_inf = Node('-', inf)
>>> integral = Node('int', exp, minus_inf, inf)
>>> print generate_line(integral)
int(x, -oo, oo)
>>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x')))) >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
>>> print generate_line(minus) >>> print generate_line(minus)
2 - -15x 2 - -15x
...@@ -107,124 +167,175 @@ def generate_line(root): ...@@ -107,124 +167,175 @@ def generate_line(root):
content = {} content = {}
def mult_sign(left, right, lparens, rparens):
# Get the previous multiplication element in an nary multiplication
if left.title() == '*':
left = rightmost_node(left)
# a * b -> ab
# a * 2 -> a * 2
# a * (b) -> a(b)
# (a) * b -> (a)b
# (a) * (b) -> (a)(b)
# 2 * a -> 2a
# a * sin(b) -> a sin(b)
left_char = content[left][-1]
right_char = content[right][0]
left_paren = lparens or left_char in ')]}'
right_paren = rparens or right_char in '([{'
right_alpha = right_char.isalpha()
left_simple = is_id(left) or is_int(left)
if left_paren or (right_paren and left_simple) \
or (is_id(left) and is_id(right)) \
or (is_int(left) and right_alpha):
return ''
if is_id(left) and right_alpha:
return ' '
return ' * '
def construct_unary(node): def construct_unary(node):
op = node.title() op = node.title()
value = node[0] value = node[0]
strval = content[value]
# -a if op in ('()', '[]', '||', '{}'):
# -3 * 4 return op[0] + strval + op[1]
# --a
if value.is_leaf \
or not ' ' in content[value] or pred(value) > NEG_PRED:
return op + content[value]
# -(a + b) parens = False
return '%s(%s)' % (op, content[value])
if pred(value) < pred(node):
parens = not value.is_negation()
elif pred(value) == pred(node):
parens = len(value) > 1
def construct_nary(node): if parens:
strval = '(' + strval + ')'
if node.is_postfix():
return strval + node.operator()
prefix = node.operator()
if prefix != '-' and not (strval[0] in '([|{' and is_function(node)):
prefix += ' '
return prefix + strval
def construct_binary(node):
op = node.title() op = node.title()
op_pred = pred(node)
# N-ary operator if node.no_spacing:
node_pred = pred(node) sep = node.operator()
sep = ' ' + op + ' ' else:
e = [] sep = ' ' + node.operator() + ' '
for i, child in enumerate(node): left, right = node
exp = content[child] lstr = content[left]
rstr = content[right]
lpred = pred(left)
rpred = pred(right)
lparens = rparens = False
unary_right = is_unary_prefix(right)
#if i and op == '+' and exp[:2] == '-(': if lpred < op_pred or (op in '*/' and left.is_negation()):
# exp = '-' + exp[2:-1] lparens = True
# print 'exp:', exp elif lpred == op_pred:
lparens = is_right_assoc(left.title()) or is_right_assoc(op)
# Check if there is a precedence conflict
# If so, add parentheses
child_pred = pred(child)
if child.negated:
# (-a) ^ b
# -a ^ -b
# (-a) * b
# a * -b
# (-a) / b
if node_pred > NEG_PRED:
exp = '(' + exp + ')'
elif child_pred < node_pred:
exp = '(' + exp + ')'
elif child_pred == node_pred:
if i and (op != child.title() or op == '/' \
or (op == '+' and child[1].negated)):
exp = '(' + exp + ')'
elif not i and op == '^':
exp = '(' + exp + ')'
e.append(exp) if rpred < op_pred:
rparens = not unary_right
elif rpred == op_pred and len(right) > 1:
if right.title() == op:
rparens = not is_right_assoc(op)
elif is_left_assoc(right.title()):
rparens = True
if op == '*': # Check if multiplication sig is necessary
# Check if an explicit multiplication sign is nessecary if op == '*' and not unary_right:
left, right = node sep = mult_sign(left, right, lparens, rparens)
# Get the previous multiplication element if the arity is if lparens:
# greater than 2 lstr = '(' + lstr + ')'
if left.title() == '*':
left = left[1]
# a * b -> ab if rparens:
# a * 2 -> a * 2 rstr = '(' + rstr + ')'
# a * (b) -> a(b)
# (a) * b -> (a)b
# (a) * (b) -> (a)(b)
# 2 * a -> 2a
l = e[0][-1]
r = e[1][0]
left_simple = is_id(left) or is_int(left)
if (r in ('(', '[') and left_simple) or (l == ')' and r != '-') \ return lstr + sep + rstr
or (left_simple and r.isalpha()):
sep = ''
exp = sep.join(e) def construct_nary_mult(node):
op_pred = pred(node)
lstr = content[node[0]]
lparens = pred(node[0]) < op_pred or node[0].is_negation()
if node.negated and op not in ('*', '/', '^'): if lparens:
exp = '(' + exp + ')' lstr = '(' + lstr + ')'
return exp for i, right in enumerate(node[1:]):
rparens = pred(right) < op_pred
rstr = content[right]
def construct_function(node): if rparens:
buf = [] rstr = '(' + rstr + ')'
sign = mult_sign(node[i], right, lparens, rparens)
lstr += sign + rstr
lparens = rparens
return lstr
def construct_nary(node):
if node.title() == '*':
return construct_nary_mult(node)
op_pred = pred(node)
e = []
for child in node: for child in node:
buf.append(content[child]) exp = content[child]
if pred(child) < op_pred:
exp = '(' + exp + ')'
return '%s(%s)' % (node.title(), ', '.join(buf)) e.append(exp)
return (' ' + node.operator() + ' ').join(e)
def construct_function(node):
children = [content[child] for child in node]
return '%s(%s)' % (node.operator(), ', '.join(children))
# Convert negations to unary nodes to be able to account for operator
# precedence
root = preprocess_node(root.clone())
# Traverse the expression tree and construct the mathematical expression in # Traverse the expression tree and construct the mathematical expression in
# the leafs and nodes in depth first order. # the leafs and nodes in depth first order.
for node in traverse_depth_first(root): for node in traverse_depth_first(root):
custom = node.custom_line()
if custom is not None:
content[node] = custom
continue
if node.is_leaf: if node.is_leaf:
content[node] = str(node) nodestr = str(node.value)
else: else:
arity = len(node) arity = node.arity()
if is_operator(node): if arity == 1:
if arity == 1: nodestr = construct_unary(node)
content[node] = construct_unary(node) elif is_function(node):
else: nodestr = construct_function(node)
content[node] = construct_nary(node) elif arity == 2:
nodestr = construct_binary(node)
else: else:
result = None nodestr = construct_nary(node)
if hasattr(node, 'construct_function'):
children = [content[c] for c in node]
result = node.construct_function(children)
if result == None:
result = construct_function(node)
content[node] = result
# Add negations content[node] = node.postprocess_str(nodestr)
content[node] = '-' * node.negated + content[node]
# Merge binary plus and unary minus signs into binary minus. # Merge binary plus and unary minus signs into a binary minus
return content[root].replace('+ -', '- ') return content[root].replace('+ -', '- ')
...@@ -7,6 +7,7 @@ class Node(object): ...@@ -7,6 +7,7 @@ class Node(object):
super(Node, self).__init__() super(Node, self).__init__()
self.value, self.nodes = value, list(nodes) self.value, self.nodes = value, list(nodes)
self.is_leaf = False self.is_leaf = False
self.no_spacing = kwargs.get('no_spacing', False)
self.negated = kwargs.get('negated', 0) self.negated = kwargs.get('negated', 0)
def __getitem__(self, n): def __getitem__(self, n):
...@@ -26,11 +27,17 @@ class Node(object): ...@@ -26,11 +27,17 @@ class Node(object):
and self.nodes == node.nodes and self.nodes == node.nodes
def __neg__(self): def __neg__(self):
copied = deepcopy(self) copied = self.clone()
copied.negated += 1 copied.negated += 1
return copied return copied
def __pos__(self):
copied = self.clone()
copied.negated = max(copied.negated - 1, 0)
return copied
def __str__(self): def __str__(self):
return '<Node value=%s nodes=%s negated=%d>' \ return '<Node value=%s nodes=%s negated=%d>' \
% (str(self.value), str(self.nodes), self.negated) % (str(self.value), str(self.nodes), self.negated)
...@@ -41,11 +48,39 @@ class Node(object): ...@@ -41,11 +48,39 @@ class Node(object):
def title(self): def title(self):
return str(self.value) return str(self.value)
def operator(self):
return self.value
def clone(self):
return deepcopy(self)
def arity(self):
return len(self)
def is_postfix(self):
return self.value == '\''
def is_negation(self):
return self.value == '-' and len(self) == 1
def custom_line(self):
pass
def preprocess_str_exp(self):
pass
def postprocess_str(self, string):
return string
class Leaf(Node): class Leaf(Node):
def __init__(self, value, **kwargs): def __init__(self, value, **kwargs):
super(Leaf, self).__init__(value, **kwargs) super(Leaf, self).__init__(value, **kwargs)
self.value = value
if type(value) in (int, float) and value < 0:
self.value = abs(value)
self.negated += 1
self.nodes = None self.nodes = None
self.is_leaf = True self.is_leaf = True
......
import unittest import unittest
import doctest import doctest
import new
import line import line
from node import Node as N, Leaf as L from node import Node as N, Leaf as L
...@@ -38,6 +39,8 @@ class TestLine(unittest.TestCase): ...@@ -38,6 +39,8 @@ class TestLine(unittest.TestCase):
minus = N('-', l0, plus) minus = N('-', l0, plus)
self.assertEquals(generate_line(minus), '1 - (2 + 3)') self.assertEquals(generate_line(minus), '1 - (2 + 3)')
power = N('^', l0, N('_', l1, l2))
self.assertEquals(generate_line(power), '1 ^ 2 _ 3')
power = N('^', l0, N('^', l1, l2)) power = N('^', l0, N('^', l1, l2))
self.assertEquals(generate_line(power), '1 ^ 2 ^ 3') self.assertEquals(generate_line(power), '1 ^ 2 ^ 3')
power = N('^', N('^', l0, l1), l2) power = N('^', N('^', l0, l1), l2)
...@@ -63,11 +66,8 @@ class TestLine(unittest.TestCase): ...@@ -63,11 +66,8 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(plus), '1 + 2 + 3') self.assertEquals(generate_line(plus), '1 + 2 + 3')
def test_function(self): def test_function(self):
exp = L('x') sin = N('sin', N('*', L(2), L('x')))
inf = L('oo') self.assertEquals(generate_line(sin), 'sin(2x)')
minus_inf = -L('oo')
integral = N('int', exp, minus_inf, inf)
self.assertEquals(generate_line(integral), 'int(x, -oo, oo)')
def test_mod(self): def test_mod(self):
l0, l1 = L(1), L(2) l0, l1 = L(1), L(2)
...@@ -77,7 +77,7 @@ class TestLine(unittest.TestCase): ...@@ -77,7 +77,7 @@ class TestLine(unittest.TestCase):
def test_multiplication_identifiers(self): def test_multiplication_identifiers(self):
a, b = L('a'), L('b') a, b = L('a'), L('b')
self.assertEquals(generate_line(N('*', a, b)), 'ab') self.assertEquals(generate_line(N('*', a, b)), 'ab')
self.assertEquals(generate_line(N('*', a, -b)), 'a(-b)') self.assertEquals(generate_line(N('*', a, -b)), 'a * -b')
def test_multiplication_constant_identifier(self): def test_multiplication_constant_identifier(self):
l0, a = L(2), L('a') l0, a = L(2), L('a')
...@@ -206,8 +206,14 @@ class TestLine(unittest.TestCase): ...@@ -206,8 +206,14 @@ class TestLine(unittest.TestCase):
neg = -N('-', L(1), L(2)) neg = -N('-', L(1), L(2))
self.assertEquals(generate_line(neg), '-(1 - 2)') self.assertEquals(generate_line(neg), '-(1 - 2)')
# FIXME: neg = N('+', L(1), N('+', L(1), L(2)))
# FIXME: self.assertEquals(generate_line(neg), '1 + 1 + 2')
neg = N('+', N('+', L(1), L(2)), L(3))
self.assertEquals(generate_line(neg), '1 + 2 + 3')
neg = N('+', L(1), N('+', L(1), L(2))) neg = N('+', L(1), N('+', L(1), L(2)))
self.assertEquals(generate_line(neg), '1 + 1 + 2') self.assertEquals(generate_line(neg), '1 + (1 + 2)')
neg = N('+', L(1), -N('+', L(1), L(2))) neg = N('+', L(1), -N('+', L(1), L(2)))
self.assertEquals(generate_line(neg), '1 - (1 + 2)') self.assertEquals(generate_line(neg), '1 - (1 + 2)')
...@@ -219,7 +225,7 @@ class TestLine(unittest.TestCase): ...@@ -219,7 +225,7 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(neg), '-4a') self.assertEquals(generate_line(neg), '-4a')
neg = N('*', L(4), -L('a')) neg = N('*', L(4), -L('a'))
self.assertEquals(generate_line(neg), '4(-a)') self.assertEquals(generate_line(neg), '4 * -a')
neg = -N('*', L(4), L(5)) neg = -N('*', L(4), L(5))
self.assertEquals(generate_line(neg), '-4 * 5') self.assertEquals(generate_line(neg), '-4 * 5')
...@@ -234,11 +240,15 @@ class TestLine(unittest.TestCase): ...@@ -234,11 +240,15 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(plus), 'a / b - c / d') self.assertEquals(generate_line(plus), 'a / b - c / d')
mul = N('*', N('+', L('a'), L('b')), -N('+', L('c'), L('d'))) mul = N('*', N('+', L('a'), L('b')), -N('+', L('c'), L('d')))
self.assertEquals(generate_line(mul), '(a + b)(-(c + d))') self.assertEquals(generate_line(mul), '(a + b) * -(c + d)')
def test_double_negation(self): def test_double_negation(self):
neg = --L(1) neg = --L(1)
self.assertEquals(generate_line(neg), '--1') self.assertEquals(generate_line(neg), '--1')
neg = --N('*', L('x'), L(2))
self.assertEquals(generate_line(neg), '--x * 2')
neg = --N('^', L('x'), L(2))
self.assertEquals(generate_line(neg), '--x ^ 2')
def test_divide_fractions(self): def test_divide_fractions(self):
a, b, c, d = L('a'), L('b'), L('c'), L('d') a, b, c, d = L('a'), L('b'), L('c'), L('d')
...@@ -246,3 +256,85 @@ class TestLine(unittest.TestCase): ...@@ -246,3 +256,85 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(div), 'a / (b / c)') self.assertEquals(generate_line(div), 'a / (b / c)')
div = N('/', N('/', a, b), N('/', c, d)) div = N('/', N('/', a, b), N('/', c, d))
self.assertEquals(generate_line(div), 'a / b / (c / d)') self.assertEquals(generate_line(div), 'a / b / (c / d)')
def test_prime(self):
a, b, c, d = L('a'), L('b'), L('c'), L('d')
root = N('*', a, N("'", b))
self.assertEquals(generate_line(root), "a b'")
root = N("'", -a)
self.assertEquals(generate_line(root), "-a'")
root = -N("'", a)
self.assertEquals(generate_line(root), "-(a')")
root = N("'", N('*', a, b))
self.assertEquals(generate_line(root), "(ab)'")
root = N("'", N('/', a, b))
def test_function(self):
root = N('sin', L('x'))
self.assertEquals(generate_line(root), 'sin x')
root = N('sin', N('+', L('x'), L(2)))
self.assertEquals(generate_line(root), 'sin(x + 2)')
root = N('dummyfunc', L('x'), L(2))
self.assertEquals(generate_line(root), 'dummyfunc(x, 2)')
def test_no_spacing(self):
root = N('+', L('x'), L(2), no_spacing=True)
self.assertEquals(generate_line(root), 'x+2')
def test_explicit_parentheses(self):
root = N('[]', L('x'))
self.assertEquals(generate_line(root), '[x]')
root = N('()', L('x'))
self.assertEquals(generate_line(root), '(x)')
root = N('{}', L('x'))
self.assertEquals(generate_line(root), '{x}')
root = N('^', N('[]', N('+', L('x'), L('y'))), L(2))
self.assertEquals(generate_line(root), '[x + y] ^ 2')
def test_abs(self):
root = N('||', L('x'))
self.assertEquals(generate_line(root), '|x|')
root = N('||', N('+', L('x'), L(1)))
self.assertEquals(generate_line(root), '|x + 1|')
root = N('ln', N('||', L('x')))
self.assertEquals(generate_line(root), 'ln|x|')
def test_postprocess_str(self):
root = N('int', N('^', L('x'), L(2)), L('x'))
root.arity = lambda: 1
root.postprocess_str = lambda s: s + ' dx'
self.assertEquals(generate_line(root), 'int x ^ 2 dx')
def test_concat_with_negation(self):
root = N('*', -L(2), L('x'))
self.assertEquals(generate_line(root), '(-2)x')
root = N('*', N('*', L(3), -L(2)), L('x'))
self.assertEquals(generate_line(root), '3 * -2x')
root = N('*', L(3), -L(2), L('x'))
self.assertEquals(generate_line(root), '3 * -2 * x')
def test_first_child_negation(self):
root = N('*', -L(1), L(2))
self.assertEquals(generate_line(root), '(-1)2')
root = -N('*', L(1), L(2))
self.assertEquals(generate_line(root), '-1 * 2')
root = N('/', -L(1), L(2))
self.assertEquals(generate_line(root), '(-1) / 2')
root = -N('/', L(1), L(2))
self.assertEquals(generate_line(root), '-1 / 2')
def test_postfix_brackets(self):
root = N('*', L('x'), N("'", N('[]', N('^', L('x'), L(2)))))
self.assertEquals(generate_line(root), "x[x ^ 2]'")
def test_custom_line(self):
root = N('*', L(1), L(2))
root.custom_line = lambda: 'test'
self.assertEquals(generate_line(root), 'test')
def test_preprocess_str_exp(self):
root = N('-', L(1))
def addbrackets(self): self[0] = N('[]', self[0])
root.preprocess_str_exp = new.instancemethod(addbrackets, root)
self.assertEquals(generate_line(root), '-[1]')
...@@ -51,3 +51,6 @@ class TestNode(unittest.TestCase): ...@@ -51,3 +51,6 @@ class TestNode(unittest.TestCase):
self.assertEqual(Node('+', l1, l2, negated=1).negated, 1) self.assertEqual(Node('+', l1, l2, negated=1).negated, 1)
self.assertEqual(Leaf(1, negated=2).negated, 2) self.assertEqual(Leaf(1, negated=2).negated, 2)
def test_negated_int_constructor(self):
self.assertEquals(-Leaf(2), Leaf(-2))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment