Commit 3cadb493 authored by Jayke Meijer's avatar Jayke Meijer

Added proper removal of redundant jumps.

parent ff70b648
...@@ -124,45 +124,45 @@ def add_lw(add, statements): ...@@ -124,45 +124,45 @@ def add_lw(add, statements):
return True return True
def remove_redundant_jumps(statements): def remove_redundant_jumps(statements):
"""Remove jump if label follows immediatly.""" """Remove jump if label follows immediately."""
old_len = -1 changed = False
while old_len != len(statements):
old_len = len(statements)
while not statements.end(): statements.reset()
s = statements.read() while not statements.end():
s = statements.read()
# j $Lx -> $Lx:
# $Lx: # j $Lx -> $Lx:
if s.is_command('j'): # $Lx:
following = statements.peek(2) if s.is_command('j'):
following = statements.peek()
if following.is_label(s[0]):
statements.replace(1, []) if following.is_label(s[0]):
statements.replace(1, [])
changed = True
return True
def remove_redundant_branch_jumps(statements): def remove_redundant_branch_jumps(statements):
"""Optimize statement sequences on a global level.""" """Optimize statement sequences on a global level."""
old_len = -1 changed = False
while old_len != len(statements):
old_len = len(statements)
while not statements.end():
s = statements.read()
# beq/bne ..., $Lx -> bne/beq ..., $Ly
# j $Ly $Lx:
# $Lx:
if s.is_command('beq', 'bne'):
following = statements.peek(2)
if len(following) == 2:
j, label = following
if j.is_command('j') and label.is_label(s[2]):
s.name = 'bne' if s.is_command('beq') else 'beq'
s[2] = j[0]
statements.replace(3, [s, label])
statements.reset() statements.reset()
while not statements.end():
s = statements.read()
# beq/bne ..., $Lx -> bne/beq ..., $Ly
# j $Ly $Lx:
# $Lx:
if s.is_command('beq', 'bne'):
following = statements.peek(2)
if len(following) == 2:
j, label = following
if j.is_command('j') and label.is_label(s[2]):
s.name = 'bne' if s.is_command('beq') else 'beq'
s[2] = j[0]
statements.replace(3, [s, label])
changed = True
return changed
from statement import Statement as S, Block from statement import Statement as S, Block
from dataflow import find_basic_blocks, generate_flow_graph from dataflow import find_basic_blocks, generate_flow_graph
from optimize.redundancies import remove_redundant_jumps, remove_redundancies from optimize.redundancies import remove_redundant_jumps, remove_redundancies,\
remove_redundant_branch_jumps
from optimize.advanced import eliminate_common_subexpressions, \ from optimize.advanced import eliminate_common_subexpressions, \
fold_constants, copy_propagation, eliminate_dead_code fold_constants, copy_propagation, eliminate_dead_code
from writer import write_statements from writer import write_statements
...@@ -63,6 +64,7 @@ class Program(Block): ...@@ -63,6 +64,7 @@ class Program(Block):
def optimize_global(self): def optimize_global(self):
"""Optimize on a global level.""" """Optimize on a global level."""
remove_redundant_jumps(self) remove_redundant_jumps(self)
remove_redundant_branch_jumps(self)
def optimize_blocks(self): def optimize_blocks(self):
"""Optimize on block level. Keep executing all optimizations until no """Optimize on block level. Keep executing all optimizations until no
......
import unittest import unittest
from src.optimize.redundancies import remove_redundancies, remove_redundant_jumps from src.optimize.redundancies import remove_redundancies, \
remove_redundant_jumps, remove_redundant_branch_jumps
from src.program import Program from src.program import Program
from src.statement import Statement as S, Block as B from src.statement import Statement as S, Block as B
...@@ -199,12 +200,12 @@ class TestOptimize(unittest.TestCase): ...@@ -199,12 +200,12 @@ class TestOptimize(unittest.TestCase):
S('command', 'j', '$L1'), S('command', 'j', '$L1'),
S('label', '$L1'), S('label', '$L1'),
self.bar]) self.bar])
remove_redundancies(block) remove_redundant_jumps(block)
self.assertEqual(block.statements, B([self.foo, self.assertEqual(block.statements, [self.foo,
S('command', 'j', '$L1'), S('label', '$L1'),
self.bar])) self.bar])
def test_remove_redundant_jumps_false(self): def test_remove_redundant_jumps_false(self):
arguments = [self.foo, arguments = [self.foo,
...@@ -213,9 +214,9 @@ class TestOptimize(unittest.TestCase): ...@@ -213,9 +214,9 @@ class TestOptimize(unittest.TestCase):
self.bar] self.bar]
block = B(arguments) block = B(arguments)
remove_redundancies(block) remove_redundant_jumps(block)
self.assertEqual(block.statements, arguments) self.assertEqual(block.statements, arguments)
def test_remove_redundant_branch_jumps_beq_j_true(self): def test_remove_redundant_branch_jumps_beq_j_true(self):
block = B([self.foo, block = B([self.foo,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment