Kaynağa Gözat

Merge branch 'master' of github.com:taddeus/peephole

Jayke Meijer 14 yıl önce
ebeveyn
işleme
59b9d767a8
6 değiştirilmiş dosya ile 253 ekleme ve 91 silme
  1. 151 0
      src/dataflow.py
  2. 21 15
      src/optimize.py
  3. 11 61
      src/statement.py
  4. 46 0
      tests/test_dataflow.py
  5. 1 1
      tests/test_optimize.py
  6. 23 14
      tests/test_statement.py

+ 151 - 0
src/dataflow.py

@@ -0,0 +1,151 @@
+from copy import copy
+
+from statement import Block
+
+
+class BasicBlock(Block):
+    def __init__(self, statements=[]):
+        Block.__init__(self, statements)
+        self.edges_to = []
+        self.edges_from = []
+
+        self.dominates = []
+        self.dominated_by = []
+
+    def add_edge_to(self, block):
+        if block not in self.edges_to:
+            self.edges_to.append(block)
+            block.edges_from.append(self)
+
+    def set_dominates(self, block):
+        if block not in self.dominates:
+            self.dominates.append(block)
+            block.dominated_by.append(self)
+
+    def get_gen(self):
+        pass
+
+    def get_kill(self):
+        pass
+
+    def get_in(self):
+        pass
+
+    def get_out(self):
+        pass
+
+
+def find_leaders(statements):
+    """Determine the leaders, which are:
+       1. The first statement.
+       2. Any statement that is the target of a jump.
+       3. Any statement that follows directly follows a jump."""
+    leaders = [0]
+    jump_target_labels = []
+
+    # Append statements following jumps and save jump target labels
+    for i, statement in enumerate(statements[1:]):
+        if statement.is_jump():
+            leaders.append(i + 2)
+            jump_target_labels.append(statement[-1])
+
+    # Append jump targets
+    for i, statement in enumerate(statements[1:]):
+        if i + 1 not in leaders \
+                and statement.is_label() \
+                and statement.name in jump_target_labels:
+            leaders.append(i + 1)
+
+    leaders.sort()
+
+    return leaders
+
+
+def find_basic_blocks(statements):
+    """Divide a statement list into basic blocks. Returns a list of basic
+    blocks, which are also statement lists."""
+    leaders = find_leaders(statements)
+    blocks = []
+
+    for i in range(len(leaders) - 1):
+        blocks.append(BasicBlock(statements[leaders[i]:leaders[i + 1]]))
+
+    blocks.append(BasicBlock(statements[leaders[-1]:]))
+
+    return blocks
+
+
+def generate_flow_graph(blocks):
+    """Add flow graph edge administration of an ordered sequence of basic
+    blocks."""
+    for i, b in enumerate(blocks):
+        last_statement = b[-1]
+
+        if last_statement.is_jump():
+            target = last_statement.jump_target()
+
+            # Compare the target to all leading labels, add an edge if the
+            # label matches the jump target
+            for other in blocks:
+                if other[0].is_label(target):
+                    b.add_edge_to(other)
+
+            # A branch instruction also creates an edge to the next block
+            if last_statement.is_branch() and i < len(blocks) - 1:
+                b.add_edge_to(blocks[i + 1])
+        elif i < len(blocks) - 1:
+            b.add_edge_to(blocks[i + 1])
+
+
+def generate_dominator_tree(nodes):
+    """Add dominator administration to the given flow graph nodes."""
+    # Dominator of the start node is the start itself
+    nodes[0].dom = set([nodes[0]])
+
+    # For all other nodes, set all nodes as the dominators
+    for n in nodes[1:]:
+        n.dom = set(copy(nodes))
+
+    def pred(n, known=[]):
+        """Recursively find all predecessors of a node."""
+        direct = filter(lambda x: x not in known, n.edges_from)
+        p = copy(direct)
+
+        for ancestor in direct:
+            p += pred(ancestor, direct)
+
+        return p
+
+    # Iteratively eliminate nodes that are not dominators
+    changed = True
+
+    while changed:
+        changed = False
+
+        for n in nodes[1:]:
+            old_dom = n.dom
+            intersection = lambda p1, p2: p1.dom & p2.dom
+            n.dom = set([n]) | reduce(intersection, pred(n), set([]))
+
+            if n.dom != old_dom:
+                changed = True
+
+    def idom(d, n):
+        """Check if d immediately dominates n."""
+        for b in n.dom:
+            if b != d and b != n and b in n.dom:
+                return False
+
+        return True
+
+    # Build tree using immediate dominators
+    for n in nodes:
+        for d in n.dom:
+            if idom(d, n):
+                d.set_dominates(n)
+                break
+
+# statements = parse_file(...)
+# b = find_basic_blocks(statements)
+# generate_flow_graph(b)  # nodes now have edges
+# generate_dominator_tree(b)  # nodes now have dominators

+ 21 - 15
src/optimize.py

@@ -1,8 +1,8 @@
-from utils import find_basic_blocks
+from dataflow import find_basic_blocks
 
 
 def optimize_global(statements):
-    """Optimize statement sequences in on a global level."""
+    """Optimize statement sequences on a global level."""
     old_len = -1
 
     while old_len != len(statements):
@@ -11,12 +11,12 @@ def optimize_global(statements):
         while not statements.end():
             s = statements.read()
 
-            # mov $regA,$regB           ->  --- remove it
+            # mov $regA, $regA          ->  --- remove it
             if s.is_command('move') and s[0] == s[1]:
                 statements.replace(1, [])
                 continue
 
-            # mov $regA,$regB           ->  instr $regA, $regB, ...
+            # mov $regA, $regB          ->  instr $regA, $regB, ...
             # instr $regA, $regA, ...
             if s.is_command('move'):
                 ins = statements.peek()
@@ -35,8 +35,8 @@ def optimize_global(statements):
                 if len(following) == 2:
                     mov, jal = following
 
-                    if mov.name == 'move' and mov[1] == s[0] \
-                            and jal.name == 'jal':
+                    if mov.is_command('move') and mov[1] == s[0] \
+                            and jal.is_command('jal'):
                         s[0] = mov[0]
                         statements.replace(1, [], start=statements.pointer + 1)
                         continue
@@ -78,8 +78,6 @@ def optimize_global(statements):
                         s[2] = label.name
                         statements.replace(3, [s, label])
 
-    return statements
-
 
 def optimize_blocks(blocks):
     """Call the optimizer for each basic block. Do this several times until
@@ -115,21 +113,29 @@ def optimize_block(statements):
     return changed, output_statements
 
 
-def optimize(original, verbose=0):
+def optimize(statements, verbose=0):
     """optimization wrapper function, calls global and basic-block level
     optimization functions."""
     # Optimize on a global level
-    opt_global = optimize_global(original)
+    o = len(statements)
+    optimize_global(statements)
+    g = len(statements)
 
     # Optimize basic blocks
-    basic_blocks = find_basic_blocks(opt_global)
+    basic_blocks = find_basic_blocks(statements)
     blocks = optimize_blocks(basic_blocks)
-    opt_blocks = reduce(lambda a, b: a.statements + b.statements, blocks)
+    block_statements = map(lambda b: b.statements, blocks)
+    opt_blocks = reduce(lambda a, b: a + b, block_statements)
+    b = len(opt_blocks)
+
+    # - Common subexpression elimination
+    # - Constant folding
+    # - Copy propagation
+    # - Dead-code elimination
+    # - Temporary variable renaming
+    # - Interchange of independent statements
 
     if verbose:
-        o = len(original)
-        g = len(opt_global)
-        b = len(opt_blocks)
         print 'Original statements:             %d' % o
         print 'After global optimization:       %d' % g
         print 'After basic blocks optimization: %d' % b

+ 11 - 61
src/utils.py → src/statement.py

@@ -1,10 +1,11 @@
 import re
 
+
 class Statement:
     def __init__(self, stype, name, *args, **kwargs):
         self.stype = stype
         self.name = name
-        self.args = args
+        self.args = list(args)
         self.options = kwargs
 
     def __getitem__(self, n):
@@ -51,6 +52,12 @@ class Statement:
                and re.match('^j|jal|beq|bne|blez|bgtz|bltz|bgez|bct|bcf$', \
                             self.name)
 
+    def is_branch(self):
+        """Check if the statement is a branch."""
+        return self.is_command() \
+               and re.match('^beq|bne|blez|bgtz|bltz|bgez|bct|bcf$', \
+                            self.name)
+
     def is_shift(self):
         """Check if the statement is a shift operation."""
         return self.is_command() and re.match('^s(ll|la|rl|ra)$', self.name)
@@ -71,26 +78,6 @@ class Statement:
 
         return self[-1]
 
-    def get_def(self):
-        """Get the def[S] of this statement."""
-        if not self.is_command():
-            return []
-
-        if self.is_load() or self.is_arith():
-            return [self[0]]
-
-    def get_use(self):
-        """Get the use[S] of this statement."""
-        return []
-
-    def defines(self, var):
-        """Check if a variable is defined by this statement."""
-        return var in self.get_def()
-
-    def uses(self, var):
-        """Check if a variable is used by this statement."""
-        return var in self.get_use()
-
 
 class Block:
     def __init__(self, statements=[]):
@@ -100,6 +87,9 @@ class Block:
     def __iter__(self):
         return iter(self.statements)
 
+    def __getitem__(self, n):
+        return self.statements[n]
+
     def __len__(self):
         return len(self.statements)
 
@@ -137,43 +127,3 @@ class Block:
         """Apply a filter to the statement list. If the callback returns True,
         the statement will remain in the list.."""
         self.statements = filter(callback, self.statements)
-
-
-def find_leaders(statements):
-    """Determine the leaders, which are:
-       1. The first statement.
-       2. Any statement that is the target of a jump.
-       3. Any statement that follows directly follows a jump."""
-    leaders = [0]
-    jump_target_labels = []
-
-    # Append statements following jumps and save jump target labels
-    for i, statement in enumerate(statements[1:]):
-        if statement.is_jump():
-            leaders.append(i + 2)
-            jump_target_labels.append(statement[-1])
-
-    # Append jump targets
-    for i, statement in enumerate(statements[1:]):
-        if i + 1 not in leaders \
-                and statement.is_label() \
-                and statement.name in jump_target_labels:
-            leaders.append(i + 1)
-
-    leaders.sort()
-
-    return leaders
-
-
-def find_basic_blocks(statements):
-    """Divide a statement list into basic blocks. Returns a list of basic
-    blocks, which are also statement lists."""
-    leaders = find_leaders(statements)
-    blocks = []
-
-    for i in range(len(leaders) - 1):
-        blocks.append(Block(statements[leaders[i]:leaders[i + 1]]))
-
-    blocks.append(Block(statements[leaders[-1]:]))
-
-    return blocks

+ 46 - 0
tests/test_dataflow.py

@@ -0,0 +1,46 @@
+import unittest
+
+from src.statement import Statement as S
+from src.dataflow import BasicBlock as B, find_leaders, find_basic_blocks, \
+        generate_flow_graph
+
+
+class TestDataflow(unittest.TestCase):
+
+    def setUp(self):
+        add = S('command', 'add', '$1', '$2', '$3')
+        self.statements = [add, S('command', 'j', 'foo'), add, add, \
+                S('label', 'foo')]
+
+    def tearDown(self):
+        del self.statements
+
+    def test_find_leaders(self):
+        self.assertEqual(find_leaders(self.statements), [0, 2, 4])
+
+    def test_find_basic_blocks(self):
+        s = self.statements
+        self.assertEqual(map(lambda b: b.statements, find_basic_blocks(s)), \
+                [B(s[:2]).statements, B(s[2:4]).statements, \
+                 B(s[4:]).statements])
+
+    def test_generate_flow_graph_simple(self):
+        b1 = B([S('command', 'foo'), S('command', 'j', 'b2')])
+        b2 = B([S('label', 'b2'), S('command', 'bar')])
+        generate_flow_graph([b1, b2])
+
+        self.assertEqual(b1.edges_to, [b2])
+        self.assertEqual(b2.edges_from, [b1])
+
+    def test_generate_flow_graph_branch(self):
+        b1 = B([S('command', 'foo'), S('command', 'beq', '$1', '$2', 'b3')])
+        b2 = B([S('command', 'bar')])
+        b3 = B([S('label', 'b3'), S('command', 'baz')])
+        generate_flow_graph([b1, b2, b3])
+
+        self.assertIn(b2, b1.edges_to)
+        self.assertIn(b3, b1.edges_to)
+        self.assertEqual(b2.edges_from, [b1])
+        self.assertEqual(b2.edges_to, [b3])
+        self.assertIn(b1, b3.edges_from)
+        self.assertIn(b2, b3.edges_from)

+ 1 - 1
tests/test_optimize.py

@@ -1,7 +1,7 @@
 import unittest
 
 from src.optimize import optimize_global
-from src.utils import Statement as S, Block as B
+from src.statement import Statement as S, Block as B
 
 
 class TestOptimize(unittest.TestCase):

+ 23 - 14
tests/test_utils.py → tests/test_statement.py

@@ -1,31 +1,25 @@
 import unittest
 
-from src.utils import Statement as S, Block as B, find_leaders, \
-        find_basic_blocks
+from src.statement import Statement as S, Block as B
 
 
-class TestUtils(unittest.TestCase):
+class TestStatement(unittest.TestCase):
 
     def setUp(self):
-        add = S('command', 'add', '$1', '$2', '$3')
-        self.statements = [add, S('command', 'j', 'foo'), add, add, \
-                S('label', 'foo')]
+        self.statement = S('command', 'foo', '$1')
         self.block = B([S('command', 'foo'), \
                         S('comment', 'bar'),
                         S('command', 'baz')])
 
     def tearDown(self):
-        del self.statements
         del self.block
 
-    def test_find_leaders(self):
-        self.assertEqual(find_leaders(self.statements), [0, 2, 4])
+    def test_getitem(self):
+        self.assertEqual(self.statement[0], '$1')
 
-    def test_find_basic_blocks(self):
-        s = self.statements
-        self.assertEqual(map(lambda b: b.statements, find_basic_blocks(s)), \
-                [B(s[:2]).statements, B(s[2:4]).statements, \
-                 B(s[4:]).statements])
+    def test_setitem(self):
+        self.statement[0] = '$2'
+        self.assertEqual(self.statement[0], '$2')
 
     def test_eq(self):
         self.assertTrue(S('command', 'foo') == S('command', 'foo'))
@@ -83,3 +77,18 @@ class TestUtils(unittest.TestCase):
         self.block.apply_filter(lambda s: s.is_command())
         self.assertEqual(self.block.statements, [S('command', 'foo'), \
                                                  S('command', 'baz')])
+
+    def test_is_shift(self):
+        self.assertTrue(S('command', 'sll').is_shift())
+        self.assertFalse(S('command', 'foo').is_shift())
+        self.assertFalse(S('label', 'sll').is_shift())
+
+    def test_is_load(self):
+        self.assertTrue(S('command', 'lw').is_load())
+        self.assertFalse(S('command', 'foo').is_load())
+        self.assertFalse(S('label', 'lw').is_load())
+
+    def test_is_arith(self):
+        self.assertTrue(S('command', 'add', '$1', '$2', '$3').is_arith())
+        self.assertFalse(S('command', 'foo').is_arith())
+        self.assertFalse(S('label', 'add').is_arith())