Commit 75ad53c8 authored by Taddeus Kroes's avatar Taddeus Kroes

Added some tests for utilities, together with some small bugfixes.

parent c5624470
...@@ -31,8 +31,7 @@ def is_fraction(node, nominator, denominator): ...@@ -31,8 +31,7 @@ def is_fraction(node, nominator, denominator):
Check if a node represents the fraction of a given nominator and Check if a node represents the fraction of a given nominator and
denominator. denominator.
>>> from ..node import ExpressionLeaf as L >>> a, l1, l2 = L('a'), L(1), L(2)
>>> l1, l2, a = L('a'), L(1), L(2)
>>> is_fraction(a / l2, a, 2) >>> is_fraction(a / l2, a, 2)
True True
>>> is_fraction(l1 / l2 * a, a, 2) >>> is_fraction(l1 / l2 * a, a, 2)
...@@ -100,6 +99,11 @@ def first_sorted_variable(variables): ...@@ -100,6 +99,11 @@ def first_sorted_variable(variables):
def find_variable(exp): def find_variable(exp):
"""
Find the main (e.g. first prioritized) variable in an expression and return
it as an ExpressionNode object. If no variable is present, return 'x' by
default.
"""
variables = find_variables(exp) variables = find_variables(exp)
if not len(variables): if not len(variables):
...@@ -108,10 +112,6 @@ def find_variable(exp): ...@@ -108,10 +112,6 @@ def find_variable(exp):
return L(first_sorted_variable(variables)) return L(first_sorted_variable(variables))
def infinity():
return L(INFINITY)
def replace_variable(f, x, replacement): def replace_variable(f, x, replacement):
""" """
Replace all occurences of variable x in function f with the specified Replace all occurences of variable x in function f with the specified
...@@ -126,3 +126,10 @@ def replace_variable(f, x, replacement): ...@@ -126,3 +126,10 @@ def replace_variable(f, x, replacement):
children = map(lambda c: replace_variable(c, x, replacement), f) children = map(lambda c: replace_variable(c, x, replacement), f)
return N(f.op, *children) return N(f.op, *children)
def infinity():
"""
Return an infinity leaf node.
"""
return L(INFINITY)
import unittest import unittest
import doctest
from src.node import ExpressionNode from src.node import ExpressionNode
from src.parser import Parser from src.parser import Parser
from tests.parser import ParserWrapper from tests.parser import ParserWrapper
...@@ -14,6 +16,10 @@ def rewrite(exp, **kwargs): ...@@ -14,6 +16,10 @@ def rewrite(exp, **kwargs):
class RulesTestCase(unittest.TestCase): class RulesTestCase(unittest.TestCase):
def assertDoctests(self, module):
self.assertEqual(doctest.testmod(m=module)[0], 0,
'There are failed doctests.')
def assertEqualPos(self, possibilities, expected): def assertEqualPos(self, possibilities, expected):
self.assertEqual(len(possibilities), len(expected)) self.assertEqual(len(possibilities), len(expected))
......
import unittest from src.rules import utils
from src.rules.utils import least_common_multiple, is_fraction, partition, \ from src.rules.utils import least_common_multiple, is_fraction, partition, \
find_variables, first_sorted_variable find_variables, first_sorted_variable, find_variable, \
from tests.rulestestcase import tree replace_variable, infinity
from tests.rulestestcase import tree, RulesTestCase
class TestRulesUtils(RulesTestCase):
class TestRulesUtils(unittest.TestCase): def test_doctest(self):
self.assertDoctests(utils)
def test_least_common_multiple(self): def test_least_common_multiple(self):
self.assertEqual(least_common_multiple(5, 6), 30) self.assertEqual(least_common_multiple(5, 6), 30)
...@@ -19,11 +22,15 @@ class TestRulesUtils(unittest.TestCase): ...@@ -19,11 +22,15 @@ class TestRulesUtils(unittest.TestCase):
self.assertTrue(is_fraction(l1 / 2 * a, a, 2)) self.assertTrue(is_fraction(l1 / 2 * a, a, 2))
self.assertTrue(is_fraction(a * (l1 / 2), a, 2)) self.assertTrue(is_fraction(a * (l1 / 2), a, 2))
self.assertFalse(is_fraction(l1 / 3 * a, a, 2)) self.assertFalse(is_fraction(l1 / 3 * a, a, 2))
self.assertFalse(is_fraction(l1, a, 2))
def test_partition(self): def test_partition(self):
self.assertEqual(partition(lambda x: x & 1, range(6)), self.assertEqual(partition(lambda x: x & 1, range(6)),
([1, 3, 5], [0, 2, 4])) ([1, 3, 5], [0, 2, 4]))
def test_infinity(self):
self.assertEqual(infinity(), tree('oo'))
def test_find_variables(self): def test_find_variables(self):
x, l2, add, mul0, mul1 = tree('x, 2, x + 2, 2x, xy') x, l2, add, mul0, mul1 = tree('x, 2, x + 2, 2x, xy')
self.assertSetEqual(find_variables(x), set(['x'])) self.assertSetEqual(find_variables(x), set(['x']))
...@@ -38,3 +45,19 @@ class TestRulesUtils(unittest.TestCase): ...@@ -38,3 +45,19 @@ class TestRulesUtils(unittest.TestCase):
self.assertEqual(first_sorted_variable(set('az')), 'z') self.assertEqual(first_sorted_variable(set('az')), 'z')
self.assertEqual(first_sorted_variable(set('xz')), 'x') self.assertEqual(first_sorted_variable(set('xz')), 'x')
self.assertEqual(first_sorted_variable(set('bac')), 'a') self.assertEqual(first_sorted_variable(set('bac')), 'a')
def test_find_variable(self):
x, y = tree('x, y')
self.assertEqual(find_variable(tree('x')), x)
self.assertEqual(find_variable(tree('x ^ 2')), x)
self.assertEqual(find_variable(tree('1 + 2')), x)
self.assertEqual(find_variable(tree('y ^ 2')), y)
def test_replace_variable(self):
x, a = tree('x, a')
self.assertEqual(replace_variable(x, x, a), a)
self.assertEqual(replace_variable(tree('x2'), x, a), tree('a2'))
self.assertEqual(replace_variable(tree('y + x + 1'), x, a),
tree('y + a + 1'))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment