Просмотр исходного кода

Implemented integral notation, still produces one shift/reduce conflict.

Taddeus Kroes 14 лет назад
Родитель
Сommit
cbd37711fd
6 измененных файлов с 126 добавлено и 73 удалено
  1. 5 4
      src/node.py
  2. 69 50
      src/parser.py
  3. 17 6
      src/rules/integrals.py
  4. 9 0
      src/rules/utils.py
  5. 15 2
      tests/test_parser.py
  6. 11 11
      tests/test_rules_integrals.py

+ 5 - 4
src/node.py

@@ -207,8 +207,8 @@ class ExpressionBase(object):
         return ExpressionNode(OP_ADD, self, to_expression(other))
         return ExpressionNode(OP_ADD, self, to_expression(other))
 
 
     def __sub__(self, other):
     def __sub__(self, other):
-        #FIXME: return ExpressionNode(OP_ADD, self, -to_expression(other))
-        return ExpressionNode(OP_SUB, self, to_expression(other))
+        return ExpressionNode(OP_ADD, self, -to_expression(other))
+        #FIXME: return ExpressionNode(OP_SUB, self, to_expression(other))
 
 
     def __mul__(self, other):
     def __mul__(self, other):
         return ExpressionNode(OP_MUL, self, to_expression(other))
         return ExpressionNode(OP_MUL, self, to_expression(other))
@@ -428,7 +428,7 @@ class ExpressionLeaf(Leaf, ExpressionBase):
         other_type = type(other)
         other_type = type(other)
 
 
         if other_type in TYPE_MAP:
         if other_type in TYPE_MAP:
-            return TYPE_MAP[other_type] == self.type \
+            return self.type == TYPE_MAP[other_type] \
                    and self.actual_value() == other
                    and self.actual_value() == other
 
 
         return self.negated == other.negated and self.type == other.type \
         return self.negated == other.negated and self.type == other.type \
@@ -466,7 +466,8 @@ class ExpressionLeaf(Leaf, ExpressionBase):
         return (ExpressionLeaf(1), self, ExpressionLeaf(1))
         return (ExpressionLeaf(1), self, ExpressionLeaf(1))
 
 
     def actual_value(self):
     def actual_value(self):
-        assert self.is_numeric()
+        if self.type == TYPE_IDENTIFIER:
+            return self.value
 
 
         return (1 - 2 * (self.negated & 1)) * self.value
         return (1 - 2 * (self.negated & 1)) * self.value
 
 

+ 69 - 50
src/parser.py

@@ -14,11 +14,13 @@ sys.path.insert(1, EXTERNAL_MODS)
 from pybison import BisonParser, BisonSyntaxError
 from pybison import BisonParser, BisonSyntaxError
 from graph_drawing.graph import generate_graph
 from graph_drawing.graph import generate_graph
 
 
-from node import ExpressionNode as Node, ExpressionLeaf as Leaf, OP_MAP, \
-        OP_DER, TOKEN_MAP, TYPE_OPERATOR, OP_COMMA, OP_NEG, OP_MUL, OP_DIV, \
-        OP_LOG, OP_ADD, Scope, E, DEFAULT_LOGARITHM_BASE, OP_VALUE_MAP, \
-        SPECIAL_TOKENS, OP_INT, OP_INT_INDEF
+from node import ExpressionBase, ExpressionNode as Node, \
+        ExpressionLeaf as Leaf, OP_MAP, OP_DER, TOKEN_MAP, TYPE_OPERATOR, \
+        OP_COMMA, OP_NEG, OP_MUL, OP_DIV, OP_POW, OP_LOG, OP_ADD, Scope, E, \
+        DEFAULT_LOGARITHM_BASE, OP_VALUE_MAP, SPECIAL_TOKENS, OP_INT, \
+        OP_INT_INDEF
 from rules import RULES
 from rules import RULES
+from rules.utils import find_variable
 from strategy import pick_suggestion
 from strategy import pick_suggestion
 from possibilities import filter_duplicates, apply_suggestion
 from possibilities import filter_duplicates, apply_suggestion
 
 
@@ -43,16 +45,17 @@ def combine(op, op_type, *nodes):
 
 
 def find_integration_variable(exp):
 def find_integration_variable(exp):
     if not exp.is_op(OP_MUL):
     if not exp.is_op(OP_MUL):
-        return exp
+        return exp, find_variable(exp)
 
 
     scope = Scope(exp)
     scope = Scope(exp)
 
 
-    if len(scope) < 3 or scope[-2] != 'd' or not scope[-1].is_identifier():
-        return exp
+    if len(scope) > 2 and scope[-2] == 'd' and scope[-1].is_identifier():
+        x = scope[-1]
+        scope.nodes = scope[:-2]
 
 
-    scope.nodes = scope[:-2]
+        return scope.as_nary_node(), x
 
 
-    return scope.as_nary_node()
+    return exp, find_variable(exp)
 
 
 
 
 class Parser(BisonParser):
 class Parser(BisonParser):
@@ -62,8 +65,8 @@ class Parser(BisonParser):
     """
     """
 
 
     # Words to be ignored by preprocessor
     # Words to be ignored by preprocessor
-    words = tuple(filter(lambda w: len(w) > 1, OP_MAP.iterkeys())) \
-             + ('raise', 'graph')+ tuple(SPECIAL_TOKENS)
+    words = tuple(filter(lambda w: w.isalpha(), OP_MAP.iterkeys())) \
+             + ('raise', 'graph') + tuple(SPECIAL_TOKENS)
 
 
     # Output directory of generated pybison files, including a trailing slash.
     # Output directory of generated pybison files, including a trailing slash.
     buildDirectory = PYBISON_BUILD + '/'
     buildDirectory = PYBISON_BUILD + '/'
@@ -75,7 +78,7 @@ class Parser(BisonParser):
     # of tokens of the lex script.
     # of tokens of the lex script.
     tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH',
     tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH',
               'LPAREN', 'RPAREN', 'FUNCTION', 'FUNCTION_LPAREN', 'LBRACKET',
               'LPAREN', 'RPAREN', 'FUNCTION', 'FUNCTION_LPAREN', 'LBRACKET',
-              'RBRACKET', 'APOSTROPH', 'DERIVATIVE', 'SUB'] \
+              'RBRACKET', 'PRIME', 'DERIVATIVE', 'SUB'] \
              + filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values())
              + filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values())
 
 
     # ------------------------------
     # ------------------------------
@@ -83,12 +86,14 @@ class Parser(BisonParser):
     # ------------------------------
     # ------------------------------
     precedences = (
     precedences = (
         ('left', ('COMMA', )),
         ('left', ('COMMA', )),
+        ('right', ('INTEGRAL', 'DERIVATIVE')),
         ('left', ('MINUS', 'PLUS')),
         ('left', ('MINUS', 'PLUS')),
         ('left', ('TIMES', 'DIVIDE')),
         ('left', ('TIMES', 'DIVIDE')),
-        ('right', ('FUNCTION', 'DERIVATIVE')),
+        ('right', ('FUNCTION', )),
         ('left', ('EQ', )),
         ('left', ('EQ', )),
         ('left', ('NEG', )),
         ('left', ('NEG', )),
-        ('right', ('POW', 'SUB')),
+        ('right', ('POW', )),
+        ('right', ('SUB', )),
         ('right', ('FUNCTION_LPAREN', )),
         ('right', ('FUNCTION_LPAREN', )),
         )
         )
 
 
@@ -228,7 +233,8 @@ class Parser(BisonParser):
         return data
         return data
 
 
     def hook_handler(self, target, option, names, values, retval):
     def hook_handler(self, target, option, names, values, retval):
-        if target in ['exp', 'line', 'input'] or not retval:
+        if target in ['exp', 'line', 'input'] \
+                or not isinstance(retval, ExpressionBase):
             return retval
             return retval
 
 
         if not retval.negated and retval.type != TYPE_OPERATOR:
         if not retval.negated and retval.type != TYPE_OPERATOR:
@@ -390,7 +396,8 @@ class Parser(BisonParser):
               | FUNCTION exp
               | FUNCTION exp
               | DERIVATIVE exp
               | DERIVATIVE exp
               | bracket_derivative
               | bracket_derivative
-              | integral
+              | INTEGRAL exp
+              | INTEGRAL bounds exp %prec INTEGRAL
         """
         """
 
 
         if option == 0:  # rule: NEG exp
         if option == 0:  # rule: NEG exp
@@ -407,6 +414,11 @@ class Parser(BisonParser):
         if option in (1, 2):  # rule: FUNCTION_LPAREN exp RPAREN | FUNCTION exp
         if option in (1, 2):  # rule: FUNCTION_LPAREN exp RPAREN | FUNCTION exp
             op = values[0].split(' ', 1)[0]
             op = values[0].split(' ', 1)[0]
 
 
+            if op == 'int':
+                fx, x = find_integration_variable(values[1])
+
+                return Node(OP_INT, fx, x)
+
             if op == 'ln':
             if op == 'ln':
                 return Node(OP_LOG, values[1], Leaf(E))
                 return Node(OP_LOG, values[1], Leaf(E))
 
 
@@ -432,54 +444,57 @@ class Parser(BisonParser):
             # DERIVATIVE looks like 'd/d*x*' -> extract the 'x'
             # DERIVATIVE looks like 'd/d*x*' -> extract the 'x'
             return Node(OP_DER, values[1], Leaf(values[0][-2]))
             return Node(OP_DER, values[1], Leaf(values[0][-2]))
 
 
-        if option in (4, 5):  # rule: bracket_derivative | integral
+        if option == 4:  # rule: bracket_derivative
             return values[0]
             return values[0]
 
 
+
+        if option == 5:  # rule: INTEGRAL exp
+            fx, x = find_integration_variable(values[1])
+
+            return Node(OP_INT, fx, x)
+
+        if option == 6:  # rule: INTEGRAL bounds exp
+            lbnd, ubnd = values[1]
+            fx, x = find_integration_variable(values[2])
+
+            return Node(OP_INT, fx, x, lbnd, ubnd)
+
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
                                % (option, target))  # pragma: nocover
                                % (option, target))  # pragma: nocover
 
 
-    def on_bracket_derivative(self, target, option, names, values):
+    def on_bounds(self, target, option, names, values):
         """
         """
-        bracket_derivative : LBRACKET exp RBRACKET APOSTROPH
-                           | bracket_derivative APOSTROPH
+        bounds : SUB power TIMES
         """
         """
 
 
-        if option == 0:  # rule: LBRACKET exp RBRACKET APOSTROPH
-            return Node(OP_DER, values[1])
-
-        if option == 1:  # rule: bracket_derivative APOSTROPH
-            return Node(OP_DER, values[0])
+        if option == 0:  # rule: SUB power
+            return values[1]
 
 
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
                                % (option, target))  # pragma: nocover
                                % (option, target))  # pragma: nocover
 
 
-    def on_integral(self, target, option, names, values):
+    def on_power(self, target, option, names, values):
         """
         """
-        integral : INTEGRAL exp
+        power : exp POW exp
         """
         """
-                 #| INTEGRAL SUB exp exp
-                 #| LBRACKET exp RBRACKET SUB exp exp
 
 
-        if option == 0:  # rule: INTEGRAL exp
-            fx, x = find_integration_variable(values[1])
-
-            return N(OP_INT, fx, x)
-
-        if option == 1:  # rule: INTEGRAL SUB exp exp
-            if not values[2].is_power():  # pragma: nocover
-                raise BisonSyntaxError('No upper bound specified in "%s".'
-                                       % values[2])
+        if option == 0:  # rule: exp POW exp
+            return values[0], values[2]
 
 
-            lbnd, ubnd = values[2]
-            fx, x = find_integration_variable(values[3])
+        raise BisonSyntaxError('Unsupported option %d in target "%s".'
+                               % (option, target))  # pragma: nocover
 
 
-            return N(OP_INT, fx, x, lbnd, ubnd)
+    def on_bracket_derivative(self, target, option, names, values):
+        """
+        bracket_derivative : LBRACKET exp RBRACKET PRIME
+                           | bracket_derivative PRIME
+        """
 
 
-        if option == 2:  # rule: LBRACKET exp RBRACKET SUB exp POWER exp
-            exp = values[1]
-            fx, x = find_integration_variable(values[1])
+        if option == 0:  # rule: LBRACKET exp RBRACKET PRIME
+            return Node(OP_DER, values[1])
 
 
-            return N(OP_INT_INDEF, fx, x, values[4], values[6])
+        if option == 1:  # rule: bracket_derivative PRIME
+            return Node(OP_DER, values[0])
 
 
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
                                % (option, target))  # pragma: nocover
                                % (option, target))  # pragma: nocover
@@ -489,15 +504,15 @@ class Parser(BisonParser):
         binary : exp PLUS exp
         binary : exp PLUS exp
                | exp TIMES exp
                | exp TIMES exp
                | exp DIVIDE exp
                | exp DIVIDE exp
-               | exp POW exp
                | exp EQ exp
                | exp EQ exp
                | exp MINUS exp
                | exp MINUS exp
+               | power
         """
         """
 
 
-        if 0 <= option < 5:  # rule: exp {PLUS,TIMES,DIVIDES,POW,EQ} exp
+        if 0 <= option < 4:  # rule: exp {PLUS,TIMES,DIVIDE,EQ} exp
             return Node(values[1], values[0], values[2])
             return Node(values[1], values[0], values[2])
 
 
-        if option == 5:  # rule: exp MINUS exp
+        if option == 4:  # rule: exp MINUS exp
             node = values[2]
             node = values[2]
 
 
             # Add negation to the left-most child
             # Add negation to the left-most child
@@ -508,10 +523,13 @@ class Parser(BisonParser):
                 node.negated += 1
                 node.negated += 1
 
 
             # Explicit call the hook handler on the created unary negation.
             # Explicit call the hook handler on the created unary negation.
-            node = self.hook_handler('binary', 4, names, values, node)
+            self.hook_handler('binary', 3, names, values, node)
 
 
             return Node(OP_ADD, values[0], values[2])
             return Node(OP_ADD, values[0], values[2])
 
 
+        if option == 5:  # rule: power
+            return Node(OP_POW, *values[0])
+
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
                                % (option, target))  # pragma: nocover
                                % (option, target))  # pragma: nocover
 
 
@@ -591,7 +609,7 @@ class Parser(BisonParser):
     ")"       { returntoken(RPAREN); }
     ")"       { returntoken(RPAREN); }
     "["       { returntoken(LBRACKET); }
     "["       { returntoken(LBRACKET); }
     "]"       { returntoken(RBRACKET); }
     "]"       { returntoken(RBRACKET); }
-    "'"       { returntoken(APOSTROPH); }
+    "'"       { returntoken(PRIME); }
     log_([0-9]+|[a-zA-Z])"*(" { returntoken(FUNCTION_LPAREN); }
     log_([0-9]+|[a-zA-Z])"*(" { returntoken(FUNCTION_LPAREN); }
     log_([0-9]+|[a-zA-Z])"*" { returntoken(FUNCTION); }
     log_([0-9]+|[a-zA-Z])"*" { returntoken(FUNCTION); }
     """ + operators + r"""
     """ + operators + r"""
@@ -607,3 +625,4 @@ class Parser(BisonParser):
 
 
     yywrap() { return(1); }
     yywrap() { return(1); }
     """
     """
+    #int[ ]*"(" { returntoken(FUNCTION_LPAREN); }

+ 17 - 6
src/rules/integrals.py

@@ -2,16 +2,27 @@ from .utils import find_variables, first_sorted_variable, infinity, \
         replace_variable
         replace_variable
 from .logarithmic import ln
 from .logarithmic import ln
 #from .goniometry import sin, cos
 #from .goniometry import sin, cos
-from ..node import ExpressionLeaf as L, OP_INT
+from ..node import ExpressionNode as N, ExpressionLeaf as L, OP_INT
 from ..possibilities import Possibility as P, MESSAGES
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 from ..translate import _
 
 
 
 
-#def ader(f, x=None):
-#    """
-#    Anti-derivative.
-#    """
-#    return N(OP_INT, f, x)
+def integral(f, x=None, lbnd=None, ubnd=None):
+    """
+    Anti-derivative.
+    """
+    params = [f]
+
+    if x:
+        params.append(x)
+
+    if lbnd:
+        params.append(lbnd)
+
+    if ubnd:
+        params.append(ubnd)
+
+    return N(OP_INT, *params)
 
 
 
 
 def integral_params(integral):
 def integral_params(integral):

+ 9 - 0
src/rules/utils.py

@@ -98,6 +98,15 @@ def first_sorted_variable(variables):
     return sorted(variables)[0]
     return sorted(variables)[0]
 
 
 
 
+def find_variable(exp):
+    variables = find_variables(exp)
+
+    if not len(variables):
+        variables.add('x')
+
+    return L(first_sorted_variable(variables))
+
+
 def infinity():
 def infinity():
     return L(INFINITY)
     return L(INFINITY)
 
 

+ 15 - 2
tests/test_parser.py

@@ -9,6 +9,7 @@ from tests.rulestestcase import tree
 from src.rules.goniometry import sin, cos
 from src.rules.goniometry import sin, cos
 from src.rules.derivatives import der
 from src.rules.derivatives import der
 from src.rules.logarithmic import log, ln
 from src.rules.logarithmic import log, ln
+from src.rules.integrals import integral
 
 
 
 
 class TestParser(unittest.TestCase):
 class TestParser(unittest.TestCase):
@@ -73,8 +74,8 @@ class TestParser(unittest.TestCase):
 
 
         self.assertEqual(tree('d/dx x ^ 2'), der(exp, x))
         self.assertEqual(tree('d/dx x ^ 2'), der(exp, x))
         self.assertEqual(tree('d / dx x ^ 2'), der(exp, x))
         self.assertEqual(tree('d / dx x ^ 2'), der(exp, x))
-        self.assertEqual(tree('d/dx x ^ 2 + x'), der(exp, x) + x)
-        self.assertEqual(tree('d/dx (x ^ 2 + x)'), der(exp + x, x))
+        self.assertEqual(tree('d/dx x ^ 2 + x'), der(exp + x, x))
+        self.assertEqual(tree('(d/dx x ^ 2) + x'), der(exp, x) + x)
         self.assertEqual(tree('d/d'), d / d)
         self.assertEqual(tree('d/d'), d / d)
         # FIXME: self.assertEqual(tree('d(x ^ 2)/dx'), der(exp, x))
         # FIXME: self.assertEqual(tree('d(x ^ 2)/dx'), der(exp, x))
 
 
@@ -97,3 +98,15 @@ class TestParser(unittest.TestCase):
             a, t = Leaf('a'), Leaf(token)
             a, t = Leaf('a'), Leaf(token)
             self.assertEqual(tree('a' + token), a * t)
             self.assertEqual(tree('a' + token), a * t)
             # FIXME: self.assertEqual(tree('a' + token + 'a'), a * t * a)
             # FIXME: self.assertEqual(tree('a' + token + 'a'), a * t * a)
+
+    def test_integral(self):
+        x, y, dx, a, b = tree('x, y, dx, a, b')
+
+        self.assertEqual(tree('int x'), integral(x, x))
+        self.assertEqual(tree('int x2'), integral(x ** 2, x))
+        self.assertEqual(tree('int x2 dx'), integral(x ** 2, x))
+        self.assertEqual(tree('int x2 dy'), integral(x ** 2, y))
+
+        self.assertEqual(tree('int_a^b x2 dy'), integral(x ** 2, y, a, b))
+        self.assertEqual(tree('int_(a-b)^(a+b) x2'),
+                         integral(x ** 2, x, a - b, a + b))

+ 11 - 11
tests/test_rules_integrals.py

@@ -10,36 +10,36 @@ from tests.rulestestcase import RulesTestCase, tree
 class TestRulesIntegrals(RulesTestCase):
 class TestRulesIntegrals(RulesTestCase):
 
 
     def test_integral_params(self):
     def test_integral_params(self):
-        f, x = root = tree('int(fx, x)')
+        f, x = root = tree('int fx dx')
         self.assertEqual(integral_params(root), (f, x))
         self.assertEqual(integral_params(root), (f, x))
 
 
-        root = tree('int(fx)')
+        root = tree('int fx')
         self.assertEqual(integral_params(root), (f, x))
         self.assertEqual(integral_params(root), (f, x))
 
 
-        root = tree('int(3)')
-        self.assertEqual(integral_params(root), (3, None))
+        root = tree('int 3')
+        self.assertEqual(integral_params(root), (3, x))
 
 
     def test_choose_constant(self):
     def test_choose_constant(self):
         a, b, c = tree('a, b, c')
         a, b, c = tree('a, b, c')
-        self.assertEqual(choose_constant(tree('int(x ^ n, x)')), c)
-        self.assertEqual(choose_constant(tree('int(x ^ c, x)')), a)
-        self.assertEqual(choose_constant(tree('int(a ^ c, a)')), b)
+        self.assertEqual(choose_constant(tree('int x ^ n')), c)
+        self.assertEqual(choose_constant(tree('int x ^ c')), a)
+        self.assertEqual(choose_constant(tree('int a ^ c da')), b)
 
 
     def test_match_integrate_variable_power(self):
     def test_match_integrate_variable_power(self):
-        for root in tree('int(x ^ n, x), int(x ^ n)'):
+        for root in tree('int x ^ n, int x ^ n'):
             self.assertEqualPos(match_integrate_variable_power(root),
             self.assertEqualPos(match_integrate_variable_power(root),
                     [P(root, integrate_variable_root)])
                     [P(root, integrate_variable_root)])
 
 
-        for root in tree('int(g ^ x, x), int(g ^ x)'):
+        for root in tree('int g ^ x, int g ^ x'):
             self.assertEqualPos(match_integrate_variable_power(root),
             self.assertEqualPos(match_integrate_variable_power(root),
                     [P(root, integrate_variable_exponent)])
                     [P(root, integrate_variable_exponent)])
 
 
     def test_integrate_variable_root(self):
     def test_integrate_variable_root(self):
-        ((x, n),), c = root, c = tree('int(x ^ n), c')
+        ((x, n), x), c = root, c = tree('int x ^ n, c')
         self.assertEqual(integrate_variable_root(root, ()),
         self.assertEqual(integrate_variable_root(root, ()),
                          x ** (n + 1) / (n + 1) + c)
                          x ** (n + 1) / (n + 1) + c)
 
 
     def test_integrate_variable_exponent(self):
     def test_integrate_variable_exponent(self):
-        ((g, x),), c = root, c = tree('int(g ^ x), c')
+        ((g, x), x), c = root, c = tree('int g ^ x, c')
         self.assertEqual(integrate_variable_exponent(root, ()),
         self.assertEqual(integrate_variable_exponent(root, ()),
                          g ** x / ln(g) + c)
                          g ** x / ln(g) + c)