Skip to content
Snippets Groups Projects
Commit 6b59d5de authored by Taddeus Kroes's avatar Taddeus Kroes
Browse files

Derivative now supports multiple variables without an error.

parent 19b91caf
No related branches found
No related tags found
No related merge requests found
from .utils import find_variables from .utils import find_variables, first_sorted_variable
from .logarithmic import ln from .logarithmic import ln
from .goniometry import sin, cos from .goniometry import sin, cos
from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DER, \ from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DER, \
...@@ -32,16 +32,10 @@ def get_derivation_variable(node, variables=None): ...@@ -32,16 +32,10 @@ def get_derivation_variable(node, variables=None):
if not variables: if not variables:
variables = find_variables(node) variables = find_variables(node)
if len(variables) > 1:
# FIXME: Use first variable, sorted alphabetically?
#return sorted(variables)[0]
raise ValueError('More than 1 variable in implicit derivative: '
+ ', '.join(variables))
if not len(variables): if not len(variables):
return None return None
return list(variables)[0] return first_sorted_variable(variables)
def chain_rule(root, args): def chain_rule(root, args):
...@@ -168,7 +162,8 @@ def match_variable_power(node): ...@@ -168,7 +162,8 @@ def match_variable_power(node):
return [P(node, variable_root)] return [P(node, variable_root)]
return [P(node, chain_rule, (root, variable_root, ()))] return [P(node, chain_rule, (root, variable_root, ()))]
elif not x in rvars and x in evars:
if not x in rvars and x in evars:
if exponent.is_variable(): if exponent.is_variable():
return [P(node, variable_exponent)] return [P(node, variable_exponent)]
......
from ..node import ExpressionLeaf as L, OP_MUL, OP_DIV from ..node import ExpressionLeaf as L, OP_MUL, OP_DIV, INFINITY
def greatest_common_divisor(a, b): def greatest_common_divisor(a, b):
...@@ -84,3 +84,35 @@ def find_variables(node): ...@@ -84,3 +84,35 @@ def find_variables(node):
return reduce(lambda a, b: a | b, map(find_variables, node)) return reduce(lambda a, b: a | b, map(find_variables, node))
return set() return set()
def first_sorted_variable(variables):
"""
In a set of variables, find the main variable to be used in a derivation or
integral. The prioritized order is x, y, z, a, b, c, d, ...
"""
for x in 'xyz':
if x in variables:
return x
return sorted(variables)[0]
def infinity():
return L(INFINITY)
def replace_variable(f, x, replacement):
"""
Replace all occurences of variable x in function f with the specified
replacement.
"""
if f == x:
return replacement.clone()
if f.is_leaf:
return f
children = map(lambda c: replace_variable(c, x, replacement), f)
return N(f, *children)
...@@ -16,13 +16,12 @@ from tests.rulestestcase import RulesTestCase, tree ...@@ -16,13 +16,12 @@ from tests.rulestestcase import RulesTestCase, tree
class TestRulesDerivatives(RulesTestCase): class TestRulesDerivatives(RulesTestCase):
def test_get_derivation_variable(self): def test_get_derivation_variable(self):
xy, x, l1 = tree('der(xy, x), der(x), der(1)') xy0, xy1, x, l1 = tree('der(xy, x), der(xy), der(x), der(1)')
self.assertEqual(get_derivation_variable(xy), 'x') self.assertEqual(get_derivation_variable(xy0), 'x')
self.assertEqual(get_derivation_variable(xy1), 'x')
self.assertEqual(get_derivation_variable(x), 'x') self.assertEqual(get_derivation_variable(x), 'x')
self.assertIsNone(get_derivation_variable(l1)) self.assertIsNone(get_derivation_variable(l1))
self.assertRaises(ValueError, tree, 'der(xy)')
def test_match_zero_derivative(self): def test_match_zero_derivative(self):
root = tree('der(x, y)') root = tree('der(x, y)')
self.assertEqualPos(match_zero_derivative(root), self.assertEqualPos(match_zero_derivative(root),
......
import unittest import unittest
from src.rules.utils import least_common_multiple, is_fraction, partition, \ from src.rules.utils import least_common_multiple, is_fraction, partition, \
find_variables find_variables, first_sorted_variable
from tests.rulestestcase import tree from tests.rulestestcase import tree
...@@ -31,3 +31,10 @@ class TestRulesUtils(unittest.TestCase): ...@@ -31,3 +31,10 @@ class TestRulesUtils(unittest.TestCase):
self.assertSetEqual(find_variables(add), set(['x'])) self.assertSetEqual(find_variables(add), set(['x']))
self.assertSetEqual(find_variables(mul0), set(['x'])) self.assertSetEqual(find_variables(mul0), set(['x']))
self.assertSetEqual(find_variables(mul1), set(['x', 'y'])) self.assertSetEqual(find_variables(mul1), set(['x', 'y']))
def test_first_sorted_variable(self):
self.assertEqual(first_sorted_variable(set('ax')), 'x')
self.assertEqual(first_sorted_variable(set('ay')), 'y')
self.assertEqual(first_sorted_variable(set('az')), 'z')
self.assertEqual(first_sorted_variable(set('xz')), 'x')
self.assertEqual(first_sorted_variable(set('bac')), 'a')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment