Explorar o código

Fixed merge conflict.

Jayke Meijer %!s(int64=14) %!d(string=hai) anos
pai
achega
58c6acb7d7
Modificáronse 3 ficheiros con 152 adicións e 55 borrados
  1. 80 35
      src/dataflow.py
  2. 23 11
      src/statement.py
  3. 49 9
      tests/test_dataflow.py

+ 80 - 35
src/dataflow.py

@@ -1,4 +1,4 @@
-#from copy import copy
+from copy import copy
 
 from statement import Block
 
@@ -25,40 +25,85 @@ class BasicBlock(Block):
         if block not in self.dominates:
             self.dominates.append(block)
             block.dominated_by.append(self)
-            
-    
-#    def get_gen(self):
-#        for s in self.statements:       
-#            if s.is_arith():
-#                self.gen_set.add(s[0])
-#                print 'added: ', s[0]
-#        
-#        return self.gen_set
-#        
-#    def get_kill(self):
-##        if self.edges_from != []:
-#    
-#        for backw in self.edges_from:
-#            self.kill_set = self.gen_set & backw.kill_set
-#            
-#        self.kill_set = self.kill_set - self.get_gen()
-#        print 'get_kill_set', self.kill_set
-#        return self.kill_set
-
-#    def get_in(self):
-#        for backw in self.edges_from:
-#            self.in_set = self.in_set | backw.out_set
-#        print 'in_set', self.in_set
-#        return self.in_set
-
-#    def get_out(self):
-#        print 'gen_set', self.gen_set
-#        print 'get_in', self.get_in()
-#        print 'get_kill', self.get_kill()
-#        self.out_set = self.gen_set | (self.get_in() - self.get_kill())
-        
-
-    
+
+    def create_gen_kill(self, defs):
+        used = set()
+        self_defs = {}
+
+        # Get the last of each definition series and put in in the `def' set
+        self.gen_set = set()
+
+        for s in reversed(self):
+            for reg in s.get_def():
+                if reg not in self_defs:
+                    print 'Found def:', s
+                    self_defs[reg] = s.sid
+                    self.gen_set.add(s.sid)
+
+        # Generate kill set
+        self.kill_set = set()
+
+        for reg, statement_ids in defs.iteritems():
+            if reg in self_defs:
+                add = statement_ids - set([self_defs[reg]])
+            else:
+                add = statement_ids
+
+            self.kill_set |= add
+
+
+def defs(blocks):
+    # Collect definitions of all registers
+    defs = {}
+
+    for b in blocks:
+        for s in b:
+            for reg in s.get_def():
+                if reg not in defs:
+                    defs[reg] = set([s.sid])
+                else:
+                    defs[reg].add(s.sid)
+
+    return defs
+
+
+def reaching_definitions(blocks):
+    """Generate the `in' and `out' sets of the given blocks using the iterative
+    algorithm from the slides."""
+    defs = defs(blocks)
+
+    for b in blocks:
+        b.create_gen_kill(defs)
+        b.out_set = b.gen_set
+
+    change = True
+
+    while change:
+        change = False
+
+        for b in blocks:
+            b.in_set = set()
+
+            for pred in b.edges_from:
+                b.in_set |= pred.out_set
+
+            oldout = copy(p.out_set)
+            p.out_set = b.gen_set | (b.in_set - b.kill_set)
+
+            if b.out_set != oldout:
+                change = True
+
+
+def pred(n, known=[]):
+    """Recursively find all predecessors of a node."""
+    direct = filter(lambda b: b not in known, n.edges_from)
+    p = copy(direct)
+
+    for ancestor in direct:
+        p += pred(ancestor, direct)
+
+    return p
+
 
 def find_leaders(statements):
     """Determine the leaders, which are:

+ 23 - 11
src/statement.py

@@ -2,12 +2,18 @@ import re
 
 
 class Statement:
+    sid = 1
+
     def __init__(self, stype, name, *args, **kwargs):
         self.stype = stype
         self.name = name
         self.args = list(args)
         self.options = kwargs
 
+        # Assign a unique ID to each satement
+        self.sid = Statement.sid
+        Statement.sid += 1
+
     def __getitem__(self, n):
         """Get an argument."""
         return self.args[n]
@@ -26,8 +32,8 @@ class Statement:
         return len(self.args)
 
     def __str__(self):  # pragma: nocover
-        return '<Statement type=%s name=%s args=%s>' \
-                % (self.stype, self.name, self.args)
+        return '<Statement sid=%d type=%s name=%s args=%s>' \
+                % (self.sid, self.stype, self.name, self.args)
 
     def __repr__(self):  # pragma: nocover
         return str(self)
@@ -64,6 +70,11 @@ class Statement:
         """Check if the statement is a shift operation."""
         return self.is_command() and re.match('^s(ll|rl|ra)$', self.name)
 
+    def is_load(self):
+        """Check if the statement is a load instruction."""
+        return self.is_command() and self.name in ['lw', 'li', 'dlw', 'l.s', \
+                                                   'l.d']
+                                                   
     def is_arith(self):
         """Check if the statement is an arithmetic operation."""
         return self.is_command() \
@@ -81,12 +92,12 @@ class Statement:
         """Check if the statement is an binary operation."""
         return self.is_command() and len(self) == 3 and not self.is_jump()
         
-    def is_load(self):
+    def is_load_non_immediate(self):
         """Check if the statement is a load statement."""
         return self.is_command() \
                and re.match('^l(w|a|b|bu|\.d|\.s)|dlw$', \
                             self.name)
-    def is_logical:
+    def is_logical(self):
         """Check if the statement is a logical operator."""
         return self.is_command() and re.match('^(xor|or|and)i?$', self.name)
     
@@ -108,11 +119,11 @@ class Statement:
         """Check if the statement is a shift if less then."""
         return self.is_command() and self.name in ['slt', 'sltu']
         
-    def self.is_convert(self):
+    def is_convert(self):
         """Check if the statement is a convert operator."""
         return self.is_command() and re.match('^cvt\.[a-z\.]*$', self.name)
         
-    def self.is_truncate(self):
+    def is_truncate(self):
         """Check if the statement is a convert operator."""
         return self.is_command() and re.match('^trunc\.[a-z\.]*$', self.name)
         
@@ -125,12 +136,13 @@ class Statement:
     
     def get_def(self):
         """Get the variable that this statement defines, if any."""
-        inst = ['move', 'addu', 'subu', 'li', 'mtc1', 'dmfc1']
+        instr = ['move', 'addu', 'subu', 'li', 'mtc1', 'dmfc1']
         
-        if self.is_load() or self.is_arith() or self.is_logical() \
-                or self.is_double_arithmetic() or self.is_move_from_spec() \
-                or self.is_double_unary() or self.is_set_if_less() \
-                or self.is_convert() or self.is_truncate() \
+        if self.is_load_non_immediate() or self.is_arith() \
+                or self.is_logical() or self.is_double_arithmetic() \
+                or self.is_move_from_spec() or self.is_double_unary() \
+                or self.is_set_if_less() or self.is_convert() \
+                or self.is_truncate() or self.is_load() \
                 or (self.is_command and self.name in instr):
             return self[0]
 

+ 49 - 9
tests/test_dataflow.py

@@ -2,7 +2,7 @@ import unittest
 
 from src.statement import Statement as S
 from src.dataflow import BasicBlock as B, find_leaders, find_basic_blocks, \
-        generate_flow_graph, Dag, DagNode, DagLeaf
+        generate_flow_graph, Dag, DagNode, DagLeaf, defs, reaching_definitions
 
 
 class TestDataflow(unittest.TestCase):
@@ -23,12 +23,12 @@ class TestDataflow(unittest.TestCase):
         self.assertEqual(map(lambda b: b.statements, find_basic_blocks(s)), \
                 [B(s[:2]).statements, B(s[2:4]).statements, \
                  B(s[4:]).statements])
-                 
+
 #    def test_get_gen(self):
 #        b1 = B([S('command', 'add', '$1', '$2', '$3'), \
 #                S('command', 'add', '$2', '$3', '$4'), \
 #                S('command', 'add', '$1', '$4', '$5')])
-#        
+#
 #        self.assertEqual(b1.get_gen(), ['$1', '$2'])
 
 #    def test_get_out(self):
@@ -36,18 +36,18 @@ class TestDataflow(unittest.TestCase):
 #                S('command', 'add', '$2', '$3', '$4'), \
 #                S('command', 'add', '$1', '$4', '$5'), \
 #                S('command', 'j', 'b2')])
-#        
+#
 #        b2 = B([S('command', 'add', '$3', '$5', '$6'), \
 #                S('command', 'add', '$1', '$2', '$3'), \
-#                S('command', 'add', '$6', '$4', '$5')])      
-#                
+#                S('command', 'add', '$6', '$4', '$5')])
+#
 #        blocks = [b1, b2]
-#        
-#        # initialize  out[B] = gen[B] for every block        
+#
+#        # initialize  out[B] = gen[B] for every block
 #        for block in blocks:
 #            block.out_set = block.get_gen()
 #            print 'block.out_set', block.out_set
-#            
+#
 #        generate_flow_graph(blocks)
 
 #        change = True
@@ -112,6 +112,46 @@ class TestDataflow(unittest.TestCase):
 #
 #        self.assertEqualDag(dag, expect)
 
+    def test_defs(self):
+        s1 = S('command', 'addu', '$3', '$1', '$2')
+        s2 = S('command', 'addu', '$1', '$3', 10)
+        s3 = S('command', 'subu', '$3', '$1', 5)
+        s4 = S('command', 'li', '$4', '0x00000001')
+        block = B([s1, s2, s3, s4])
+        self.assertEqual(defs([block]), {
+            '$3': set([s1.sid, s3.sid]),
+            '$1': set([s2.sid]),
+            '$4': set([s4.sid])
+        })
+
+    #def test_defs(self):
+    #    s1 = S('command', 'add', '$3', '$1', '$2')
+    #    s2 = S('command', 'move', '$1', '$3')
+    #    s3 = S('command', 'move', '$3', '$2')
+    #    s4 = S('command', 'li', '$4', '0x00000001')
+    #    block = B([s1, s2, s3, s4])
+    #    self.assertEqual(defs([block]), {
+    #        '$3': set([s1.sid, s3.sid]),
+    #        '$1': set([s2.sid]),
+    #        '$4': set([s4.sid])
+    #    })
+
+    def test_create_gen_kill_gen(self):
+        s1 = S('command', 'addu', '$3', '$1', '$2')
+        s2 = S('command', 'addu', '$1', '$3', 10)
+        s3 = S('command', 'subu', '$3', '$1', 5)
+        s4 = S('command', 'li', '$4', '0x00000001')
+        block = B([s1, s2, s3, s4])
+        block.create_gen_kill(defs([block]))
+        self.assertEqual(block.gen_set, set([s2.sid, s3.sid, s4.sid]))
+
+    #def test_get_kill_used(self):
+    #    block = B([S('command', 'move', '$1', '$3'),
+    #               S('command', 'add', '$3', '$1', '$2'),
+    #               S('command', 'move', '$1', '$3'),
+    #               S('command', 'move', '$2', '$3')])
+    #    self.assertEqual(block.get_kill(), set())
+
     def assertEqualDag(self, dag1, dag2):
         self.assertEqual(len(dag1.nodes), len(dag2.nodes))