Просмотр исходного кода

Moved util functions to separate files.

Taddeus Kroes 14 лет назад
Родитель
Сommit
ef3125ad59
5 измененных файлов с 190 добавлено и 91 удалено
  1. 141 0
      src/dataflow.py
  2. 21 15
      src/optimize.py
  3. 1 60
      src/statement.py
  4. 25 0
      tests/test_dataflow.py
  5. 2 16
      tests/test_statement.py

+ 141 - 0
src/dataflow.py

@@ -0,0 +1,141 @@
+from copy import copy
+
+from utils import Block
+
+
+class BasicBlock(Block):
+    edges_to = []
+    edges_from = []
+
+    dominates = []
+    dominated_by = []
+
+    def add_edge_to(self, block):
+        self.edges_to.append(block)
+        block.edges_from.append(self)
+
+    def set_dominates(self, block):
+        self.dominates.append(block)
+        block.dominated_by.append(self)
+
+    def get_gen(self):
+        pass
+
+    def get_kill(self):
+        pass
+
+    def get_in(self):
+        pass
+
+    def get_out(self):
+        pass
+
+
+def find_leaders(statements):
+    """Determine the leaders, which are:
+       1. The first statement.
+       2. Any statement that is the target of a jump.
+       3. Any statement that follows directly follows a jump."""
+    leaders = [0]
+    jump_target_labels = []
+
+    # Append statements following jumps and save jump target labels
+    for i, statement in enumerate(statements[1:]):
+        if statement.is_jump():
+            leaders.append(i + 2)
+            jump_target_labels.append(statement[-1])
+
+    # Append jump targets
+    for i, statement in enumerate(statements[1:]):
+        if i + 1 not in leaders \
+                and statement.is_label() \
+                and statement.name in jump_target_labels:
+            leaders.append(i + 1)
+
+    leaders.sort()
+
+    return leaders
+
+
+def find_basic_blocks(statements):
+    """Divide a statement list into basic blocks. Returns a list of basic
+    blocks, which are also statement lists."""
+    leaders = find_leaders(statements)
+    blocks = []
+
+    for i in range(len(leaders) - 1):
+        blocks.append(BasicBlock(statements[leaders[i]:leaders[i + 1]]))
+
+    blocks.append(BasicBlock(statements[leaders[-1]:]))
+
+    return blocks
+
+
+def generate_flow_graph(blocks):
+    """Add flow graph edge administration of an ordered sequence of basic
+    blocks."""
+    for b in blocks:
+        last_statement = b[-1]
+
+        if last_statement.is_jump():
+            target = last_statement.jump_target()
+
+            # Compare the target to all leading labels, add an edge if the
+            # label matches the jump target
+            for other in blocks:
+                if other[0].is_label(target):
+                    b.add_edge_to(other)
+
+
+def generate_dominator_tree(nodes):
+    """Add dominator administration to the given flow graph nodes."""
+    # Dominator of the start node is the start itself
+    nodes[0].dom = set([nodes[0]])
+
+    # For all other nodes, set all nodes as the dominators
+    for n in nodes[1:]:
+        n.dom = set(copy(nodes))
+
+    def pred(n, known=[]):
+        """Recursively find all predecessors of a node."""
+        direct = filter(lambda x: x not in known, n.edges_from)
+        p = copy(direct)
+
+        for ancestor in direct:
+            p += pred(ancestor, direct)
+
+        return p
+
+    # Iteratively eliminate nodes that are not dominators
+    changed = True
+
+    while changed:
+        changed = False
+
+        for n in nodes[1:]:
+            old_dom = n.dom
+            intersection = lambda p1, p2: p1.dom & p2.dom
+            n.dom = set([n]) | reduce(intersection, pred(n), set([]))
+
+            if n.dom != old_dom:
+                changed = True
+
+    def idom(d, n):
+        """Check if d immediately dominates n."""
+        for b in n.dom:
+            if b != d and b != n and b in n.dom:
+                return False
+
+        return True
+
+    # Build tree using immediate dominators
+    for n in nodes:
+        for d in n.dom:
+            if idom(d, n):
+                d.set_dominates(n)
+                break
+
+# statements = parse_file(...)
+# b = find_basic_blocks(statements)
+# generate_flow_graph(b)  # nodes now have edges
+# generate_dominator_tree(b)  # nodes now have dominators

+ 21 - 15
src/optimize.py

@@ -1,8 +1,8 @@
-from utils import find_basic_blocks
+from dataflow import find_basic_blocks
 
 
 def optimize_global(statements):
-    """Optimize statement sequences in on a global level."""
+    """Optimize statement sequences on a global level."""
     old_len = -1
 
     while old_len != len(statements):
@@ -11,12 +11,12 @@ def optimize_global(statements):
         while not statements.end():
             s = statements.read()
 
-            # mov $regA,$regB           ->  --- remove it
+            # mov $regA, $regA          ->  --- remove it
             if s.is_command('move') and s[0] == s[1]:
                 statements.replace(1, [])
                 continue
 
-            # mov $regA,$regB           ->  instr $regA, $regB, ...
+            # mov $regA, $regB          ->  instr $regA, $regB, ...
             # instr $regA, $regA, ...
             if s.is_command('move'):
                 ins = statements.peek()
@@ -35,8 +35,8 @@ def optimize_global(statements):
                 if len(following) == 2:
                     mov, jal = following
 
-                    if mov.name == 'move' and mov[1] == s[0] \
-                            and jal.name == 'jal':
+                    if mov.is_command('move') and mov[1] == s[0] \
+                            and jal.is_command('jal'):
                         s[0] = mov[0]
                         statements.replace(1, [], start=statements.pointer + 1)
                         continue
@@ -78,8 +78,6 @@ def optimize_global(statements):
                         s[2] = label.name
                         statements.replace(3, [s, label])
 
-    return statements
-
 
 def optimize_blocks(blocks):
     """Call the optimizer for each basic block. Do this several times until
@@ -115,21 +113,29 @@ def optimize_block(statements):
     return changed, output_statements
 
 
-def optimize(original, verbose=0):
+def optimize(statements, verbose=0):
     """optimization wrapper function, calls global and basic-block level
     optimization functions."""
     # Optimize on a global level
-    opt_global = optimize_global(original)
+    o = len(statements)
+    optimize_global(statements)
+    g = len(statements)
 
     # Optimize basic blocks
-    basic_blocks = find_basic_blocks(opt_global)
+    basic_blocks = find_basic_blocks(statements)
     blocks = optimize_blocks(basic_blocks)
-    opt_blocks = reduce(lambda a, b: a.statements + b.statements, blocks)
+    block_statements = map(lambda b: b.statements, blocks)
+    opt_blocks = reduce(lambda a, b: a + b, block_statements)
+    b = len(opt_blocks)
+
+    # - Common subexpression elimination
+    # - Constant folding
+    # - Copy propagation
+    # - Dead-code elimination
+    # - Temporary variable renaming
+    # - Interchange of independent statements
 
     if verbose:
-        o = len(original)
-        g = len(opt_global)
-        b = len(opt_blocks)
         print 'Original statements:             %d' % o
         print 'After global optimization:       %d' % g
         print 'After basic blocks optimization: %d' % b

+ 1 - 60
src/utils.py → src/statement.py

@@ -1,5 +1,6 @@
 import re
 
+
 class Statement:
     def __init__(self, stype, name, *args, **kwargs):
         self.stype = stype
@@ -71,26 +72,6 @@ class Statement:
 
         return self[-1]
 
-    def get_def(self):
-        """Get the def[S] of this statement."""
-        if not self.is_command():
-            return []
-
-        if self.is_load() or self.is_arith():
-            return [self[0]]
-
-    def get_use(self):
-        """Get the use[S] of this statement."""
-        return []
-
-    def defines(self, var):
-        """Check if a variable is defined by this statement."""
-        return var in self.get_def()
-
-    def uses(self, var):
-        """Check if a variable is used by this statement."""
-        return var in self.get_use()
-
 
 class Block:
     def __init__(self, statements=[]):
@@ -137,43 +118,3 @@ class Block:
         """Apply a filter to the statement list. If the callback returns True,
         the statement will remain in the list.."""
         self.statements = filter(callback, self.statements)
-
-
-def find_leaders(statements):
-    """Determine the leaders, which are:
-       1. The first statement.
-       2. Any statement that is the target of a jump.
-       3. Any statement that follows directly follows a jump."""
-    leaders = [0]
-    jump_target_labels = []
-
-    # Append statements following jumps and save jump target labels
-    for i, statement in enumerate(statements[1:]):
-        if statement.is_jump():
-            leaders.append(i + 2)
-            jump_target_labels.append(statement[-1])
-
-    # Append jump targets
-    for i, statement in enumerate(statements[1:]):
-        if i + 1 not in leaders \
-                and statement.is_label() \
-                and statement.name in jump_target_labels:
-            leaders.append(i + 1)
-
-    leaders.sort()
-
-    return leaders
-
-
-def find_basic_blocks(statements):
-    """Divide a statement list into basic blocks. Returns a list of basic
-    blocks, which are also statement lists."""
-    leaders = find_leaders(statements)
-    blocks = []
-
-    for i in range(len(leaders) - 1):
-        blocks.append(Block(statements[leaders[i]:leaders[i + 1]]))
-
-    blocks.append(Block(statements[leaders[-1]:]))
-
-    return blocks

+ 25 - 0
tests/test_dataflow.py

@@ -0,0 +1,25 @@
+import unittest
+
+from src.statement import Statement as S, Block as B
+from src.dataflow import find_leaders, find_basic_blocks
+
+
+class TestDataflow(unittest.TestCase):
+
+    def setUp(self):
+        add = S('command', 'add', '$1', '$2', '$3')
+        self.statements = [add, S('command', 'j', 'foo'), add, add, \
+                S('label', 'foo')]
+
+    def tearDown(self):
+        del self.statements
+
+    def test_find_leaders(self):
+        self.assertEqual(find_leaders(self.statements), [0, 2, 4])
+
+    def test_find_basic_blocks(self):
+        s = self.statements
+        self.assertEqual(map(lambda b: b.statements, find_basic_blocks(s)), \
+                [B(s[:2]).statements, B(s[2:4]).statements, \
+                 B(s[4:]).statements])
+

+ 2 - 16
tests/test_utils.py → tests/test_statement.py

@@ -1,32 +1,18 @@
 import unittest
 
-from src.utils import Statement as S, Block as B, find_leaders, \
-        find_basic_blocks
+from src.statement import Statement as S, Block as B
 
 
-class TestUtils(unittest.TestCase):
+class TestStatement(unittest.TestCase):
 
     def setUp(self):
-        add = S('command', 'add', '$1', '$2', '$3')
-        self.statements = [add, S('command', 'j', 'foo'), add, add, \
-                S('label', 'foo')]
         self.block = B([S('command', 'foo'), \
                         S('comment', 'bar'),
                         S('command', 'baz')])
 
     def tearDown(self):
-        del self.statements
         del self.block
 
-    def test_find_leaders(self):
-        self.assertEqual(find_leaders(self.statements), [0, 2, 4])
-
-    def test_find_basic_blocks(self):
-        s = self.statements
-        self.assertEqual(map(lambda b: b.statements, find_basic_blocks(s)), \
-                [B(s[:2]).statements, B(s[2:4]).statements, \
-                 B(s[4:]).statements])
-
     def test_eq(self):
         self.assertTrue(S('command', 'foo') == S('command', 'foo'))
         self.assertFalse(S('command', 'foo') == S('command', 'bar'))