Przeglądaj źródła

Compute groups once and update them on each move

Taddeus Kroes 5 lat temu
rodzic
commit
0db87a732d
1 zmienionych plików z 116 dodań i 103 usunięć
  1. 116 103
      strategy.py

+ 116 - 103
strategy.py

@@ -3,7 +3,7 @@ import time
 from collections import deque
 from collections import deque
 from contextlib import redirect_stdout
 from contextlib import redirect_stdout
 from copy import copy
 from copy import copy
-from itertools import combinations, islice
+from itertools import combinations
 from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
 from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
                       detect_held, print_board, is_basic, is_bomb
                       detect_held, print_board, is_basic, is_bomb
 
 
@@ -20,25 +20,22 @@ MOVE_DELAYS = (
 )
 )
 MIN_BASIC_GROUP_SIZE = 4
 MIN_BASIC_GROUP_SIZE = 4
 MIN_BOMB_GROUP_SIZE = 2
 MIN_BOMB_GROUP_SIZE = 2
-POINTS_DEPTH = 3
-FRAG_DEPTH = 4
-DEFRAG_PRIO = 4
 COLSIZE_PRIO = 5
 COLSIZE_PRIO = 5
 COLSIZE_PANIC = 8
 COLSIZE_PANIC = 8
-COLSIZE_MAX = 9
 BOMB_POINTS = 5
 BOMB_POINTS = 5
 
 
 
 
 class State:
 class State:
-    def __init__(self, blocks, exa, held, colskip, busy, moves, placed, grabbed):
+    def __init__(self, blocks, exa, held, colskip, busy, moves, groups, groupsizes, maxgroup):
         self.blocks = blocks
         self.blocks = blocks
         self.exa = exa
         self.exa = exa
         self.held = held
         self.held = held
         self.colskip = colskip
         self.colskip = colskip
         self.busy = busy
         self.busy = busy
         self.moves = moves
         self.moves = moves
-        self.placed = placed
-        self.grabbed = grabbed
+        self.groups = groups
+        self.groupsizes = groupsizes
+        self.maxgroup = maxgroup
         self.nrows = len(blocks) // COLUMNS
         self.nrows = len(blocks) // COLUMNS
 
 
     @classmethod
     @classmethod
@@ -48,7 +45,8 @@ class State:
         held = detect_held(board, exa)
         held = detect_held(board, exa)
         colskip = get_colskip(blocks)
         colskip = get_colskip(blocks)
         busy = get_busy(blocks, colskip)
         busy = get_busy(blocks, colskip)
-        return cls(blocks, exa, held, colskip, busy, (), set(), {})
+        groups, groupsizes, maxgroup = get_groups(blocks)
+        return cls(blocks, exa, held, colskip, busy, (), groups, groupsizes, maxgroup)
 
 
     def copy(self, deep):
     def copy(self, deep):
         mcopy = copy if deep else lambda x: x
         mcopy = copy if deep else lambda x: x
@@ -58,26 +56,18 @@ class State:
                               mcopy(self.colskip),
                               mcopy(self.colskip),
                               self.busy,
                               self.busy,
                               self.moves,
                               self.moves,
-                              mcopy(self.placed),
-                              mcopy(self.grabbed))
+                              mcopy(self.groups),
+                              mcopy(self.groupsizes),
+                              self.maxgroup))
 
 
     def colbusy(self, col):
     def colbusy(self, col):
         return (self.busy >> col) & 1
         return (self.busy >> col) & 1
 
 
-    def grabbing_or_dropping(self):
-        skip = self.colskip[self.exa]
-        i = (skip + 1) * COLUMNS + self.exa
-        return i < len(self.blocks) and self.blocks[i] == NOBLOCK
-
-    def iter_columns(self):
-        def gen_col(col):
-            for row in range(self.nrows):
-                i = row * COLUMNS + col
-                if self.blocks[i] != NOBLOCK:
-                    yield i
+    def colrows(self, col):
+        return self.nrows - self.colskip[col]
 
 
-        for col in range(COLUMNS):
-            yield gen_col(col)
+    def maxrows(self):
+        return max(map(self.colrows, range(COLUMNS)))
 
 
     def causes_panic(self):
     def causes_panic(self):
         return self.max_colsize() >= COLSIZE_PANIC
         return self.max_colsize() >= COLSIZE_PANIC
@@ -101,9 +91,6 @@ class State:
                 score += row - start_row + 1
                 score += row - start_row + 1
         return score
         return score
 
 
-    def colrows(self, col):
-        return self.nrows - self.colskip[col]
-
     def move(self, *moves):
     def move(self, *moves):
         deep = any(move in (GRAB, DROP, SWAP) for move in moves)
         deep = any(move in (GRAB, DROP, SWAP) for move in moves)
         s = self.copy(deep)
         s = self.copy(deep)
@@ -124,22 +111,21 @@ class State:
                 row = s.colskip[s.exa]
                 row = s.colskip[s.exa]
                 assert row < s.nrows
                 assert row < s.nrows
                 i = row * COLUMNS + s.exa
                 i = row * COLUMNS + s.exa
-                s.grabbed[i] = s.held = s.blocks[i]
+                s.held = s.blocks[i]
                 s.blocks[i] = NOBLOCK
                 s.blocks[i] = NOBLOCK
-                s.grabbed[i] = s.held
                 s.colskip[s.exa] += 1
                 s.colskip[s.exa] += 1
+                s.ungroup(i)
 
 
             elif move == DROP:
             elif move == DROP:
                 assert not s.colbusy(s.exa)
                 assert not s.colbusy(s.exa)
                 assert s.held != NOBLOCK
                 assert s.held != NOBLOCK
                 row = s.colskip[s.exa]
                 row = s.colskip[s.exa]
                 assert row > 0
                 assert row > 0
-                # XXX assert s.nrows - row < COLSIZE_MAX
                 i = (row - 1) * COLUMNS + s.exa
                 i = (row - 1) * COLUMNS + s.exa
                 s.blocks[i] = s.held
                 s.blocks[i] = s.held
                 s.held = NOBLOCK
                 s.held = NOBLOCK
-                s.placed.add(i)
                 s.colskip[s.exa] -= 1
                 s.colskip[s.exa] -= 1
+                s.regroup(i)
 
 
             elif move == SWAP:
             elif move == SWAP:
                 assert not s.colbusy(s.exa)
                 assert not s.colbusy(s.exa)
@@ -150,45 +136,55 @@ class State:
                 bi = s.blocks[i]
                 bi = s.blocks[i]
                 bj = s.blocks[j]
                 bj = s.blocks[j]
                 if bi != bj:
                 if bi != bj:
-                    s.blocks[i] = bj
+                    s.blocks[i] = NOBLOCK
+                    s.blocks[j] = NOBLOCK
+                    s.ungroup(i)
+                    s.ungroup(j)
                     s.blocks[j] = bi
                     s.blocks[j] = bi
-                    s.grabbed[i] = bi
-                    s.grabbed[j] = bj
-                    s.placed.add(i)
-                    s.placed.add(j)
+                    s.regroup(j)
+                    s.blocks[i] = bj
+                    s.regroup(i)
 
 
         return s
         return s
 
 
-    def find_groups(self, depth=POINTS_DEPTH, minsize=2):
-        def follow_group(i, block, group):
-            if self.blocks[i] == block and i not in visited:
-                group.append(i)
+    def ungroup(self, i):
+        assert self.blocks[i] == NOBLOCK
+        visited = set()
+        oldid = self.groups[i]
+
+        for nb in neighbors(i, self.blocks):
+            if self.groups[nb] == oldid:
+                newgroup = self.scan_group(nb, visited)
+                if newgroup:
+                    self.maxgroup = newid = self.maxgroup + 1
+                    for j in newgroup:
+                        self.groups[j] = newid
+                        self.groupsizes[j] = len(newgroup)
+
+        self.groups[i] = 0
+        self.groupsizes[i] = 0
+
+    def regroup(self, i):
+        assert self.blocks[i] != NOBLOCK
+        self.maxgroup = newid = self.maxgroup + 1
+        newgroup = self.scan_group(i, set())
+        for j in newgroup:
+            self.groups[j] = newid
+            self.groupsizes[j] = len(newgroup)
+
+    def scan_group(self, start, visited):
+        def scan(i):
+            if i not in visited:
+                yield i
                 visited.add(i)
                 visited.add(i)
-                for nb in self.neighbors(i):
-                    follow_group(nb, block, group)
+                for nb in neighbors(i, self.blocks):
+                    if self.blocks[nb] == block:
+                        yield from scan(nb)
 
 
-        visited = set()
+        block = self.blocks[start]
+        return tuple(scan(start))
 
 
-        for col in self.iter_columns():
-            for i in islice(col, depth):
-                block = self.blocks[i]
-                group = []
-                follow_group(i, block, group)
-                if len(group) >= minsize:
-                    yield block, group
-
-    def neighbors(self, i):
-        row, col = divmod(i, COLUMNS)
-        if col > 0 and self.blocks[i - 1] != NOBLOCK:
-            yield i - 1
-        if col < COLUMNS - 1 and self.blocks[i + 1] != NOBLOCK:
-            yield i + 1
-        if row > 0 and self.blocks[i - COLUMNS] != NOBLOCK:
-            yield i - COLUMNS
-        if row < self.nrows - 1 and self.blocks[i + COLUMNS] != NOBLOCK:
-            yield i + COLUMNS
-
-    def fragmentation(self, depth=FRAG_DEPTH):
+    def fragmentation(self):
         """
         """
         Minimize the sum of dist(i,j) between all blocks i,j of the same color.
         Minimize the sum of dist(i,j) between all blocks i,j of the same color.
         Magnify vertical distances to avoid column stacking.
         Magnify vertical distances to avoid column stacking.
@@ -199,64 +195,45 @@ class State:
 
 
             # for blocks in the same group, only count vertical distance so
             # for blocks in the same group, only count vertical distance so
             # that groups are spread out horizontally
             # that groups are spread out horizontally
-            if groups[i] == groups[j]:
+            if self.groups[i] == self.groups[j]:
                 return abs(yj - yi)
                 return abs(yj - yi)
 
 
             return abs(xj - xi) + abs(yj - yi) * 2 - 1
             return abs(xj - xi) + abs(yj - yi) * 2 - 1
 
 
         colors = {}
         colors = {}
-        groups = {}
-
-        for groupid, (block, group) in enumerate(self.find_groups(depth, 1)):
-            colors.setdefault(block, []).extend(group)
-            for i in group:
-                groups[i] = groupid
+        for i, block in enumerate(self.blocks):
+            if block != NOBLOCK:
+                colors.setdefault(block, []).append(i)
 
 
         return sum(dist(i, j)
         return sum(dist(i, j)
-                   for block, color in colors.items()
-                   for i, j in combinations(color, 2))
+                   for blocks in colors.values()
+                   for i, j in combinations(blocks, 2))
 
 
-    def points(self):
-        def group_size(start):
-            work = [start]
-            visited.add(start)
-            size = 0
-            block = self.blocks[start]
-
-            while work:
-                i = work.pop()
-
-                # avoid giving points to moving a block within the same group
-                if self.grabbed.get(i, None) == block:
-                    return 0
-
-                if self.blocks[i] == block:
-                    size += 1
-                    for nb in self.neighbors(i):
-                        if nb not in visited:
-                            visited.add(nb)
-                            work.append(nb)
-            return size
+    def group_leaders(self):
+        seen = set()
+        for i, groupid in enumerate(self.groups):
+            if groupid > 0 and groupid not in seen:
+                seen.add(groupid)
+                yield i
 
 
+    def points(self):
         points = 0
         points = 0
-        visited = set()
 
 
-        for i in self.placed:
-            if i not in visited:
-                block = self.blocks[i]
-                size = group_size(i)
+        for leader in self.group_leaders():
+            block = self.blocks[leader]
+            size = self.groupsizes[leader]
 
 
-                if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
-                    points += size
-                elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
-                    points += BOMB_POINTS
+            if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
+                points += size
+            elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
+                points += self.maxrows()
 
 
         return -points
         return -points
 
 
     def gen_moves(self):
     def gen_moves(self):
         yield self
         yield self
 
 
-        for src in self.gen_shift(not self.grabbing_or_dropping()):
+        for src in self.gen_shift(not self.colbusy(self.exa)):
             yield from src.gen_stationary()
             yield from src.gen_stationary()
 
 
             for get in src.gen_get():
             for get in src.gen_get():
@@ -421,6 +398,42 @@ def get_busy(blocks, colskip):
     return mask
     return mask
 
 
 
 
+def scan_group(blocks, i, block, visited):
+    yield i
+    visited.add(i)
+    for nb in neighbors(i, blocks):
+        if blocks[nb] == block and nb not in visited:
+            yield from scan_group(blocks, nb, block, visited)
+
+
+def get_groups(blocks):
+    groupid = 0
+    groups = [0] * len(blocks)
+    groupsizes = [0] * len(blocks)
+    visited = set()
+
+    for i, block in enumerate(blocks):
+        if block != NOBLOCK and i not in visited:
+            groupid += 1
+            group = tuple(scan_group(blocks, i, block, visited))
+            for j in group:
+                groups[j] = groupid
+
+    return groups, groupsizes, groupid
+
+
+def neighbors(i, blocks):
+    y, x = divmod(i, COLUMNS)
+    if x > 0:
+        yield i - 1
+    if x < COLUMNS - 1:
+        yield i + 1
+    if y > 0:
+        yield i - COLUMNS
+    if y < len(blocks) // COLUMNS - 1:
+        yield i + COLUMNS
+
+
 def move_to_key(move):
 def move_to_key(move):
     return 'jjkadl'[move]
     return 'jjkadl'[move]