Kaynağa Gözat

Implemented constant folding.

Taddeus Kroes 14 yıl önce
ebeveyn
işleme
f7e3d0c66b
2 değiştirilmiş dosya ile 75 ekleme ve 27 silme
  1. 73 26
      src/optimize/advanced.py
  2. 2 1
      tests/test_optimize_advanced.py

+ 73 - 26
src/optimize/advanced.py

@@ -65,43 +65,90 @@ def to_hex(value):
 def fold_constants(block):
     """
     Constant folding:
+    - An immidiate load defines a register value:
+        li $reg, XX     ->  register[$reg] = XX
     - Integer variable definition is of the following form:
-        li $reg, XX
-        sw $reg, VAR
-      save this as:
-        reg[$reg] = XX
-        constants[VAR] = XX
+        li $reg, XX     ->  constants[VAR] = XX
+        sw $reg, VAR    ->  register[$reg] = XX
     - When a variable is used, the following happens:
-        lw $reg, VAR
-      save this as:
-        reg[$reg] = constants[VAR]
+        lw $reg, VAR    ->  register[$reg] = constants[VAR]
     """
+    found = False
+
+    # Variable values
     constants = {}
-    reg = {}
+
+    # Current known values in register
+    register = {}
 
     while not block.end():
         s = block.read()
 
-        if s.is_load():
-            constants[s[0]] = s[1]
-        elif s.is_command() and len(s) == 3:
-            d, s, t = s
+        if not s.is_command():
+            continue
+
+        if s.name == 'li':
+            # Save value in register
+            register[s[0]] = int(s[1], 16)
+        elif s.name == 'move' and s[0] in register:
+            reg_to, reg_from = s
+
+            if reg_from in register:
+                # Other value is also known, copy its value
+                register[reg_to] = register[reg_to]
+            else:
+                # Other value is unknown, delete the value
+                del register[reg_to]
+        elif s.name == 'sw' and s[0] in register:
+            # Constant variable definition, e.g. 'int a = 1;'
+            constants[s[1]] = register[s[0]]
+        elif s.name == 'lw' and s[1] in constants:
+            # Usage of variable with constant value
+            register[s[0]] = constants[s[1]]
+        elif s.name in ['addu', 'subu', 'mult', 'div']:
+            # Calculation with constants
+            rd, rs, rt = s
+            rs_known = rs in register
+            rt_known = rt in register
+
+            if rs_known and rt_known:
+                # a = 5         ->  b = 15
+                # c = 10
+                # b = a + c
+                rs_val = register[rs]
+                rt_val = register[rt]
 
-            if s in constants and t in constants:
                 if s.name == 'addu':
-                    result = s + t
-                elif s.name == 'subu':
-                    result = s - t
-                elif s.name == 'mult':
-                    result = s * t
-                elif s.name == 'div':
-                    result = s / t
-
-                block.replace(1, [S('command', 'li', to_hex(result))])
-                constants[d] = result
-            #else:
+                    result = to_hex(rs_val + rt_val)
 
-    return False
+                if s.name == 'subu':
+                    result = to_hex(rs_val - rt_val)
+
+                if s.name == 'mult':
+                    result = to_hex(rs_val * rt_val)
+
+                if s.name == 'div':
+                    result = to_hex(rs_val / rt_val)
+
+                block.replace(1, [S('command', 'li', result)])
+                register[rd] = result
+                found = True
+            elif rt_known:
+                # c = 10        ->  b = a + 10
+                # b = c + a
+                s[2] = register[rt]
+                found = True
+            elif rs_known and s.name in ['addu', 'mult']:
+                # a = 10        ->  b = c + 10
+                # b = c + a
+                s[1] = rt
+                s[2] = register[rs]
+                found = True
+        elif len(s) and s[0] in register:
+            # Known register is overwritten, remove its value
+            del register[s[0]]
+
+    return found
 
 
 def copy_propagation(block):

+ 2 - 1
tests/test_optimize_advanced.py

@@ -1,6 +1,7 @@
 import unittest
 
-from src.optimize.advanced import eliminate_common_subexpressions
+from src.optimize.advanced import eliminate_common_subexpressions, \
+        fold_constants, copy_propagation
 from src.statement import Statement as S, Block as B