Pārlūkot izejas kodu

Compute groups once and update them on each move

Taddeus Kroes 5 gadi atpakaļ
vecāks
revīzija
0db87a732d
1 mainītis faili ar 116 papildinājumiem un 103 dzēšanām
  1. 116 103
      strategy.py

+ 116 - 103
strategy.py

@@ -3,7 +3,7 @@ import time
 from collections import deque
 from contextlib import redirect_stdout
 from copy import copy
-from itertools import combinations, islice
+from itertools import combinations
 from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
                       detect_held, print_board, is_basic, is_bomb
 
@@ -20,25 +20,22 @@ MOVE_DELAYS = (
 )
 MIN_BASIC_GROUP_SIZE = 4
 MIN_BOMB_GROUP_SIZE = 2
-POINTS_DEPTH = 3
-FRAG_DEPTH = 4
-DEFRAG_PRIO = 4
 COLSIZE_PRIO = 5
 COLSIZE_PANIC = 8
-COLSIZE_MAX = 9
 BOMB_POINTS = 5
 
 
 class State:
-    def __init__(self, blocks, exa, held, colskip, busy, moves, placed, grabbed):
+    def __init__(self, blocks, exa, held, colskip, busy, moves, groups, groupsizes, maxgroup):
         self.blocks = blocks
         self.exa = exa
         self.held = held
         self.colskip = colskip
         self.busy = busy
         self.moves = moves
-        self.placed = placed
-        self.grabbed = grabbed
+        self.groups = groups
+        self.groupsizes = groupsizes
+        self.maxgroup = maxgroup
         self.nrows = len(blocks) // COLUMNS
 
     @classmethod
@@ -48,7 +45,8 @@ class State:
         held = detect_held(board, exa)
         colskip = get_colskip(blocks)
         busy = get_busy(blocks, colskip)
-        return cls(blocks, exa, held, colskip, busy, (), set(), {})
+        groups, groupsizes, maxgroup = get_groups(blocks)
+        return cls(blocks, exa, held, colskip, busy, (), groups, groupsizes, maxgroup)
 
     def copy(self, deep):
         mcopy = copy if deep else lambda x: x
@@ -58,26 +56,18 @@ class State:
                               mcopy(self.colskip),
                               self.busy,
                               self.moves,
-                              mcopy(self.placed),
-                              mcopy(self.grabbed))
+                              mcopy(self.groups),
+                              mcopy(self.groupsizes),
+                              self.maxgroup))
 
     def colbusy(self, col):
         return (self.busy >> col) & 1
 
-    def grabbing_or_dropping(self):
-        skip = self.colskip[self.exa]
-        i = (skip + 1) * COLUMNS + self.exa
-        return i < len(self.blocks) and self.blocks[i] == NOBLOCK
-
-    def iter_columns(self):
-        def gen_col(col):
-            for row in range(self.nrows):
-                i = row * COLUMNS + col
-                if self.blocks[i] != NOBLOCK:
-                    yield i
+    def colrows(self, col):
+        return self.nrows - self.colskip[col]
 
-        for col in range(COLUMNS):
-            yield gen_col(col)
+    def maxrows(self):
+        return max(map(self.colrows, range(COLUMNS)))
 
     def causes_panic(self):
         return self.max_colsize() >= COLSIZE_PANIC
@@ -101,9 +91,6 @@ class State:
                 score += row - start_row + 1
         return score
 
-    def colrows(self, col):
-        return self.nrows - self.colskip[col]
-
     def move(self, *moves):
         deep = any(move in (GRAB, DROP, SWAP) for move in moves)
         s = self.copy(deep)
@@ -124,22 +111,21 @@ class State:
                 row = s.colskip[s.exa]
                 assert row < s.nrows
                 i = row * COLUMNS + s.exa
-                s.grabbed[i] = s.held = s.blocks[i]
+                s.held = s.blocks[i]
                 s.blocks[i] = NOBLOCK
-                s.grabbed[i] = s.held
                 s.colskip[s.exa] += 1
+                s.ungroup(i)
 
             elif move == DROP:
                 assert not s.colbusy(s.exa)
                 assert s.held != NOBLOCK
                 row = s.colskip[s.exa]
                 assert row > 0
-                # XXX assert s.nrows - row < COLSIZE_MAX
                 i = (row - 1) * COLUMNS + s.exa
                 s.blocks[i] = s.held
                 s.held = NOBLOCK
-                s.placed.add(i)
                 s.colskip[s.exa] -= 1
+                s.regroup(i)
 
             elif move == SWAP:
                 assert not s.colbusy(s.exa)
@@ -150,45 +136,55 @@ class State:
                 bi = s.blocks[i]
                 bj = s.blocks[j]
                 if bi != bj:
-                    s.blocks[i] = bj
+                    s.blocks[i] = NOBLOCK
+                    s.blocks[j] = NOBLOCK
+                    s.ungroup(i)
+                    s.ungroup(j)
                     s.blocks[j] = bi
-                    s.grabbed[i] = bi
-                    s.grabbed[j] = bj
-                    s.placed.add(i)
-                    s.placed.add(j)
+                    s.regroup(j)
+                    s.blocks[i] = bj
+                    s.regroup(i)
 
         return s
 
-    def find_groups(self, depth=POINTS_DEPTH, minsize=2):
-        def follow_group(i, block, group):
-            if self.blocks[i] == block and i not in visited:
-                group.append(i)
+    def ungroup(self, i):
+        assert self.blocks[i] == NOBLOCK
+        visited = set()
+        oldid = self.groups[i]
+
+        for nb in neighbors(i, self.blocks):
+            if self.groups[nb] == oldid:
+                newgroup = self.scan_group(nb, visited)
+                if newgroup:
+                    self.maxgroup = newid = self.maxgroup + 1
+                    for j in newgroup:
+                        self.groups[j] = newid
+                        self.groupsizes[j] = len(newgroup)
+
+        self.groups[i] = 0
+        self.groupsizes[i] = 0
+
+    def regroup(self, i):
+        assert self.blocks[i] != NOBLOCK
+        self.maxgroup = newid = self.maxgroup + 1
+        newgroup = self.scan_group(i, set())
+        for j in newgroup:
+            self.groups[j] = newid
+            self.groupsizes[j] = len(newgroup)
+
+    def scan_group(self, start, visited):
+        def scan(i):
+            if i not in visited:
+                yield i
                 visited.add(i)
-                for nb in self.neighbors(i):
-                    follow_group(nb, block, group)
+                for nb in neighbors(i, self.blocks):
+                    if self.blocks[nb] == block:
+                        yield from scan(nb)
 
-        visited = set()
+        block = self.blocks[start]
+        return tuple(scan(start))
 
-        for col in self.iter_columns():
-            for i in islice(col, depth):
-                block = self.blocks[i]
-                group = []
-                follow_group(i, block, group)
-                if len(group) >= minsize:
-                    yield block, group
-
-    def neighbors(self, i):
-        row, col = divmod(i, COLUMNS)
-        if col > 0 and self.blocks[i - 1] != NOBLOCK:
-            yield i - 1
-        if col < COLUMNS - 1 and self.blocks[i + 1] != NOBLOCK:
-            yield i + 1
-        if row > 0 and self.blocks[i - COLUMNS] != NOBLOCK:
-            yield i - COLUMNS
-        if row < self.nrows - 1 and self.blocks[i + COLUMNS] != NOBLOCK:
-            yield i + COLUMNS
-
-    def fragmentation(self, depth=FRAG_DEPTH):
+    def fragmentation(self):
         """
         Minimize the sum of dist(i,j) between all blocks i,j of the same color.
         Magnify vertical distances to avoid column stacking.
@@ -199,64 +195,45 @@ class State:
 
             # for blocks in the same group, only count vertical distance so
             # that groups are spread out horizontally
-            if groups[i] == groups[j]:
+            if self.groups[i] == self.groups[j]:
                 return abs(yj - yi)
 
             return abs(xj - xi) + abs(yj - yi) * 2 - 1
 
         colors = {}
-        groups = {}
-
-        for groupid, (block, group) in enumerate(self.find_groups(depth, 1)):
-            colors.setdefault(block, []).extend(group)
-            for i in group:
-                groups[i] = groupid
+        for i, block in enumerate(self.blocks):
+            if block != NOBLOCK:
+                colors.setdefault(block, []).append(i)
 
         return sum(dist(i, j)
-                   for block, color in colors.items()
-                   for i, j in combinations(color, 2))
+                   for blocks in colors.values()
+                   for i, j in combinations(blocks, 2))
 
-    def points(self):
-        def group_size(start):
-            work = [start]
-            visited.add(start)
-            size = 0
-            block = self.blocks[start]
-
-            while work:
-                i = work.pop()
-
-                # avoid giving points to moving a block within the same group
-                if self.grabbed.get(i, None) == block:
-                    return 0
-
-                if self.blocks[i] == block:
-                    size += 1
-                    for nb in self.neighbors(i):
-                        if nb not in visited:
-                            visited.add(nb)
-                            work.append(nb)
-            return size
+    def group_leaders(self):
+        seen = set()
+        for i, groupid in enumerate(self.groups):
+            if groupid > 0 and groupid not in seen:
+                seen.add(groupid)
+                yield i
 
+    def points(self):
         points = 0
-        visited = set()
 
-        for i in self.placed:
-            if i not in visited:
-                block = self.blocks[i]
-                size = group_size(i)
+        for leader in self.group_leaders():
+            block = self.blocks[leader]
+            size = self.groupsizes[leader]
 
-                if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
-                    points += size
-                elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
-                    points += BOMB_POINTS
+            if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
+                points += size
+            elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
+                points += self.maxrows()
 
         return -points
 
     def gen_moves(self):
         yield self
 
-        for src in self.gen_shift(not self.grabbing_or_dropping()):
+        for src in self.gen_shift(not self.colbusy(self.exa)):
             yield from src.gen_stationary()
 
             for get in src.gen_get():
@@ -421,6 +398,42 @@ def get_busy(blocks, colskip):
     return mask
 
 
+def scan_group(blocks, i, block, visited):
+    yield i
+    visited.add(i)
+    for nb in neighbors(i, blocks):
+        if blocks[nb] == block and nb not in visited:
+            yield from scan_group(blocks, nb, block, visited)
+
+
+def get_groups(blocks):
+    groupid = 0
+    groups = [0] * len(blocks)
+    groupsizes = [0] * len(blocks)
+    visited = set()
+
+    for i, block in enumerate(blocks):
+        if block != NOBLOCK and i not in visited:
+            groupid += 1
+            group = tuple(scan_group(blocks, i, block, visited))
+            for j in group:
+                groups[j] = groupid
+
+    return groups, groupsizes, groupid
+
+
+def neighbors(i, blocks):
+    y, x = divmod(i, COLUMNS)
+    if x > 0:
+        yield i - 1
+    if x < COLUMNS - 1:
+        yield i + 1
+    if y > 0:
+        yield i - COLUMNS
+    if y < len(blocks) // COLUMNS - 1:
+        yield i + COLUMNS
+
+
 def move_to_key(move):
     return 'jjkadl'[move]