Commit 58c6acb7 authored by Jayke Meijer's avatar Jayke Meijer

Fixed merge conflict.

parents fb1abf4b 668a93c9
#from copy import copy from copy import copy
from statement import Block from statement import Block
...@@ -26,38 +26,83 @@ class BasicBlock(Block): ...@@ -26,38 +26,83 @@ class BasicBlock(Block):
self.dominates.append(block) self.dominates.append(block)
block.dominated_by.append(self) block.dominated_by.append(self)
def create_gen_kill(self, defs):
used = set()
self_defs = {}
# def get_gen(self): # Get the last of each definition series and put in in the `def' set
# for s in self.statements: self.gen_set = set()
# if s.is_arith():
# self.gen_set.add(s[0]) for s in reversed(self):
# print 'added: ', s[0] for reg in s.get_def():
# if reg not in self_defs:
# return self.gen_set print 'Found def:', s
# self_defs[reg] = s.sid
# def get_kill(self): self.gen_set.add(s.sid)
## if self.edges_from != []:
# # Generate kill set
# for backw in self.edges_from: self.kill_set = set()
# self.kill_set = self.gen_set & backw.kill_set
# for reg, statement_ids in defs.iteritems():
# self.kill_set = self.kill_set - self.get_gen() if reg in self_defs:
# print 'get_kill_set', self.kill_set add = statement_ids - set([self_defs[reg]])
# return self.kill_set else:
add = statement_ids
self.kill_set |= add
def defs(blocks):
# Collect definitions of all registers
defs = {}
for b in blocks:
for s in b:
for reg in s.get_def():
if reg not in defs:
defs[reg] = set([s.sid])
else:
defs[reg].add(s.sid)
return defs
def reaching_definitions(blocks):
"""Generate the `in' and `out' sets of the given blocks using the iterative
algorithm from the slides."""
defs = defs(blocks)
for b in blocks:
b.create_gen_kill(defs)
b.out_set = b.gen_set
change = True
while change:
change = False
for b in blocks:
b.in_set = set()
for pred in b.edges_from:
b.in_set |= pred.out_set
oldout = copy(p.out_set)
p.out_set = b.gen_set | (b.in_set - b.kill_set)
if b.out_set != oldout:
change = True
# def get_in(self):
# for backw in self.edges_from:
# self.in_set = self.in_set | backw.out_set
# print 'in_set', self.in_set
# return self.in_set
# def get_out(self): def pred(n, known=[]):
# print 'gen_set', self.gen_set """Recursively find all predecessors of a node."""
# print 'get_in', self.get_in() direct = filter(lambda b: b not in known, n.edges_from)
# print 'get_kill', self.get_kill() p = copy(direct)
# self.out_set = self.gen_set | (self.get_in() - self.get_kill())
for ancestor in direct:
p += pred(ancestor, direct)
return p
def find_leaders(statements): def find_leaders(statements):
......
...@@ -2,12 +2,18 @@ import re ...@@ -2,12 +2,18 @@ import re
class Statement: class Statement:
sid = 1
def __init__(self, stype, name, *args, **kwargs): def __init__(self, stype, name, *args, **kwargs):
self.stype = stype self.stype = stype
self.name = name self.name = name
self.args = list(args) self.args = list(args)
self.options = kwargs self.options = kwargs
# Assign a unique ID to each satement
self.sid = Statement.sid
Statement.sid += 1
def __getitem__(self, n): def __getitem__(self, n):
"""Get an argument.""" """Get an argument."""
return self.args[n] return self.args[n]
...@@ -26,8 +32,8 @@ class Statement: ...@@ -26,8 +32,8 @@ class Statement:
return len(self.args) return len(self.args)
def __str__(self): # pragma: nocover def __str__(self): # pragma: nocover
return '<Statement type=%s name=%s args=%s>' \ return '<Statement sid=%d type=%s name=%s args=%s>' \
% (self.stype, self.name, self.args) % (self.sid, self.stype, self.name, self.args)
def __repr__(self): # pragma: nocover def __repr__(self): # pragma: nocover
return str(self) return str(self)
...@@ -64,6 +70,11 @@ class Statement: ...@@ -64,6 +70,11 @@ class Statement:
"""Check if the statement is a shift operation.""" """Check if the statement is a shift operation."""
return self.is_command() and re.match('^s(ll|rl|ra)$', self.name) return self.is_command() and re.match('^s(ll|rl|ra)$', self.name)
def is_load(self):
"""Check if the statement is a load instruction."""
return self.is_command() and self.name in ['lw', 'li', 'dlw', 'l.s', \
'l.d']
def is_arith(self): def is_arith(self):
"""Check if the statement is an arithmetic operation.""" """Check if the statement is an arithmetic operation."""
return self.is_command() \ return self.is_command() \
...@@ -81,12 +92,12 @@ class Statement: ...@@ -81,12 +92,12 @@ class Statement:
"""Check if the statement is an binary operation.""" """Check if the statement is an binary operation."""
return self.is_command() and len(self) == 3 and not self.is_jump() return self.is_command() and len(self) == 3 and not self.is_jump()
def is_load(self): def is_load_non_immediate(self):
"""Check if the statement is a load statement.""" """Check if the statement is a load statement."""
return self.is_command() \ return self.is_command() \
and re.match('^l(w|a|b|bu|\.d|\.s)|dlw$', \ and re.match('^l(w|a|b|bu|\.d|\.s)|dlw$', \
self.name) self.name)
def is_logical: def is_logical(self):
"""Check if the statement is a logical operator.""" """Check if the statement is a logical operator."""
return self.is_command() and re.match('^(xor|or|and)i?$', self.name) return self.is_command() and re.match('^(xor|or|and)i?$', self.name)
...@@ -108,11 +119,11 @@ class Statement: ...@@ -108,11 +119,11 @@ class Statement:
"""Check if the statement is a shift if less then.""" """Check if the statement is a shift if less then."""
return self.is_command() and self.name in ['slt', 'sltu'] return self.is_command() and self.name in ['slt', 'sltu']
def self.is_convert(self): def is_convert(self):
"""Check if the statement is a convert operator.""" """Check if the statement is a convert operator."""
return self.is_command() and re.match('^cvt\.[a-z\.]*$', self.name) return self.is_command() and re.match('^cvt\.[a-z\.]*$', self.name)
def self.is_truncate(self): def is_truncate(self):
"""Check if the statement is a convert operator.""" """Check if the statement is a convert operator."""
return self.is_command() and re.match('^trunc\.[a-z\.]*$', self.name) return self.is_command() and re.match('^trunc\.[a-z\.]*$', self.name)
...@@ -125,12 +136,13 @@ class Statement: ...@@ -125,12 +136,13 @@ class Statement:
def get_def(self): def get_def(self):
"""Get the variable that this statement defines, if any.""" """Get the variable that this statement defines, if any."""
inst = ['move', 'addu', 'subu', 'li', 'mtc1', 'dmfc1'] instr = ['move', 'addu', 'subu', 'li', 'mtc1', 'dmfc1']
if self.is_load() or self.is_arith() or self.is_logical() \ if self.is_load_non_immediate() or self.is_arith() \
or self.is_double_arithmetic() or self.is_move_from_spec() \ or self.is_logical() or self.is_double_arithmetic() \
or self.is_double_unary() or self.is_set_if_less() \ or self.is_move_from_spec() or self.is_double_unary() \
or self.is_convert() or self.is_truncate() \ or self.is_set_if_less() or self.is_convert() \
or self.is_truncate() or self.is_load() \
or (self.is_command and self.name in instr): or (self.is_command and self.name in instr):
return self[0] return self[0]
......
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
from src.statement import Statement as S from src.statement import Statement as S
from src.dataflow import BasicBlock as B, find_leaders, find_basic_blocks, \ from src.dataflow import BasicBlock as B, find_leaders, find_basic_blocks, \
generate_flow_graph, Dag, DagNode, DagLeaf generate_flow_graph, Dag, DagNode, DagLeaf, defs, reaching_definitions
class TestDataflow(unittest.TestCase): class TestDataflow(unittest.TestCase):
...@@ -112,6 +112,46 @@ class TestDataflow(unittest.TestCase): ...@@ -112,6 +112,46 @@ class TestDataflow(unittest.TestCase):
# #
# self.assertEqualDag(dag, expect) # self.assertEqualDag(dag, expect)
def test_defs(self):
s1 = S('command', 'addu', '$3', '$1', '$2')
s2 = S('command', 'addu', '$1', '$3', 10)
s3 = S('command', 'subu', '$3', '$1', 5)
s4 = S('command', 'li', '$4', '0x00000001')
block = B([s1, s2, s3, s4])
self.assertEqual(defs([block]), {
'$3': set([s1.sid, s3.sid]),
'$1': set([s2.sid]),
'$4': set([s4.sid])
})
#def test_defs(self):
# s1 = S('command', 'add', '$3', '$1', '$2')
# s2 = S('command', 'move', '$1', '$3')
# s3 = S('command', 'move', '$3', '$2')
# s4 = S('command', 'li', '$4', '0x00000001')
# block = B([s1, s2, s3, s4])
# self.assertEqual(defs([block]), {
# '$3': set([s1.sid, s3.sid]),
# '$1': set([s2.sid]),
# '$4': set([s4.sid])
# })
def test_create_gen_kill_gen(self):
s1 = S('command', 'addu', '$3', '$1', '$2')
s2 = S('command', 'addu', '$1', '$3', 10)
s3 = S('command', 'subu', '$3', '$1', 5)
s4 = S('command', 'li', '$4', '0x00000001')
block = B([s1, s2, s3, s4])
block.create_gen_kill(defs([block]))
self.assertEqual(block.gen_set, set([s2.sid, s3.sid, s4.sid]))
#def test_get_kill_used(self):
# block = B([S('command', 'move', '$1', '$3'),
# S('command', 'add', '$3', '$1', '$2'),
# S('command', 'move', '$1', '$3'),
# S('command', 'move', '$2', '$3')])
# self.assertEqual(block.get_kill(), set())
def assertEqualDag(self, dag1, dag2): def assertEqualDag(self, dag1, dag2):
self.assertEqual(len(dag1.nodes), len(dag2.nodes)) self.assertEqual(len(dag1.nodes), len(dag2.nodes))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment