Преглед изворни кода

Rewritten old files to use new statement/block classes.

Taddeus Kroes пре 14 година
родитељ
комит
02c89c74c0
7 измењених фајлова са 200 додато и 176 уклоњено
  1. 0 51
      src/basic_block.py
  2. 24 0
      src/main.py
  3. 0 41
      src/optimize.py
  4. 37 7
      src/optimizer.py
  5. 8 6
      src/parser.py
  6. 106 41
      src/utils.py
  7. 25 30
      src/writer.py

+ 0 - 51
src/basic_block.py

@@ -1,51 +0,0 @@
-# TODO: JALR & JR
-JUMP_COMMANDS = ['j', 'jal', 'beq', 'bne', 'blez', 'bgtz', 'bltz', 'bgez', \
-                 'bc1f', 'bc1t']
-
-def is_jump(statement):
-    '''Check if a statement is a jump command.'''
-    return statement[0] == 'command' and statement[1] in JUMP_COMMANDS
-
-def find_leaders(statements):
-    '''Determine the leaders, which are:
-       1. The first statement.
-       2. Any statement that is the target of a jump.
-       3. Any statement that follows directly follows a jump.'''
-    leaders = [0]
-    jump_target_labels = []
-
-    # Append statements following jumps and save jump target labels
-    for i, statement in enumerate(statements[1:]):
-        if is_jump(statement):
-            leaders.append(i + 2)
-#            print statement[2]['args'][-1]
-            jump_target_labels.append(statement[2]['args'][-1])
-            #print 'found jump:', i, statement
-
-#    print 'target labels:', jump_target_labels
-#    print 'leaders:', leaders
-
-    # Append jump targets
-    for i, statement in enumerate(statements[1:]):
-        if i + 1 not in leaders \
-                and statement[0] == 'label' \
-                and statement[1] in jump_target_labels:
-            leaders.append(i + 1)
-            #print 'target:', i + 1, statements[i + 1]
-    
-    leaders.sort()
-    
-    return leaders
-
-def find_basic_blocks(statements):
-    '''Divide a statement list into basic blocks. Returns a list of basic
-    blocks, which are also statement lists.'''
-    leaders = find_leaders(statements)
-    blocks = []
-
-    for i in range(len(leaders) - 1):
-        blocks.append(statements[leaders[i]:leaders[i + 1]])
-
-    blocks.append(statements[leaders[-1]:])
-    
-    return blocks

+ 24 - 0
src/main.py

@@ -0,0 +1,24 @@
+#!/usr/bin/python
+from parser import parse_file
+from optimize import optimize
+from writer import write_statements
+
+if __name__ == '__main__':
+    from sys import argv, exit
+
+    if len(argv) < 2:
+        print 'Usage: python %s FILE' % argv[0]
+        exit(1)
+
+    # Parse File
+    statements = parse_file(argv[1])
+    statements = optimize(statements, verbose=1)
+
+    # Rewrite to assembly
+    out = write_statements(statements)
+
+    if len(argv) > 2:
+        # Save output assembly
+        f = open(argv[2], 'w+')
+        f.write(out)
+        f.close()

+ 0 - 41
src/optimize.py

@@ -1,41 +0,0 @@
-#!/usr/bin/python
-from parser import parse_file
-from basic_block import find_basic_blocks
-from optimizer import optimize_blocks, optimize_global
-from writer import write_statements
-
-if __name__ == '__main__':
-    from sys import argv, exit
-
-    if len(argv) < 2:
-        print 'Usage: python %s FILE' % argv[0]
-        exit(1)
-
-    # Parse File
-    statements = parse_file(argv[1])
-    st_original = len(statements)
-
-    # Optimize on a global level
-    statements = optimize_global(statements)
-
-    st_aft_global = len(statements)
-
-    # Create basic blocks
-    blocks = find_basic_blocks(statements)
-
-    # Optimize basic blocks
-    statements = optimize_blocks(blocks)
-
-    # Rewrite to assembly
-    out = write_statements(statements)
-
-    print "Optimization:"
-    print "Original statements:", st_original
-    print "After global optimization:", st_aft_global
-    print "After basic blocks optimization:", len(statements)
-
-    if len(argv) > 2:
-        # Save output assembly
-        f = open(argv[2], 'w+')
-        f.write(out)
-        f.close()

+ 37 - 7
src/optimizer.py

@@ -1,13 +1,18 @@
+from utils import Statement as S, Block, find_basic_blocks
+
+
 def equal_mov(s):
-    '''Check for useless move operations.'''
+    """Check for useless move operations."""
     return s.is_command() and s.name == 'move' and s[0] == s[1]
 
+
 def empty_shift(s):
-    '''Check for useless shift operations.'''
+    """Check for useless shift operations."""
     return s.is_shift() and s[0] == s[1] and s[2] == 0
 
+
 def optimize_branch_jump_label(statements):
-    '''Optimize jumps after branches.'''
+    """Optimize jumps after branches."""
     out_statements = []
 
     for i in xrange(len(statements)):
@@ -31,15 +36,17 @@ def optimize_branch_jump_label(statements):
 
     return out_statements
 
+
 def optimize_global(statements):
-    '''Optimize one-line statements in entire code.'''
+    """Optimize one-line statements in entire code."""
     statements = optimize_branch_jump_label(statements)
 
     return filter(lambda s: not equal_mov(s) and not empty_shift(s), statements)
 
+
 def optimize_blocks(blocks):
-    '''Call the optimizer for each basic block. Do this several times until
-    no more optimizations are achieved.'''
+    """Call the optimizer for each basic block. Do this several times until
+    no more optimizations are achieved."""
     changed = True
 
     while changed:
@@ -57,8 +64,9 @@ def optimize_blocks(blocks):
 
     return reduce(lambda a, b: a + b, blocks, [])
 
+
 def optimize_block(statements):
-    '''Optimize a basic block.'''
+    """Optimize a basic block."""
     changed = False
     output_statements = []
 
@@ -68,3 +76,25 @@ def optimize_block(statements):
         output_statements.append(new_statement)
 
     return changed, output_statements
+
+
+def optimize(original, verbose=0):
+    # Optimize on a global level
+    opt_global = optimize_global(original)
+
+    # Optimize basic blocks
+    basic_blocks = find_basic_blocks(opt_global)
+    blocks = optimize_blocks(basic_blocks)
+    opt_blocks = reduce(lambda a, b: a.statements + b.statements, blocks)
+
+    if verbose:
+        o = len(original)
+        g = len(opt_global)
+        b = len(opt_blocks)
+        print 'Original statements:             %d' % o
+        print 'After global optimization:       %d' % g
+        print 'After basic blocks optimization: %d' % b
+        print 'Speedup:                         %d (%d%%)' \
+                % (b - o, int((b - o) / o * 100))
+
+    return opt_blocks

+ 8 - 6
src/parser.py

@@ -1,6 +1,8 @@
 import ply.lex as lex
 import ply.yacc as yacc
 
+from utils import Statement as S, Block
+
 # Global statements administration
 statements = []
 
@@ -72,11 +74,11 @@ def p_line_instruction(p):
 
 def p_line_comment(p):
     'line : COMMENT NEWLINE'
-    statements.append(('comment', p[1], {'inline': False}))
+    statements.append(S('comment', p[1], inline=False))
 
 def p_line_inline_comment(p):
     'line : instruction COMMENT NEWLINE'
-    statements.append(('comment', p[2], {'inline': True}))
+    statements.append(S('comment', p[2], inline=True))
 
 def p_instruction_command(p):
     'instruction : command'
@@ -84,18 +86,18 @@ def p_instruction_command(p):
 
 def p_instruction_directive(p):
     'instruction : DIRECTIVE'
-    statements.append(('directive', p[1], None))
+    statements.append(S('directive', p[1]))
 
 def p_instruction_label(p):
     'instruction : WORD COLON'
-    statements.append(('label', p[1], None))
+    statements.append(S('label', p[1]))
 
 def p_command(p):
     '''command : WORD WORD COMMA WORD COMMA WORD
                | WORD WORD COMMA WORD
                | WORD WORD
                | WORD'''
-    statements.append(('command', p[1], {'args': list(p)[2::2]}))
+    statements.append(S('command', p[1], *list(p)[2::2]))
 
 def p_error(p):
     print 'Syntax error at "%s" on line %d' % (p.value, lexer.lineno)
@@ -113,4 +115,4 @@ def parse_file(filename):
     except IOError:
         print 'File "%s" could not be opened' % filename
 
-    return statements
+    return Block(statements)

+ 106 - 41
src/utils.py

@@ -1,34 +1,46 @@
 import re
 
 class Statement:
-    def __init__(self, stype, name, *args):
+    def __init__(self, stype, name, *args, **kwargs):
         self.stype = stype
         self.name = name
         self.args = args
+        self.options = kwargs
 
     def __getitem__(self, n):
         """Get an argument."""
         return self.args[n]
 
-    def jump_target(self, arg):
-        """Get the use[S] of this statement."""
-        if self.name in ['beq', 'bne', 'blez', 'bgtz', 'bltz', 'bgez', \
-                         'bct', 'bcf']:
-            return self[1]
-        else:
-            raise Exception('"%s" command has no jump target' % self.name)
+    def __eq__(self, other):
+        """Check if two statements are equal by comparing their type, name and
+        arguments."""
+        return self.stype == other.stype and self.name == other.name \
+                and self.args == other.args
+
+    def is_comment(self):
+        return self.stype == 'comment'
+
+    def is_inline_comment(self):
+        return self.is_comment() and self.options['inline']
+
+    def is_directive(self):
+        return self.stype == 'directive'
+
+    def is_label(self):
+        return self.stype == 'label'
 
     def is_command(self):
-        """Check if the statement is a command."""
         return self.stype == 'command'
 
     def is_jump(self):
         """Check if the statement is a jump."""
-        return self.is_command() and re.match('j|jal|jr|jalr', self.name)
+        return self.is_command() \
+               and re.match('^j|jal|beq|bne|blez|bgtz|bltz|bgez|bct|bcf$', \
+                            self.name)
 
     def is_shift(self):
         """Check if the statement is a shift operation."""
-        return self.is_command() and re.match('s(ll|la|rl|ra)', self.name
+        return self.is_command() and re.match('^s(ll|la|rl|ra)$', self.name)
 
     def is_load(self):
         """Check if the statement is a load instruction."""
@@ -37,7 +49,14 @@ class Statement:
     def is_arith(self):
         """Check if the statement is an arithmetic operation."""
         return self.is_command() \
-               and re.match('(add|sub|mult|div|abs|neg)(u|\.d)?', self.name)
+               and re.match('^(add|sub|mult|div|abs|neg)(u|\.d)?$', self.name)
+
+    def jump_target(self, arg):
+        """Get the jump target of this statement."""
+        if re.match('^beq|bne|blez|bgtz|bltz|bgez|bct|bcf$', self.name):
+            return self[1]
+        else:
+            raise Exception('Command "%s" has no jump target' % self.name)
 
     def get_def(self):
         """Get the def[S] of this statement."""
@@ -47,9 +66,6 @@ class Statement:
         if self.is_load() or self.is_arith():
             return [self[0]]
 
-        if self.arith():
-            return [self[0]]
-
     def get_use(self, arg):
         """Get the use[S] of this statement."""
         return []
@@ -62,40 +78,89 @@ class Statement:
         """Check if a variable is used by this statement."""
         return var in self.get_use()
 
-class StatementList:
-    def __init__(self, statement_list):
-        self.statement_list = statement_list
-        self.pointer = 0
 
-    def __getitem__(self, n):
-        return self.statement_list[n]
+class Block:
+    def __init__(self, statements=[]):
+        self.statements = statements
+        self.pointer = 0
 
     def __iter__(self):
-        return iter(self.statement_list)
+        return iter(self.statements)
 
     def __len__(self):
-        return len(self.statement_list)
-
-    def get_range(self, start, end):
-        return self.statement_list[start:end]
+        return len(self.statements)
 
     def replace(self, start, end, replacement):
-        before = self.statement_list[:start]
-        after = self.statement_list[end:]
-        self.statement_list = before + replacement + after
+        """Replace the given range start-end with the given statement list, and
+        move the pointer to the first statement after the replacement."""
+        before = self.statements[:start]
+        after = self.statements[end:]
+        self.statements = before + replacement + after
+        self.pointer = start + len(replacement)
+
+    def read(self, count=1):
+        """Read the statement at the current pointer position and move the
+        pointer one position to the right."""
+        s = statements[self.pointer]
+        self.pointer += 1
+
+        return s
+
+    def peek(self, count=1):
+        """Read the statements until an offset from the current pointer
+        position."""
+        i = self.pointer + offset
+
+        if i < len(self.statements):
+            return self.statements[self.pointer:i]
+
+    def end(self):
+        """Check if the pointer is at the end of the statement list."""
+        return self.pointer == len(self.statements) - 1
+
+
+def find_leaders(statements):
+    """Determine the leaders, which are:
+       1. The first statement.
+       2. Any statement that is the target of a jump.
+       3. Any statement that follows directly follows a jump."""
+    leaders = [0]
+    jump_target_labels = []
+
+    # Append statements following jumps and save jump target labels
+    for i, statement in enumerate(statements[1:]):
+        if statement.is_jump():
+            leaders.append(i + 2)
+            jump_target_labels.append(statement[-1])
+
+    # Append jump targets
+    for i, statement in enumerate(statements[1:]):
+        if i + 1 not in leaders \
+                and statement.is_label() \
+                and statement.name in jump_target_labels:
+            leaders.append(i + 1)
+
+    leaders.sort()
+
+    return leaders
+
+
+def find_basic_blocks(statements):
+    """Divide a statement list into basic blocks. Returns a list of basic
+    blocks, which are also statement lists."""
+    leaders = find_leaders(statements)
+    blocks = []
+
+    for i in range(len(leaders) - 1):
+        blocks.append(Block(statements[leaders[i]:leaders[i + 1]]))
 
-    def move_pointer(self, steps=1):
-        self.pointer += steps
+    blocks.append(Block(statements[leaders[-1]:]))
 
-    def liveness(self, stype):
-        """Check if the statement is of a given type."""
-        return self.stype == stype
+    return blocks
 
 
-while not block.end():
-    if (...):
-        i = block.current()
-        block.replace(i, i + 3, [nieuwe statements])
-        block.move_pointer(3)
-    else:
-        block.move_pointer(1)
+#while not block.end():
+#    i, s = block.read()
+#
+#    if block.peek():
+#        block.replace(i, i + 3, [nieuwe statements])

+ 25 - 30
src/writer.py

@@ -1,52 +1,47 @@
 from math import ceil
 
 def write_statements(statements):
-    '''Write a list of statements to valid assembly code.'''
-    s = ''
+    """Write a list of statements to valid assembly code."""
+    out = ''
     indent_level = 0
     prevline = ''
 
-    for i, statement in enumerate(statements):
-        statement_type, name, args = statement
+    for i, s in enumerate(statements):
         newline = '\n' if i else ''
 
-        if statement_type == 'label':
-            line = name + ':'
+        if s.is_label():
+            line = s.name + ':'
             indent_level = 1
-        elif statement_type == 'comment':
-            line = '#' + name
-
-            if args['inline']:
-                l = len(prevline.expandtabs(4))
-                tabs = int(ceil((24 - l) / 4.)) + 1
-                newline = '\t' * tabs
-            else:
-                line = '\t' * indent_level + line
-        elif statement_type == 'directive':
-            line = '\t' + name
-        elif statement_type == 'command':
-            line = '\t' + name
-
-            if len(args['args']):
-                l = len(name)
-
-                if l < 8:
+        elif s.is_inline_comment():
+            line = '#' + s.name
+            l = len(prevline.expandtabs(4))
+            tabs = int(ceil((24 - l) / 4.)) + 1
+            newline = '\t' * tabs
+        elif s.is_comment():
+            line = '\t' * indent_level + line
+        elif s.is_directive():
+            line = '\t' + s.name
+        elif s.is_command():
+            line = '\t' + s.name
+
+            if len(s):
+                if len(s.name) < 8:
                     line += '\t'
                 else:
                     line += ' '
 
-                line += ','.join(args['args'])
+                line += ','.join(s.args)
         else:
-            raise Exception('Unsupported statement type "%s"' % statement_type)
+            raise Exception('Unsupported statement type "%s"' % s.stype)
 
-        s += newline + line
+        out += newline + line
         prevline = line
 
-    return s
+    return out
 
 def write_to_file(filename, statements):
-    '''Convert a list of statements to valid assembly code and write it to a
-    file.'''
+    """Convert a list of statements to valid assembly code and write it to a
+    file."""
     s = write_statements(statements)
     f = open(filename, 'w+')
     f.write(s)