Explorar el Código

Added proper removal of redundant jumps.

Jayke Meijer hace 14 años
padre
commit
3cadb4930f
Se han modificado 3 ficheros con 48 adiciones y 45 borrados
  1. 36 36
      src/optimize/redundancies.py
  2. 3 1
      src/program.py
  3. 9 8
      tests/test_optimize.py

+ 36 - 36
src/optimize/redundancies.py

@@ -124,45 +124,45 @@ def add_lw(add, statements):
             return True
 
 def remove_redundant_jumps(statements):
-    """Remove jump if label follows immediatly."""
-    old_len = -1
-
-    while old_len != len(statements):
-        old_len = len(statements)
+    """Remove jump if label follows immediately."""
+    changed = False
         
-        while not statements.end():
-            s = statements.read()
-
-            #     j $Lx     ->             $Lx:
-            # $Lx:
-            if s.is_command('j'):
-                following = statements.peek(2)
-
-                if following.is_label(s[0]):
-                    statements.replace(1, [])
+    statements.reset()
+    while not statements.end():
+        s = statements.read()
+
+        #     j $Lx     ->             $Lx:
+        # $Lx:
+        if s.is_command('j'):
+            following = statements.peek()
+            
+            if following.is_label(s[0]):
+                statements.replace(1, [])
+                changed = True
+                    
+    return True
                         
 def remove_redundant_branch_jumps(statements):
     """Optimize statement sequences on a global level."""
-    old_len = -1
-
-    while old_len != len(statements):
-        old_len = len(statements)
-
-        while not statements.end():
-            s = statements.read()
-
-            #     beq/bne ..., $Lx      ->      bne/beq ..., $Ly
-            #     j $Ly                     $Lx:
-            # $Lx:
-            if s.is_command('beq', 'bne'):
-                following = statements.peek(2)
-
-                if len(following) == 2:
-                    j, label = following
-
-                    if j.is_command('j') and label.is_label(s[2]):
-                        s.name = 'bne' if s.is_command('beq') else 'beq'
-                        s[2] = j[0]
-                        statements.replace(3, [s, label])
+    changed = False
 
     statements.reset()
+    while not statements.end():
+        s = statements.read()
+
+        #     beq/bne ..., $Lx      ->      bne/beq ..., $Ly
+        #     j $Ly                     $Lx:
+        # $Lx:
+        if s.is_command('beq', 'bne'):
+            following = statements.peek(2)
+
+            if len(following) == 2:
+                j, label = following
+
+                if j.is_command('j') and label.is_label(s[2]):
+                    s.name = 'bne' if s.is_command('beq') else 'beq'
+                    s[2] = j[0]
+                    statements.replace(3, [s, label])
+                    changed = True
+                        
+    return changed

+ 3 - 1
src/program.py

@@ -1,6 +1,7 @@
 from statement import Statement as S, Block
 from dataflow import find_basic_blocks, generate_flow_graph
-from optimize.redundancies import remove_redundant_jumps, remove_redundancies
+from optimize.redundancies import remove_redundant_jumps, remove_redundancies,\
+        remove_redundant_branch_jumps
 from optimize.advanced import eliminate_common_subexpressions, \
         fold_constants, copy_propagation, eliminate_dead_code
 from writer import write_statements
@@ -63,6 +64,7 @@ class Program(Block):
     def optimize_global(self):
         """Optimize on a global level."""
         remove_redundant_jumps(self)
+        remove_redundant_branch_jumps(self)
 
     def optimize_blocks(self):
         """Optimize on block level. Keep executing all optimizations until no

+ 9 - 8
tests/test_optimize.py

@@ -1,6 +1,7 @@
 import unittest
 
-from src.optimize.redundancies import remove_redundancies, remove_redundant_jumps
+from src.optimize.redundancies import remove_redundancies, \
+    remove_redundant_jumps, remove_redundant_branch_jumps
 from src.program import Program
 from src.statement import Statement as S, Block as B
 
@@ -199,12 +200,12 @@ class TestOptimize(unittest.TestCase):
                    S('command', 'j', '$L1'),
                    S('label', '$L1'),
                    self.bar])
-                   
-       remove_redundancies(block)
+        
+        remove_redundant_jumps(block)
        
-       self.assertEqual(block.statements, B([self.foo, 
-                                             S('command', 'j', '$L1'),
-                                             self.bar]))
+        self.assertEqual(block.statements, [self.foo, 
+                                             S('label', '$L1'),
+                                             self.bar])
                                              
     def test_remove_redundant_jumps_false(self):
         arguments = [self.foo,
@@ -213,9 +214,9 @@ class TestOptimize(unittest.TestCase):
                    self.bar]
         block = B(arguments)
                    
-       remove_redundancies(block)
+        remove_redundant_jumps(block)
        
-       self.assertEqual(block.statements, arguments)
+        self.assertEqual(block.statements, arguments)
         
     def test_remove_redundant_branch_jumps_beq_j_true(self):
         block = B([self.foo,