Ver Fonte

Fixed Common Subexpression Elimination.

Taddeus Kroes há 14 anos atrás
pai
commit
950105c0b6
3 ficheiros alterados com 67 adições e 39 exclusões
  1. 47 35
      src/optimize/advanced.py
  2. 17 1
      src/statement.py
  3. 3 3
      tests/test_optimize_advanced.py

+ 47 - 35
src/optimize/advanced.py

@@ -2,22 +2,31 @@ from src.statement import Statement as S
 from math import log
 from math import log
 
 
 
 
-def reg_dead_in(var, context):
-    """Check if a register is `dead' in a given list of statements."""
-    # TODO: Finish
-    for s in context:
-        if s.defines(var) or s.uses(var):
+def reg_can_be_used_in(reg, block, start, end):
+    """Check if a register addres safely be used in a block section using local
+    dataflow analysis."""
+    # Check if the register used or defined in the block section
+    for s in block[start:end]:
+        if s.uses(reg) or s.defines(reg):
             return False
             return False
 
 
+    # Check if the register is used inside the block after the specified
+    # section, without having been re-assigned first
+    for s in block[end:]:
+        if s.uses(reg):
+            return False
+        elif s.defines(reg):
+            return True
+
     return True
     return True
 
 
 
 
-def find_free_reg(context):
+def find_free_reg(block, start, end):
     """Find a temporary register that is free in a given list of statements."""
     """Find a temporary register that is free in a given list of statements."""
-    for i in xrange(8):
-        tmp = '$t%d' % i
+    for i in xrange(8, 16):
+        tmp = '$%d' % i
 
 
-        if reg_dead_in(tmp, context):
+        if reg_can_be_used_in(tmp, block, start, end):
             return tmp
             return tmp
 
 
     raise Exception('No temporary register is available.')
     raise Exception('No temporary register is available.')
@@ -35,52 +44,55 @@ def eliminate_common_subexpressions(block):
     - If the statement can be possibly be eliminated, walk further collecting
     - If the statement can be possibly be eliminated, walk further collecting
       all other occurrences of the expression until one of the arguments is
       all other occurrences of the expression until one of the arguments is
       assigned in a statement, or the start of the block has been reached.
       assigned in a statement, or the start of the block has been reached.
-    - If one or more occurrences were found, insert the expression with a new
-      destination address before the last found occurrence and change all
+    - If one or more occurrences were changed, insert the expression with a new
+      destination address before the last changed occurrence and change all
       occurrences to a move instruction from that address.
       occurrences to a move instruction from that address.
     """
     """
-    found = False
-    block.reverse_statements()
+    changed = False
 
 
     while not block.end():
     while not block.end():
         s = block.read()
         s = block.read()
 
 
         if s.is_arith():
         if s.is_arith():
             pointer = block.pointer
             pointer = block.pointer
-            last = False
-            new_reg = False
+            occurrences = [pointer - 1]
             args = s[1:]
             args = s[1:]
 
 
             # Collect similar statements
             # Collect similar statements
             while not block.end():
             while not block.end():
                 s2 = block.read()
                 s2 = block.read()
 
 
+                if not s2.is_command():
+                    continue
+
                 # Stop if one of the arguments is assigned
                 # Stop if one of the arguments is assigned
                 if len(s2) and s2[0] in args:
                 if len(s2) and s2[0] in args:
                     break
                     break
 
 
                 # Replace a similar expression by a move instruction
                 # Replace a similar expression by a move instruction
                 if s2.name == s.name and s2[1:] == args:
                 if s2.name == s.name and s2[1:] == args:
-                    if not new_reg:
-                        new_reg = find_free_reg(block[:pointer])
+                    occurrences.append(block.pointer - 1)
 
 
-                    block.replace(1, [S('command', 'move', s2[0], new_reg)])
-                    last = block.pointer
+            if len(occurrences) > 1:
+                new_reg = find_free_reg(block, pointer, occurrences[-1])
 
 
-            # Reset pointer to and continue from the original statement
-            block.pointer = pointer
+                # Replace all occurrences with a move statement
+                for occurrence in occurrences:
+                    rd = block[occurrence][0]
+                    block.replace(1, [S('command', 'move', rd, new_reg)], \
+                            start=occurrence)
 
 
-            if last:
-                # Insert an additional expression with a new destination address
-                block.insert(S('command', s.name, *([new_reg] + args)), last)
+                # Insert the calculation before the original with the new
+                # destination address
+                block.insert(S('command', s.name, *([new_reg] + args)), \
+                             index=occurrences[0])
 
 
-                # Replace the original expression with a move statement
-                block.replace(1, [S('command', 'move', s[0], new_reg)])
-                found = True
+                changed = True
 
 
-    block.reverse_statements()
+            # Reset pointer to continue from the original statement
+            block.pointer = pointer
 
 
-    return found
+    return changed
 
 
 
 
 def to_hex(value):
 def to_hex(value):
@@ -103,7 +115,7 @@ def fold_constants(block):
     - When a variable is used, the following happens:
     - When a variable is used, the following happens:
         lw $reg, VAR    ->  register[$reg] = constants[VAR]
         lw $reg, VAR    ->  register[$reg] = constants[VAR]
     """
     """
-    found = False
+    changed = False
 
 
     # Variable values
     # Variable values
     constants = {}
     constants = {}
@@ -162,23 +174,23 @@ def fold_constants(block):
 
 
                 block.replace(1, [S('command', 'li', rd, result)])
                 block.replace(1, [S('command', 'li', rd, result)])
                 register[rd] = result
                 register[rd] = result
-                found = True
+                changed = True
             elif rt_known:
             elif rt_known:
                 # c = 10        ->  b = a + 10
                 # c = 10        ->  b = a + 10
                 # b = c + a
                 # b = c + a
                 s[2] = register[rt]
                 s[2] = register[rt]
-                found = True
-            elif rs_known and s.name in ['addu', 'mult']:
+                changed = True
+            elif rs_known and s.name == 'addu':
                 # a = 10        ->  b = c + 10
                 # a = 10        ->  b = c + 10
                 # b = c + a
                 # b = c + a
                 s[1] = rt
                 s[1] = rt
                 s[2] = register[rs]
                 s[2] = register[rs]
-                found = True
+                changed = True
         elif len(s) and s[0] in register:
         elif len(s) and s[0] in register:
             # Known register is overwritten, remove its value
             # Known register is overwritten, remove its value
             del register[s[0]]
             del register[s[0]]
 
 
-    return found
+    return changed
 
 
 
 
 def copy_propagation(block):
 def copy_propagation(block):

+ 17 - 1
src/statement.py

@@ -96,7 +96,23 @@ class Statement:
     def uses(self, reg):
     def uses(self, reg):
         """Check if this statement uses the given register."""
         """Check if this statement uses the given register."""
         # TODO: Finish
         # TODO: Finish
-        return (self.is_load() or self.is_arith()) and reg in self[1:]
+        if self.is_arith():
+            return reg in self[1:]
+
+        if self.is_command('move'):
+            return self[1] == reg
+
+        if self.is_command('lw', 'sb', 'sw', 'dsw'):
+            m = re.match('^\d+\(([^)]+)\)$', self[1])
+
+            if m:
+                return m.group(1) == reg
+
+            # 'sw' also uses its first argument
+            if self.name in ['sw', 'dsw']:
+                return self[0] == reg
+
+        return False
 
 
 
 
 class Block:
 class Block:

+ 3 - 3
tests/test_optimize_advanced.py

@@ -19,9 +19,9 @@ class TestOptimizeAdvanced(unittest.TestCase):
     def test_eliminate_common_subexpressions_simple(self):
     def test_eliminate_common_subexpressions_simple(self):
         b = B([S('command', 'addu', '$regC', '$regA', '$regB'),
         b = B([S('command', 'addu', '$regC', '$regA', '$regB'),
                S('command', 'addu', '$regD', '$regA', '$regB')])
                S('command', 'addu', '$regD', '$regA', '$regB')])
-        e = [S('command', 'addu', '$t0', '$regA', '$regB'), \
-             S('command', 'move', '$regC', '$t0'), \
-             S('command', 'move', '$regD', '$t0')]
+        e = [S('command', 'addu', '$8', '$regA', '$regB'), \
+             S('command', 'move', '$regC', '$8'), \
+             S('command', 'move', '$regD', '$8')]
         eliminate_common_subexpressions(b)
         eliminate_common_subexpressions(b)
         self.assertEqual(b.statements, e)
         self.assertEqual(b.statements, e)