Commit 6b59d5de authored by Taddeus Kroes's avatar Taddeus Kroes

Derivative now supports multiple variables without an error.

parent 19b91caf
from .utils import find_variables from .utils import find_variables, first_sorted_variable
from .logarithmic import ln from .logarithmic import ln
from .goniometry import sin, cos from .goniometry import sin, cos
from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DER, \ from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DER, \
...@@ -32,16 +32,10 @@ def get_derivation_variable(node, variables=None): ...@@ -32,16 +32,10 @@ def get_derivation_variable(node, variables=None):
if not variables: if not variables:
variables = find_variables(node) variables = find_variables(node)
if len(variables) > 1:
# FIXME: Use first variable, sorted alphabetically?
#return sorted(variables)[0]
raise ValueError('More than 1 variable in implicit derivative: '
+ ', '.join(variables))
if not len(variables): if not len(variables):
return None return None
return list(variables)[0] return first_sorted_variable(variables)
def chain_rule(root, args): def chain_rule(root, args):
...@@ -168,7 +162,8 @@ def match_variable_power(node): ...@@ -168,7 +162,8 @@ def match_variable_power(node):
return [P(node, variable_root)] return [P(node, variable_root)]
return [P(node, chain_rule, (root, variable_root, ()))] return [P(node, chain_rule, (root, variable_root, ()))]
elif not x in rvars and x in evars:
if not x in rvars and x in evars:
if exponent.is_variable(): if exponent.is_variable():
return [P(node, variable_exponent)] return [P(node, variable_exponent)]
......
from ..node import ExpressionLeaf as L, OP_MUL, OP_DIV from ..node import ExpressionLeaf as L, OP_MUL, OP_DIV, INFINITY
def greatest_common_divisor(a, b): def greatest_common_divisor(a, b):
...@@ -84,3 +84,35 @@ def find_variables(node): ...@@ -84,3 +84,35 @@ def find_variables(node):
return reduce(lambda a, b: a | b, map(find_variables, node)) return reduce(lambda a, b: a | b, map(find_variables, node))
return set() return set()
def first_sorted_variable(variables):
"""
In a set of variables, find the main variable to be used in a derivation or
integral. The prioritized order is x, y, z, a, b, c, d, ...
"""
for x in 'xyz':
if x in variables:
return x
return sorted(variables)[0]
def infinity():
return L(INFINITY)
def replace_variable(f, x, replacement):
"""
Replace all occurences of variable x in function f with the specified
replacement.
"""
if f == x:
return replacement.clone()
if f.is_leaf:
return f
children = map(lambda c: replace_variable(c, x, replacement), f)
return N(f, *children)
...@@ -16,13 +16,12 @@ from tests.rulestestcase import RulesTestCase, tree ...@@ -16,13 +16,12 @@ from tests.rulestestcase import RulesTestCase, tree
class TestRulesDerivatives(RulesTestCase): class TestRulesDerivatives(RulesTestCase):
def test_get_derivation_variable(self): def test_get_derivation_variable(self):
xy, x, l1 = tree('der(xy, x), der(x), der(1)') xy0, xy1, x, l1 = tree('der(xy, x), der(xy), der(x), der(1)')
self.assertEqual(get_derivation_variable(xy), 'x') self.assertEqual(get_derivation_variable(xy0), 'x')
self.assertEqual(get_derivation_variable(xy1), 'x')
self.assertEqual(get_derivation_variable(x), 'x') self.assertEqual(get_derivation_variable(x), 'x')
self.assertIsNone(get_derivation_variable(l1)) self.assertIsNone(get_derivation_variable(l1))
self.assertRaises(ValueError, tree, 'der(xy)')
def test_match_zero_derivative(self): def test_match_zero_derivative(self):
root = tree('der(x, y)') root = tree('der(x, y)')
self.assertEqualPos(match_zero_derivative(root), self.assertEqualPos(match_zero_derivative(root),
......
import unittest import unittest
from src.rules.utils import least_common_multiple, is_fraction, partition, \ from src.rules.utils import least_common_multiple, is_fraction, partition, \
find_variables find_variables, first_sorted_variable
from tests.rulestestcase import tree from tests.rulestestcase import tree
...@@ -31,3 +31,10 @@ class TestRulesUtils(unittest.TestCase): ...@@ -31,3 +31,10 @@ class TestRulesUtils(unittest.TestCase):
self.assertSetEqual(find_variables(add), set(['x'])) self.assertSetEqual(find_variables(add), set(['x']))
self.assertSetEqual(find_variables(mul0), set(['x'])) self.assertSetEqual(find_variables(mul0), set(['x']))
self.assertSetEqual(find_variables(mul1), set(['x', 'y'])) self.assertSetEqual(find_variables(mul1), set(['x', 'y']))
def test_first_sorted_variable(self):
self.assertEqual(first_sorted_variable(set('ax')), 'x')
self.assertEqual(first_sorted_variable(set('ay')), 'y')
self.assertEqual(first_sorted_variable(set('az')), 'z')
self.assertEqual(first_sorted_variable(set('xz')), 'x')
self.assertEqual(first_sorted_variable(set('bac')), 'a')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment