Просмотр исходного кода

Merge branch 'master' of github.com:taddeus/peephole

Jayke Meijer 14 лет назад
Родитель
Сommit
a38a01fb66
3 измененных файлов с 49 добавлено и 13 удалено
  1. 1 1
      src/optimize/__init__.py
  2. 42 12
      src/optimize/advanced.py
  3. 6 0
      tests/test_statement.py

+ 1 - 1
src/optimize/__init__.py

@@ -32,7 +32,7 @@ def optimize_block(block):
             | eliminate_common_subexpressions(block) \
             | fold_constants(block) \
             | copy_propagation(block)\
-            | algebraic_transformations(block) \
+            #| algebraic_transformations(block) \
             | eliminate_dead_code(block):
         pass
 

+ 42 - 12
src/optimize/advanced.py

@@ -149,22 +149,44 @@ def fold_constants(block):
             register[s[0]] = constants[s[1]]
         elif s.name == 'mflo':
             # Move of `Lo' register to another register
-            register[s[0]] = register['Lo']
+            register[s[0]] = register['$lo']
         elif s.name == 'mfhi':
             # Move of `Hi' register to another register
-            register[s[0]] = register['Hi']
+            register[s[0]] = register['$hi']
         elif s.name in ['mult', 'div'] \
-                and s[0] in register and s[1] in register:
+                and s[0]in register and s[1] in register:
             # Multiplication/division with constants
             rs, rt = s
+            a, b = register[rs], register[rt]
 
             if s.name == 'mult':
-                binary = bin(register[rs] * register[rt])[2:]
-                binary = '0' * (64 - len(binary)) + binary
-                register['Hi'] = int(binary[:32], base=2)
-                register['Lo'] = int(binary[32:], base=2)
+                if not a or not b:
+                    # Multiplication by 0
+                    hi = lo = to_hex(0)
+                elif a == 1:
+                    # Multiplication by 1
+                    hi = to_hex(0)
+                    lo = to_hex(b)
+                elif b == 1:
+                    # Multiplication by 1
+                    hi = to_hex(0)
+                    lo = to_hex(a)
+                else:
+                    # Calculate result and fill Hi/Lo registers
+                    binary = bin(a * b)[2:]
+                    binary = '0' * (64 - len(binary)) + binary
+                    hi = int(binary[:32], base=2)
+                    lo = int(binary[32:], base=2)
+
+                # Replace the multiplication with two immidiate loads to the
+                # Hi/Lo registers
+                block.replace(1, [S('command', 'li', '$hi', hi),
+                                S('command', 'li', '$lo', li)])
             elif s.name == 'div':
-                register['Lo'], register['Hi'] = divmod(rs, rt)
+                lo, hi = divmod(rs, rt)
+
+            register['$lo'], register['$hi'] = lo, hi
+            changed = True
         elif s.name in ['addu', 'subu']:
             # Addition/subtraction with constants
             rd, rs, rt = s
@@ -187,7 +209,9 @@ def fold_constants(block):
                 block.replace(1, [S('command', 'li', rd, result)])
                 register[rd] = result
                 changed = True
-            elif rt_known:
+                continue
+
+            if rt_known:
                 # a = 10        ->  b = c + 10
                 # b = c + a
                 s[2] = register[rt]
@@ -198,9 +222,15 @@ def fold_constants(block):
                 s[1] = rt
                 s[2] = register[rs]
                 changed = True
-        elif len(s) and s[0] in register:
-            # Known register is overwritten, remove its value
-            del register[s[0]]
+
+            if s[2] == 0:
+                # Addition/subtraction with 0
+                block.replace(1, [S('command', 'move', rd, s[1])])
+        else:
+            for reg in s.get_def():
+                if reg in register
+                    # Known register is overwritten, remove its value
+                    del register[reg]
 
     return changed
 

+ 6 - 0
tests/test_statement.py

@@ -93,3 +93,9 @@ class TestStatement(unittest.TestCase):
         self.assertTrue(S('command', 'addu', '$1', '$2', '$3').is_arith())
         self.assertFalse(S('command', 'foo').is_arith())
         self.assertFalse(S('label', 'addu').is_arith())
+        
+#    def test_get_def(self):
+#        self.assertEqual(S('command', 'addu', '$1', '$2', '$3'), '$1')
+#        
+        
+