Commit 21f0710c authored by Taddeus Kroes's avatar Taddeus Kroes

Improved power multiplication and negation printing.

parent 7ec4d99c
...@@ -43,6 +43,10 @@ def is_int(node): ...@@ -43,6 +43,10 @@ def is_int(node):
return node.is_leaf and node.title().isdigit() return node.is_leaf and node.title().isdigit()
def is_power(node):
return not node.is_leaf and node.title() == '^'
def generate_line(root): def generate_line(root):
""" """
Print an expression tree in a single text line. Where needed, add Print an expression tree in a single text line. Where needed, add
...@@ -132,10 +136,14 @@ def generate_line(root): ...@@ -132,10 +136,14 @@ def generate_line(root):
# If so, add parentheses # If so, add parentheses
child_pred = pred(child) child_pred = pred(child)
if not child.negated and (child_pred < node_pred \ if child.negated:
# (-a) ^ b
if op == '^' and not i:
exp = '(' + exp + ')'
elif child_pred < node_pred \
or (i and child_pred == node_pred \ or (i and child_pred == node_pred \
and (op != child.title() \ and (op != child.title() \
or (op == '+' and child[1].negated)))): or (op == '+' and child[1].negated))):
exp = '(' + exp + ')' exp = '(' + exp + ')'
e.append(exp) e.append(exp)
...@@ -158,13 +166,18 @@ def generate_line(root): ...@@ -158,13 +166,18 @@ def generate_line(root):
left_paren = e[0][-1] == ')' left_paren = e[0][-1] == ')'
right_paren = e[1][0] == '(' right_paren = e[1][0] == '('
if (is_id(left) or left_paren or is_int(left)) \ left_suited = left_paren or is_id(left) or is_int(left)
and ((not right.negated and is_id(right)) or right_paren): right_suited = right_paren
if not right_suited and not right.negated:
right_suited = is_id(right) or is_power(right)
if left_suited and right_suited:
sep = '' sep = ''
exp = sep.join(e) exp = sep.join(e)
if node.negated and op not in ('*', '/'): if node.negated and op not in ('*', '/', '^'):
exp = '(' + exp + ')' exp = '(' + exp + ')'
return exp return exp
......
...@@ -3,7 +3,7 @@ import doctest ...@@ -3,7 +3,7 @@ import doctest
import line import line
from node import Node as N, Leaf as L from node import Node as N, Leaf as L
from line import generate_line, is_id, is_int from line import generate_line, is_id, is_int, is_power
class TestLine(unittest.TestCase): class TestLine(unittest.TestCase):
...@@ -105,6 +105,13 @@ class TestLine(unittest.TestCase): ...@@ -105,6 +105,13 @@ class TestLine(unittest.TestCase):
node_mul = N('*', a, node_pow) node_mul = N('*', a, node_pow)
self.assertEquals(generate_line(node_mul), 'a(b + c) ^ (d + e)') self.assertEquals(generate_line(node_mul), 'a(b + c) ^ (d + e)')
def test_pow_negated_root(self):
a, l2 = L('a'), L(2)
power = -N('^', a, l2)
self.assertEquals(generate_line(power), '-a ^ 2')
power = N('^', -a, l2)
self.assertEquals(generate_line(power), '(-a) ^ 2')
def test_multiplication_sign(self): def test_multiplication_sign(self):
a, b, c, l2 = L('a'), L('b'), L('c'), L(2) a, b, c, l2 = L('a'), L('b'), L('c'), L(2)
mul = N('*', a, b) mul = N('*', a, b)
...@@ -136,7 +143,7 @@ class TestLine(unittest.TestCase): ...@@ -136,7 +143,7 @@ class TestLine(unittest.TestCase):
self.assertEquals(generate_line(mul), 'a * 2') self.assertEquals(generate_line(mul), 'a * 2')
mul = N('*', l2, N('^', a, l2)) mul = N('*', l2, N('^', a, l2))
self.assertEquals(generate_line(mul), '2 * a ^ 2') self.assertEquals(generate_line(mul), '2a ^ 2')
def test_plus_to_minus(self): def test_plus_to_minus(self):
plus = N('+', L(1), -L(2)) plus = N('+', L(1), -L(2))
...@@ -151,6 +158,7 @@ class TestLine(unittest.TestCase): ...@@ -151,6 +158,7 @@ class TestLine(unittest.TestCase):
neg = -l1 neg = -l1
neg_a = -a neg_a = -a
plus = N('+', l1, a) plus = N('+', l1, a)
power = N('^', a, l1)
self.assertTrue(is_id(a)) self.assertTrue(is_id(a))
self.assertTrue(is_id(neg_a)) self.assertTrue(is_id(neg_a))
...@@ -162,6 +170,10 @@ class TestLine(unittest.TestCase): ...@@ -162,6 +170,10 @@ class TestLine(unittest.TestCase):
self.assertFalse(is_int(neg_a)) self.assertFalse(is_int(neg_a))
self.assertFalse(is_int(plus)) self.assertFalse(is_int(plus))
self.assertTrue(is_power(power))
self.assertFalse(is_power(l1))
self.assertFalse(is_power(plus))
def test_negated_operator(self): def test_negated_operator(self):
neg = -N('+', L(1), L(2)) neg = -N('+', L(1), L(2))
self.assertEquals(generate_line(neg), '-(1 + 2)') self.assertEquals(generate_line(neg), '-(1 + 2)')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment