Explorar o código

Added possibility to pass an operator as an integer to the ExprexxionNode constructor.

Taddeus Kroes %!s(int64=14) %!d(string=hai) anos
pai
achega
9146b782d3
Modificáronse 6 ficheiros con 44 adicións e 33 borrados
  1. 21 11
      src/node.py
  2. 9 13
      src/parser.py
  3. 1 1
      src/rules/derivatives.py
  4. 10 5
      src/rules/goniometry.py
  5. 1 1
      src/rules/logarithmic.py
  6. 2 2
      src/rules/powers.py

+ 21 - 11
src/node.py

@@ -78,7 +78,6 @@ OP_MAP = {
         'der': OP_DER,
         'solve': OP_SOLVE,
         'log': OP_LOG,
-        'ln': OP_LOG,
         '=': OP_EQ,
         '??': OP_POSSIBILITIES,
         '?': OP_HINT,
@@ -86,6 +85,9 @@ OP_MAP = {
         '@': OP_REWRITE,
         }
 
+OP_VALUE_MAP = dict([(v, k) for k, v in OP_MAP.iteritems()])
+OP_MAP['ln'] = OP_LOG
+
 TOKEN_MAP = {
         OP_COMMA: 'COMMA',
         OP_ADD: 'PLUS',
@@ -192,19 +194,20 @@ class ExpressionBase(object):
         return self.type & (TYPE_FLOAT | TYPE_INTEGER)
 
     def __add__(self, other):
-        return ExpressionNode('+', self, to_expression(other))
+        return ExpressionNode(OP_ADD, self, to_expression(other))
 
     def __sub__(self, other):
-        return ExpressionNode('-', self, to_expression(other))
+        #FIXME: return ExpressionNode(OP_ADD, self, -to_expression(other))
+        return ExpressionNode(OP_SUB, self, to_expression(other))
 
     def __mul__(self, other):
-        return ExpressionNode('*', self, to_expression(other))
+        return ExpressionNode(OP_MUL, self, to_expression(other))
 
     def __div__(self, other):
-        return ExpressionNode('/', self, to_expression(other))
+        return ExpressionNode(OP_DIV, self, to_expression(other))
 
     def __pow__(self, other):
-        return ExpressionNode('^', self, to_expression(other))
+        return ExpressionNode(OP_POW, self, to_expression(other))
 
     def __pos__(self):
         return self.reduce_negation()
@@ -238,7 +241,14 @@ class ExpressionNode(Node, ExpressionBase):
     def __init__(self, *args, **kwargs):
         super(ExpressionNode, self).__init__(*args, **kwargs)
         self.type = TYPE_OPERATOR
-        self.op = OP_MAP[args[0]]
+        op = args[0]
+
+        if isinstance(op, str):
+            self.value = op
+            self.op = OP_MAP[op]
+        else:
+            self.value = OP_VALUE_MAP[op]
+            self.op = op
 
     def construct_function(self, children):
         if self.op == OP_DER:
@@ -296,13 +306,13 @@ class ExpressionNode(Node, ExpressionBase):
 
         >>> from src.node import ExpressionNode as N, ExpressionLeaf as L
         >>> c, r, e = L('c'), L('r'), L('e')
-        >>> n1 = N('*', c, N('^', r, e))
+        >>> n1 = N(OP_MUL), c, N('^', r, e))
         >>> n1.extract_polynome()
         (c, r, e)
-        >>> n2 = N('*', N('^', r, e), c)
+        >>> n2 = N(OP_MUL, N('^', r, e), c)
         >>> n2.extract_polynome()
         (c, r, e)
-        >>> n3 = N('-', r)
+        >>> n3 = -r
         >>> n3.extract_polynome()
         (1, -r, 1)
         """
@@ -512,7 +522,7 @@ class Scope(object):
         self.remove(node, replacement=replacement)
 
     def as_nary_node(self):
-        return nary_node(self.node.value, self.nodes).negate(self.node.negated)
+        return nary_node(self.node.op, self.nodes).negate(self.node.negated)
 
 
 def nary_node(operator, scope):

+ 9 - 13
src/parser.py

@@ -16,7 +16,7 @@ from graph_drawing.graph import generate_graph
 
 from node import ExpressionNode as Node, ExpressionLeaf as Leaf, OP_MAP, \
         OP_DER, TOKEN_MAP, TYPE_OPERATOR, OP_COMMA, OP_NEG, OP_MUL, OP_DIV, \
-        OP_LOG, Scope, PI, E, DEFAULT_LOGARITHM_BASE
+        OP_LOG, OP_ADD, Scope, PI, E, DEFAULT_LOGARITHM_BASE, OP_VALUE_MAP
 from rules import RULES
 from strategy import pick_suggestion
 from possibilities import filter_duplicates, apply_suggestion
@@ -389,26 +389,24 @@ class Parser(BisonParser):
             op = values[0].split(' ', 1)[0]
 
             if op == 'ln':
-                return Node('log', values[1], Leaf(E))
+                return Node(OP_LOG, values[1], Leaf(E))
 
             if values[1].is_op(OP_COMMA):
                 return Node(op, *values[1])
 
-            if op == 'log':
-                return Node('log', values[1], Leaf(DEFAULT_LOGARITHM_BASE))
+            if op == OP_VALUE_MAP[OP_LOG]:
+                return Node(OP_LOG, values[1], Leaf(DEFAULT_LOGARITHM_BASE))
 
             m = re.match(r'^log_([0-9]+)', op)
 
             if m:
-                return Node('log', values[1], Leaf(int(m.group(1))))
+                return Node(OP_LOG, values[1], Leaf(int(m.group(1))))
 
             return Node(op, values[1])
 
         if option == 3:  # rule: DERIVATIVE exp
-            op = [k for k, v in OP_MAP.iteritems() if v == OP_DER][0]
-
             # DERIVATIVE looks like 'd/d*x*' -> extract the 'x'
-            return Node(op, values[1], Leaf(values[0][-2]))
+            return Node(OP_DER, values[1], Leaf(values[0][-2]))
 
         if option == 4:  # rule: bracket_derivative
             return values[0]
@@ -422,13 +420,11 @@ class Parser(BisonParser):
                            | bracket_derivative APOSTROPH
         """
 
-        op = [k for k, v in OP_MAP.iteritems() if v == OP_DER][0]
-
         if option == 0:  # rule: LBRACKET exp RBRACKET APOSTROPH
-            return Node(op, values[1])
+            return Node(OP_DER, values[1])
 
         if option == 1:  # rule: bracket_derivative APOSTROPH
-            return Node(op, values[0])
+            return Node(OP_DER, values[0])
 
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
                                % (option, target))  # pragma: nocover
@@ -459,7 +455,7 @@ class Parser(BisonParser):
             # Explicit call the hook handler on the created unary negation.
             node = self.hook_handler('binary', 4, names, values, node)
 
-            return Node('+', values[0], values[2])
+            return Node(OP_ADD, values[0], values[2])
 
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
                                % (option, target))  # pragma: nocover

+ 1 - 1
src/rules/derivatives.py

@@ -10,7 +10,7 @@ from ..translate import _
 
 
 def der(f, x=None):
-    return N('der', f, x) if x else N('der', f)
+    return N(OP_DER, f, x) if x else N(OP_DER, f)
 
 
 def second_arg(node):

+ 10 - 5
src/rules/goniometry.py

@@ -1,20 +1,21 @@
 from .utils import is_fraction
 from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_ADD, \
-        OP_POW, OP_MUL, OP_DIV, OP_SIN, OP_COS, OP_TAN, PI, TYPE_OPERATOR
+        OP_POW, OP_MUL, OP_DIV, OP_SIN, OP_COS, OP_TAN, OP_SQRT, PI, \
+        TYPE_OPERATOR
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
 
 def sin(*args):
-    return N('sin', *args)
+    return N(OP_SIN, *args)
 
 
 def cos(*args):
-    return N('cos', *args)
+    return N(OP_COS, *args)
 
 
 def tan(*args):
-    return N('tan', *args)
+    return N(OP_TAN, *args)
 
 
 def match_add_quadrants(node):
@@ -102,6 +103,10 @@ def match_half_pi_subtraction(node):
     return []
 
 
+def half_pi_subtraction_sinus(root, args):
+    pass
+
+
 def is_pi_frac(node, denominator):
     """
     Check if a node is a fraction of 1 multiplied with PI.
@@ -124,7 +129,7 @@ def is_pi_frac(node, denominator):
 
 
 def sqrt(value):
-    return N('sqrt', L(value))
+    return N(OP_SQRT, L(value))
 
 
 l0, l1, sq2, sq3 = L(0), L(1), sqrt(2), sqrt(3)

+ 1 - 1
src/rules/logarithmic.py

@@ -7,7 +7,7 @@ def log(exponent, base=10):
     if not isinstance(base, L):
         base = L(base)
 
-    return N('log', exponent, base)
+    return N(OP_LOG, exponent, base)
 
 
 def ln(exponent):

+ 2 - 2
src/rules/powers.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
 from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
-                   OP_MUL, OP_DIV, OP_POW, OP_ADD, negate
+                   OP_MUL, OP_DIV, OP_POW, OP_ADD, OP_SQRT, negate
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -238,7 +238,7 @@ def exponent_to_root(root, args):
     """
     a, n, m = args
 
-    return N('sqrt', a if n == 1 else a ** n, m)
+    return N(OP_SQRT, a if n == 1 else a ** n, m)
 
 
 def match_extend_exponent(node):