Ver código fonte

Derivative now supports multiple variables without an error.

Taddeus Kroes 14 anos atrás
pai
commit
6b59d5de1f

+ 4 - 9
src/rules/derivatives.py

@@ -1,4 +1,4 @@
-from .utils import find_variables
+from .utils import find_variables, first_sorted_variable
 from .logarithmic import ln
 from .goniometry import sin, cos
 from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DER, \
@@ -32,16 +32,10 @@ def get_derivation_variable(node, variables=None):
     if not variables:
         variables = find_variables(node)
 
-    if len(variables) > 1:
-        # FIXME: Use first variable, sorted alphabetically?
-        #return sorted(variables)[0]
-        raise ValueError('More than 1 variable in implicit derivative: '
-                         + ', '.join(variables))
-
     if not len(variables):
         return None
 
-    return list(variables)[0]
+    return first_sorted_variable(variables)
 
 
 def chain_rule(root, args):
@@ -168,7 +162,8 @@ def match_variable_power(node):
             return [P(node, variable_root)]
 
         return [P(node, chain_rule, (root, variable_root, ()))]
-    elif not x in rvars and x in evars:
+
+    if not x in rvars and x in evars:
         if exponent.is_variable():
             return [P(node, variable_exponent)]
 

+ 33 - 1
src/rules/utils.py

@@ -1,4 +1,4 @@
-from ..node import ExpressionLeaf as L, OP_MUL, OP_DIV
+from ..node import ExpressionLeaf as L, OP_MUL, OP_DIV, INFINITY
 
 
 def greatest_common_divisor(a, b):
@@ -84,3 +84,35 @@ def find_variables(node):
         return reduce(lambda a, b: a | b, map(find_variables, node))
 
     return set()
+
+
+def first_sorted_variable(variables):
+    """
+    In a set of variables, find the main variable to be used in a derivation or
+    integral. The prioritized order is x, y, z, a, b, c, d, ...
+    """
+    for x in 'xyz':
+        if x in variables:
+            return x
+
+    return sorted(variables)[0]
+
+
+def infinity():
+    return L(INFINITY)
+
+
+def replace_variable(f, x, replacement):
+    """
+    Replace all occurences of variable x in function f with the specified
+    replacement.
+    """
+    if f == x:
+        return replacement.clone()
+
+    if f.is_leaf:
+        return f
+
+    children = map(lambda c: replace_variable(c, x, replacement), f)
+
+    return N(f, *children)

+ 3 - 4
tests/test_rules_derivatives.py

@@ -16,13 +16,12 @@ from tests.rulestestcase import RulesTestCase, tree
 class TestRulesDerivatives(RulesTestCase):
 
     def test_get_derivation_variable(self):
-        xy, x, l1 = tree('der(xy, x), der(x), der(1)')
-        self.assertEqual(get_derivation_variable(xy), 'x')
+        xy0, xy1, x, l1 = tree('der(xy, x), der(xy), der(x), der(1)')
+        self.assertEqual(get_derivation_variable(xy0), 'x')
+        self.assertEqual(get_derivation_variable(xy1), 'x')
         self.assertEqual(get_derivation_variable(x), 'x')
         self.assertIsNone(get_derivation_variable(l1))
 
-        self.assertRaises(ValueError, tree, 'der(xy)')
-
     def test_match_zero_derivative(self):
         root = tree('der(x, y)')
         self.assertEqualPos(match_zero_derivative(root),

+ 8 - 1
tests/test_rules_utils.py

@@ -1,7 +1,7 @@
 import unittest
 
 from src.rules.utils import least_common_multiple, is_fraction, partition, \
-        find_variables
+        find_variables, first_sorted_variable
 from tests.rulestestcase import tree
 
 
@@ -31,3 +31,10 @@ class TestRulesUtils(unittest.TestCase):
         self.assertSetEqual(find_variables(add), set(['x']))
         self.assertSetEqual(find_variables(mul0), set(['x']))
         self.assertSetEqual(find_variables(mul1), set(['x', 'y']))
+
+    def test_first_sorted_variable(self):
+        self.assertEqual(first_sorted_variable(set('ax')), 'x')
+        self.assertEqual(first_sorted_variable(set('ay')), 'y')
+        self.assertEqual(first_sorted_variable(set('az')), 'z')
+        self.assertEqual(first_sorted_variable(set('xz')), 'x')
+        self.assertEqual(first_sorted_variable(set('bac')), 'a')