Bläddra i källkod

Optimize strategy implementation:

- Instead of computing a complete score for each state, compute+compare
  the first component for each one, then the second, etc.
- Set colskip (and nrows) once and update occasionally once instead of
  constantly recomputing it.
- Add some more actions to do on the destination column.
- Fix move loop detection.
- Do not wait on bomb explosion (since they no longer cause loops).
Taddeus Kroes 5 år sedan
förälder
incheckning
213ebb7655
2 ändrade filer med 172 tillägg och 130 borttagningar
  1. 18 17
      bot.py
  2. 154 113
      strategy.py

+ 18 - 17
bot.py

@@ -4,9 +4,10 @@ import time
 from collections import deque
 from itertools import count
 from Xlib import error
-from strategy import State
+from detection import NOBLOCK
 from interaction import get_exapunks_window, focus_window, screenshot_board, \
                         press_keys, listen_keys, KEY_DELAY
+from strategy import State
 
 
 MAX_SPEED_ROWS = 3
@@ -33,7 +34,7 @@ if __name__ == '__main__':
 
         listen_keys({'s': lambda: save_screenshot(win)})
 
-        solutions = deque([], maxlen=3)
+        buf = deque([], maxlen=3)
 
         def vprint(*args, **kwargs):
             if verbose:
@@ -56,10 +57,10 @@ if __name__ == '__main__':
                 vprint()
 
                 start = time.time()
-                solution = state.solve()
+                newstate = state.solve()
                 end = time.time()
                 vprint('thought for', round((end - start) * 1000, 1), 'ms')
-            except (TypeError, AssertionError):
+            except (TypeError, AssertionError) as e:
                 vprint('\rerror during parsing, wait for a bit...', end='')
                 time.sleep(0.050)
                 continue
@@ -68,33 +69,33 @@ if __name__ == '__main__':
                 time.sleep(0.500)
                 continue
 
-            if len(solutions) == 3 and solution.loops(solutions.popleft()):
+            if state.held == NOBLOCK and any(map(newstate.loops, buf)):
                 vprint('\rloop detected, wait for a bit...', end='')
                 time.sleep(0.03)
-            elif solution.moves:
-                vprint('moves:', solution.keys())
-                vprint('     score:', solution.score)
-                if solutions:
-                    vprint('prev score:', solutions[-1].score)
+            elif newstate.moves:
+                vprint('moves:', newstate.keys())
+                vprint('     score:', newstate.score)
+                if buf:
+                    vprint('prev score:', buf[-1].score)
                 vprint()
 
                 vprint('target after moves:')
-                vprint_state(solution.newstate)
+                vprint_state(newstate)
 
-                press_keys(win, solution.keys())
+                press_keys(win, newstate.keys())
 
-                #keys_delay = len(solution.moves) * 2 * KEY_DELAY
-                #moves_delay = max(0, solution.delay() - keys_delay)
+                #keys_delay = len(newstate.moves) * 2 * KEY_DELAY
+                #moves_delay = max(0, newstate.delay() - keys_delay)
                 #vprint('wait for',  moves_delay, 'ms')
                 #time.sleep(moves_delay / 1000)
-                time.sleep(0.070)
-            elif state.nrows() - 2 <= MAX_SPEED_ROWS:
+                time.sleep(0.075)
+            elif state.nrows - 2 <= MAX_SPEED_ROWS:
                 vprint('no moves, speed up')
                 press_keys(win, 'l')
                 time.sleep(0.030)
             else:
                 vprint('no moves')
 
-            solutions.append(solution)
+            buf.append(newstate)
     except KeyboardInterrupt:
         print('interrupted, quitting')

+ 154 - 113
strategy.py

@@ -1,5 +1,6 @@
 import io
 import time
+from collections import deque
 from contextlib import redirect_stdout
 from itertools import combinations, islice
 from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
@@ -22,31 +23,43 @@ PUT = ((DROP,), (DROP, SWAP), (DROP, SWAP, GRAB, SWAP, DROP))
 MIN_BASIC_GROUP_SIZE = 4
 MIN_BOMB_GROUP_SIZE = 2
 POINTS_DEPTH = 3
-FRAG_DEPTH = 5
+FRAG_DEPTH = 4
 DEFRAG_PRIO = 4
 COLSIZE_PRIO = 5
 COLSIZE_PANIC = 8
 COLSIZE_MAX = 9
-BOMB_POINTS = 1
-MIN_ROWS = 2
+BOMB_POINTS = 5
 
 
 class State:
-    def __init__(self, blocks, exa, held):
+    def __init__(self, blocks, exa, held, colskip=None):
         self.blocks = blocks
         self.exa = exa
         self.held = held
+        self.moves = ()
+        self.score = ()
+        self.nrows = len(self.blocks) // COLUMNS
+
+        if colskip is None:
+            colskip = []
+            for col in range(COLUMNS):
+                for row in range(self.nrows):
+                    if self.blocks[row * COLUMNS + col] != NOBLOCK:
+                        colskip.append(row)
+                        break
+                else:
+                    colskip.append(self.nrows)
+
+        self.colskip = colskip
 
     def grabbing_or_dropping(self):
-        skip = self.colskip(self.exa)
+        skip = self.colskip[self.exa]
         i = (skip + 1) * COLUMNS + self.exa
         return i < len(self.blocks) and self.blocks[i] == NOBLOCK
 
     def iter_columns(self):
-        nrows = self.nrows()
-
         def gen_col(col):
-            for row in range(nrows):
+            for row in range(self.nrows):
                 i = row * COLUMNS + col
                 if self.blocks[i] != NOBLOCK:
                     yield i
@@ -62,13 +75,13 @@ class State:
         return cls(blocks, exa, held)
 
     def copy(self):
-        return State(list(self.blocks), self.exa, self.held)
+        return State(list(self.blocks), self.exa, self.held, list(self.colskip))
 
     def causes_panic(self):
         return self.max_colsize() >= COLSIZE_PANIC
 
     def max_colsize(self):
-        return self.nrows() - self.empty_rows()
+        return self.nrows - self.empty_rows()
 
     def empty_rows(self):
         for i, block in enumerate(self.blocks):
@@ -78,17 +91,16 @@ class State:
 
     def holes(self):
         start_row = self.empty_rows()
-        total_rows = self.nrows()
         score = 0
         for col in range(COLUMNS):
-            for row in range(start_row, total_rows):
+            for row in range(start_row, self.nrows):
                 if self.blocks[row * COLUMNS + col] != NOBLOCK:
                     break
                 score += row - start_row + 1
         return score
 
     def score(self, points, moves, prev):
-        prev_colsize = prev.nrows() - 2
+        prev_colsize = prev.nrows - 2
 
         #delay = moves_delay(moves)
         delay = len(moves)
@@ -100,7 +112,7 @@ class State:
             return -points, delay
 
         holes = self.holes()
-        frag = self.fragmentation()
+        frag = 0 if points else self.fragmentation()
 
         # When rows start stacking up, start defragmenting colors to make
         # opportunities for scoring points.
@@ -116,20 +128,6 @@ class State:
         # Column heights are getting out of hand, just move shit DOWN.
         return holes, delay, -points, frag
 
-    def solutions(self):
-        for moves in self.gen_moves():
-            try:
-                yield Solution(self, moves)
-            except AssertionError:
-                pass
-
-    def colskip(self, col):
-        nrows = self.nrows()
-        for row in range(nrows):
-            if self.blocks[row * COLUMNS + col] != NOBLOCK:
-                return row
-        return nrows
-
     def find_unmovable_blocks(self):
         unmoveable = set()
         bombed = set()
@@ -149,13 +147,16 @@ class State:
 
         return unmoveable
 
-    def simulate(self, moves):
-        s = self.copy()
-        #points = 0
+    def move(self, moves):
+        s = self.copy() if moves else self
+        s.moves = moves
 
         # avoid swapping/grabbing currently exploding items
         #unmoveable = s.find_unmovable_blocks()
 
+        s.placed = set()
+        s.grabbed = {}
+
         for move in moves:
             if move == LEFT:
                 assert s.exa > 0
@@ -165,35 +166,44 @@ class State:
                 s.exa += 1
             elif move == GRAB:
                 assert s.held == NOBLOCK
-                row = s.colskip(s.exa)
-                assert row < s.nrows()
+                row = s.colskip[s.exa]
+                assert row < s.nrows
                 i = row * COLUMNS + s.exa
                 #assert i not in unmoveable
                 s.held = s.blocks[i]
                 s.blocks[i] = NOBLOCK
+                s.grabbed[i] = s.held
+                s.colskip[s.exa] += 1
             elif move == DROP:
                 assert s.held != NOBLOCK
-                row = s.colskip(s.exa)
+                row = s.colskip[s.exa]
                 assert row > 0
-                i = row * COLUMNS + s.exa
-                s.blocks[i - COLUMNS] = s.held
+                i = (row - 1) * COLUMNS + s.exa
+                s.blocks[i] = s.held
                 s.held = NOBLOCK
-                #points += s.remove_blocks()
+                s.placed.add(i)
+                s.colskip[s.exa] -= 1
             elif move == SWAP:
-                row = s.colskip(s.exa)
-                assert row < s.nrows() - 2
+                row = s.colskip[s.exa]
                 i = row * COLUMNS + s.exa
                 j = i + COLUMNS
+                assert j < len(s.blocks)
                 #assert i not in unmoveable
                 #assert j not in unmoveable
-                s.blocks[i], s.blocks[j] = s.blocks[j], s.blocks[i]
-                #points += s.remove_blocks()
+                bi = s.blocks[i]
+                bj = s.blocks[j]
+                if bi != bj:
+                    s.blocks[i] = bj
+                    s.blocks[j] = bi
+                    s.grabbed[i] = bi
+                    s.grabbed[j] = bj
+                    s.placed.add(i)
+                    s.placed.add(j)
 
         if moves and self.max_colsize() < COLSIZE_MAX:
             assert s.max_colsize() <= COLSIZE_MAX
 
-        points = s.remove_blocks()
-        return points, s
+        return s
 
     def find_groups(self, depth=POINTS_DEPTH, minsize=2):
         def follow_group(i, block, group):
@@ -221,7 +231,7 @@ class State:
             yield i + 1
         if row > 0 and self.blocks[i - COLUMNS] != NOBLOCK:
             yield i - COLUMNS
-        if row < self.nrows() - 1 and self.blocks[i + COLUMNS] != NOBLOCK:
+        if row < self.nrows - 1 and self.blocks[i + COLUMNS] != NOBLOCK:
             yield i + COLUMNS
 
     def fragmentation(self, depth=FRAG_DEPTH):
@@ -250,47 +260,46 @@ class State:
                 groups[i] = groupid
                 groupsizes[i] = len(group)
 
-        return sum(dist(i, j)  # * (1 + 2 * is_bomb(block))
+        return sum(dist(i, j)
                    for block, color in colors.items()
                    for i, j in combinations(color, 2))
 
-    def remove_blocks(self):
-        removed = 0
-
-        for block, group in self.find_groups():
-            if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
-                removed += len(group)
-            elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
-                removed += BOMB_POINTS
-
-        return removed
-
-    #def remove_blocks(self):
-    #    remove = []
-
-    #    for block, group in self.find_groups():
-    #        if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
-    #            remove.extend(group)
-    #        elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
-    #            remove.extend(group)
-    #            remove.extend(i for i, other in enumerate(self.blocks)
-    #                          if other == bomb_to_basic(block))
+    def points(self):
+        def group_size(start):
+            work = [start]
+            visited.add(start)
+            size = 0
+            block = self.blocks[start]
+
+            while work:
+                i = work.pop()
+
+                # avoid giving points to moving a block within the same group
+                if self.grabbed.get(i, None) == block:
+                    return 0
+
+                if self.blocks[i] == block:
+                    size += 1
+                    for nb in self.neighbors(i):
+                        if nb not in visited:
+                            visited.add(nb)
+                            work.append(nb)
+            return size
+
+        points = 0
+        visited = set()
 
-    #    remove.sort()
-    #    removed = 0
-    #    prev = None
-    #    for i in remove:
-    #        if i != prev:
-    #            while self.blocks[i] != NOBLOCK:
-    #                self.blocks[i] = self.blocks[i - COLUMNS]
-    #                i -= COLUMNS
-    #            removed += 1
-    #        prev = i
+        for i in self.placed:
+            if i not in visited:
+                block = self.blocks[i]
+                size = group_size(i)
 
-    #    if removed:
-    #        removed += self.remove_blocks()
+                if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
+                    points += size
+                elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
+                    points += BOMB_POINTS
 
-    #    return removed
+        return -points
 
     def has_explosion(self):
         return any(is_bomb(block) and
@@ -300,39 +309,78 @@ class State:
     def gen_moves(self):
         yield ()
 
-        def make_move(diff):
+        def shift_exa(diff):
             direction = RIGHT if diff > 0 else LEFT
             return abs(diff) * (direction,)
 
         ignore_exa_column = self.grabbing_or_dropping()
 
         for src in range(COLUMNS):
-            mov1 = make_move(src - self.exa)
+            mov1 = shift_exa(src - self.exa)
             if mov1 or not ignore_exa_column:
                 yield mov1 + (SWAP,)
                 yield mov1 + (GRAB, SWAP, DROP)
                 yield mov1 + (SWAP, GRAB, SWAP, DROP)
+                yield mov1 + (GRAB, SWAP, DROP, SWAP)
+                yield mov1 + (SWAP, GRAB, SWAP, DROP, SWAP)
 
                 for dst in range(COLUMNS):
                     if dst != src:
-                        mov2 = make_move(dst - src)
+                        mov2 = shift_exa(dst - src)
                         for get in GET:
                             for put in PUT:
                                 yield mov1 + get + mov2 + put
 
+    def gen_valid_moves(self):
+        for moves in self.gen_moves():
+            try:
+                yield self.move(moves)
+            except AssertionError:
+                pass
+
     def solve(self):
         assert self.exa is not None
 
         if self.held != NOBLOCK:
-            return Solution(self, (DROP,))
+            return self.move((DROP,))
+
+        valid = deque(self.gen_valid_moves())
+
+        if len(valid) == 0:
+            return self.move(())
+
+        best_score = ()
+
+        for key in self.score_keys():
+            if len(valid) == 1:
+                break
+
+            for state in valid:
+                state.score = key(state)
 
-        if self.nrows() < MIN_ROWS:
-            return Solution(self, ())
+            best = min(state.score for state in valid)
+            best_score += (best,)
 
-        if self.has_explosion():
-            return Solution(self, ())
+            for i in range(len(valid)):
+                state = valid.popleft()
+                if state.score == best:
+                    valid.append(state)
 
-        return min(self.solutions())
+        best = valid.popleft()
+        best.score = best_score
+        return best
+
+    def score_keys(self):
+        cls = self.__class__
+        colsize = self.nrows - 2
+
+        if colsize >= COLSIZE_PANIC:
+            return cls.holes, cls.nmoves, cls.points, cls.fragmentation
+
+        if colsize >= COLSIZE_PRIO:
+            return cls.causes_panic, cls.points, cls.holes, cls.fragmentation, cls.nmoves
+
+        return cls.points, cls.fragmentation, cls.holes, cls.nmoves
 
     def print(self):
         print_board(self.blocks, self.exa, self.held)
@@ -343,28 +391,11 @@ class State:
             self.print()
         return stream.getvalue()
 
-    def nrows(self):
-        return len(self.blocks) // COLUMNS
-
     def has_same_exa(self, state):
         return self.exa == state.exa and self.held == state.held
 
-
-class Solution:
-    def __init__(self, state, moves):
-        self.state = state
-        self.moves = moves
-        points, self.newstate = state.simulate(moves)
-        self.score = self.newstate.score(points, moves, state)
-
-    def __lt__(self, other):
-        return self.score < other.score
-
-    def loops(self, prev_prev):
-        return self.moves and \
-               self.state.exa == prev_prev.state.exa and \
-               self.moves == prev_prev.moves and \
-               self.score == prev_prev.score
+    def nmoves(self):
+        return len(self.moves)
 
     def delay(self):
         return moves_delay(self.moves)
@@ -372,6 +403,15 @@ class Solution:
     def keys(self):
         return moves_to_keys(self.moves)
 
+    def __lt__(self, other):
+        return self.score < other.score
+
+    def loops(self, prev):
+        return self.moves and \
+               self.exa == prev.exa and \
+               self.moves == prev.moves and \
+               self.score == prev.score
+
 
 def move_to_key(move):
     return 'jjkadl'[move]
@@ -397,15 +437,16 @@ if __name__ == '__main__':
     print()
 
     start = time.time()
-    solution = state.solve()
+    newstate = state.solve()
     end = time.time()
-    print('best moves:', solution.keys())
+    print('best move:', newstate.keys())
+    print('score:', newstate.score)
     print('elapsed:', round((end - start) * 1000, 1), 'ms')
     print()
 
-    print('target after moves:')
-    solution.newstate.print()
+    print('target after move:')
+    newstate.print()
     print()
 
-    for solution in sorted(state.solutions()):
-        print('move %18s:' % solution.keys(), solution.score)
+    #for solution in sorted(state.solutions()):
+    #    print('move %18s:' % solution.keys(), solution.score)