Ver código fonte

Converted negation to counter instead of unary node.

Taddeus Kroes 14 anos atrás
pai
commit
84ad376b81
3 arquivos alterados com 41 adições e 45 exclusões
  1. 19 24
      line.py
  2. 7 2
      node.py
  3. 15 19
      tests/test_line.py

+ 19 - 24
line.py

@@ -1,6 +1,7 @@
 from node import Leaf
 from traverse import traverse_depth_first
 
+
 OPERATORS = [
         ('+', '-'),
         ('*', '/', 'mod'),
@@ -167,24 +168,11 @@ def generate_line_old(root):
     return traverse(root).replace('+ -', '- ')
 
 
-def is_negation(node):
-    if node.is_leaf:
-        return False
-
-    return node.title() == '-' and len(node) == 1
-
-
 def is_id(node):
-    if is_negation(node):
-        return is_id(node[0])
-
     return node.is_leaf and not node.title().isdigit()
 
 
 def is_int(node):
-    if is_negation(node):
-        return is_int(node[0])
-
     return node.is_leaf and node.title().isdigit()
 
 
@@ -236,7 +224,7 @@ def generate_line(root):
         return '<empty expression>'
 
     if root.is_leaf:
-        return root.title()
+        return '-' * root.negated + root.title()
 
     content = {}
 
@@ -298,7 +286,12 @@ def generate_line(root):
                     and (is_id(right) or right_paren):
                 sep = ''
 
-        return sep.join(e)
+        exp = sep.join(e)
+
+        if node.negated and op != '*':
+            exp = '(' + exp + ')'
+
+        return exp
 
     def construct_function(node):
         buf = []
@@ -313,17 +306,19 @@ def generate_line(root):
     for node in traverse_depth_first(root):
         if node.is_leaf:
             content[node] = node.title()
-            continue
-
-        arity = len(node)
+        else:
+            arity = len(node)
 
-        if is_operator(node):
-            if arity == 1:
-                content[node] = construct_unary(node)
+            if is_operator(node):
+                if arity == 1:
+                    content[node] = construct_unary(node)
+                else:
+                    content[node] = construct_nary(node)
             else:
-                content[node] = construct_nary(node)
-        else:
-            content[node] = construct_function(node)
+                content[node] = construct_function(node)
+
+        # Add negations
+        content[node] = '-' * node.negated + content[node]
 
     # Merge binary plus and unary minus signs into binary minus.
     return content[root].replace('+ -', '- ')

+ 7 - 2
node.py

@@ -1,4 +1,5 @@
 # vim: set fileencoding=utf-8 :
+from copy import deepcopy
 
 
 class Node(object):
@@ -6,6 +7,7 @@ class Node(object):
         super(Node, self).__init__()
         self.value, self.nodes = value, list(nodes)
         self.is_leaf = False
+        self.negated = 0
 
     def __getitem__(self, n):
         return self.nodes[n]
@@ -23,8 +25,11 @@ class Node(object):
         return isinstance(node, Node) \
                and self.value == node.value and self.nodes == node.nodes
 
-    #def __repr__(self):
-    #    return repr(self.value)
+    def __neg__(self):
+        copy = deepcopy(self)
+        copy.negated += 1
+
+        return copy
 
     def __str__(self):
         return '<Node value=%s nodes=%s>' % (str(self.value), str(self.nodes))

+ 15 - 19
tests/test_line.py

@@ -1,7 +1,7 @@
 import unittest
 
 from node import Node as N, Leaf as L
-from line import generate_line, is_negation, is_id, is_int
+from line import generate_line, is_id, is_int
 
 
 class TestLine(unittest.TestCase):
@@ -39,7 +39,7 @@ class TestLine(unittest.TestCase):
     def test_function(self):
         exp = L('x')
         inf = L('oo')
-        minus_inf = N('-', inf)
+        minus_inf = -L('oo')
         integral = N('int', exp, minus_inf, inf)
         self.assertEquals(generate_line(integral), 'int(x, -oo, oo)')
 
@@ -105,7 +105,7 @@ class TestLine(unittest.TestCase):
         mul = N('*', N('*', a, l2), b)
         self.assertEquals(generate_line(mul), 'a * 2b')
 
-        plus = N('*', N('*', N('-', a), b), c)
+        plus = N('*', N('*', -a, b), c)
         self.assertEquals(generate_line(plus), '-abc')
 
         mul = N('*', a, N('-', b, c))
@@ -127,23 +127,19 @@ class TestLine(unittest.TestCase):
         self.assertEquals(generate_line(mul), 'a * 2')
 
     def test_plus_to_minus(self):
-        plus = N('+', L(1), N('-', L(2)))
+        plus = N('+', L(1), -L(2))
         self.assertEquals(generate_line(plus), '1 - 2')
 
         l1, a, b, c = L(1), L('a'), L('b'), L('c')
-        plus = N('+', l1, N('*', N('*', N('-', a), b), c))
+        plus = N('+', l1, N('*', N('*', -a, b), c))
         self.assertEquals(generate_line(plus), '1 - abc')
 
     def test_helper_functions(self):
         l1, a = L(1), L('a')
-        neg = N('-', l1)
-        neg_a = N('-', a)
+        neg = -l1
+        neg_a = -a
         plus = N('+', l1, a)
 
-        self.assertTrue(is_negation(neg))
-        self.assertFalse(is_negation(l1))
-        self.assertFalse(is_negation(plus))
-
         self.assertTrue(is_id(a))
         self.assertTrue(is_id(neg_a))
         self.assertFalse(is_id(neg))
@@ -155,26 +151,26 @@ class TestLine(unittest.TestCase):
         self.assertFalse(is_int(plus))
 
     def test_negated_addition_subtraction(self):
-        neg = N('-', N('+', L(1), L(2)))
+        neg = -N('+', L(1), L(2))
         self.assertEquals(generate_line(neg), '-(1 + 2)')
 
-        neg = N('-', N('-', L(1), L(2)))
+        neg = -N('-', L(1), L(2))
         self.assertEquals(generate_line(neg), '-(1 - 2)')
 
-        neg = N('+', L(1), N('-', N('+', L(1), L(2))))
+        neg = N('+', L(1), -N('+', L(1), L(2)))
         self.assertEquals(generate_line(neg), '1 - (1 + 2)')
 
-        neg = N('-', N('*', L(4), L('a')))
+        neg = -N('*', L(4), L('a'))
         self.assertEquals(generate_line(neg), '-4a')
 
-        neg = N('-', N('*', L(4), L(5)))
+        neg = -N('*', L(4), L(5))
         self.assertEquals(generate_line(neg), '-4 * 5')
 
-        plus = N('+', L(1), N('-', N('*', L(4), L(5))))
+        plus = N('+', L(1), -N('*', L(4), L(5)))
         self.assertEquals(generate_line(plus), '1 - 4 * 5')
 
-        plus = N('+', L(1), N('-', L(4)))
+        plus = N('+', L(1), -L(4))
         self.assertEquals(generate_line(plus), '1 - 4')
 
-        neg = N('-', N('-', L(1)))
+        neg = --L(1)
         self.assertEquals(generate_line(neg), '--1')