Przeglądaj źródła

Manage group sizes per group instead of per block

Taddeus Kroes 5 lat temu
rodzic
commit
bbbfc5be63
1 zmienionych plików z 32 dodań i 44 usunięć
  1. 32 44
      strategy.py

+ 32 - 44
strategy.py

@@ -26,16 +26,15 @@ BOMB_POINTS = 5
 
 
 class State:
-    def __init__(self, blocks, exa, held, colskip, busy, moves, groups, groupsizes, maxgroup):
+    def __init__(self, blocks, exa, held, colskip, busy, moves, blockgroups, groups):
         self.blocks = blocks
         self.exa = exa
         self.held = held
         self.colskip = colskip
         self.busy = busy
         self.moves = moves
+        self.blockgroups = blockgroups
         self.groups = groups
-        self.groupsizes = groupsizes
-        self.maxgroup = maxgroup
         self.nrows = len(blocks) // COLUMNS
 
     @classmethod
@@ -45,9 +44,9 @@ class State:
         held = detect_held(board, exa)
         colskip = get_colskip(blocks)
         busy = get_busy(blocks, colskip)
-        groups, groupsizes, maxgroup = get_groups(blocks)
+        blockgroups, groups = get_groups(blocks)
         return cls(bytearray(blocks), exa, held, bytearray(colskip), busy, (),
-                   bytearray(groups), bytearray(groupsizes), maxgroup)
+                   bytearray(blockgroups), groups)
 
     def copy(self, deep):
         mcopy = copy if deep else lambda x: x
@@ -57,9 +56,8 @@ class State:
                               mcopy(self.colskip),
                               self.busy,
                               self.moves,
-                              mcopy(self.groups),
-                              mcopy(self.groupsizes),
-                              self.maxgroup)
+                              mcopy(self.blockgroups),
+                              mcopy(self.groups))
 
     def colbusy(self, col):
         return (self.busy >> col) & 1
@@ -93,13 +91,13 @@ class State:
         return score
 
     def locked(self, i):
-        block = self.blocks[i]
+        size, block = self.groups[self.blockgroups[i]]
         if block == NOBLOCK:
             return False
         if is_basic(block):
-            return self.groupsizes[i] >= MIN_BASIC_GROUP_SIZE
+            return size >= MIN_BASIC_GROUP_SIZE
         assert is_bomb(block)
-        return self.groupsizes[i] >= MIN_BOMB_GROUP_SIZE
+        return size >= MIN_BOMB_GROUP_SIZE
 
     def move(self, *moves):
         deep = any(move in (GRAB, DROP, SWAP) for move in moves)
@@ -165,15 +163,14 @@ class State:
 
     def ungroup(self, i, visited):
         assert self.blocks[i] == NOBLOCK
-        oldid = self.groups[i]
+        oldid = self.blockgroups[i]
+        self.blockgroups[i] = 0
+        self.groups[oldid] = 0, NOBLOCK
 
         for nb in neighbors(i, self.blocks):
-            if self.groups[nb] == oldid:
+            if self.blockgroups[nb] == oldid:
                 self.create_group(nb, visited)
 
-        self.groups[i] = 0
-        self.groupsizes[i] = 0
-
     def place(self, i, block):
         assert block != NOBLOCK
         assert self.blocks[i] == NOBLOCK
@@ -192,10 +189,10 @@ class State:
         block = self.blocks[start]
         group = tuple(scan(start))
         if group:
-            self.maxgroup = newid = self.maxgroup + 1
+            newid = len(self.groups)
+            self.groups.append((len(group), block))
             for j in group:
-                self.groups[j] = newid
-                self.groupsizes[j] = len(group)
+                self.blockgroups[j] = newid
 
     def fragmentation(self):
         """
@@ -207,9 +204,9 @@ class State:
             yi, xi = divmod(i, COLUMNS)
             yj, xj = divmod(j, COLUMNS)
 
-            # for blocks in the same group, only count vertical distance so
-            # that groups are spread out horizontally
-            if self.groups[i] == self.groups[j]:
+            # for blocks in the same group, only count vertical distance so that
+            # groups are spread out horizontally
+            if self.blockgroups[i] == self.blockgroups[j]:
                 return abs(yj - yi)
 
             return abs(xj - xi) + abs(yj - yi) * 2 - 1 + \
@@ -224,20 +221,10 @@ class State:
                    for blocks in colors.values()
                    for i, j in combinations(blocks, 2))
 
-    def group_leaders(self):
-        seen = set()
-        for i, groupid in enumerate(self.groups):
-            if groupid > 0 and groupid not in seen:
-                seen.add(groupid)
-                yield i
-
     def points(self):
         points = 0
 
-        for leader in self.group_leaders():
-            block = self.blocks[leader]
-            size = self.groupsizes[leader]
-
+        for size, block in self.groups:
             if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
                 points += size
             elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
@@ -367,12 +354,14 @@ class State:
         return cls.points, cls.fragmentation, cls.holes, cls.nmoves
 
     def print_groupsizes(self):
-        for start in range(len(self.groupsizes) - COLUMNS, -1, -COLUMNS):
-            print(' '.join('%-2d' % g for g in self.groupsizes[start:start + COLUMNS]))
+        for start in range(len(self.blockgroups) - COLUMNS, -1, -COLUMNS):
+            print(' '.join('%-2d' % self.groups[g][0]
+                           for g in self.blockgroups[start:start + COLUMNS]))
 
     def print_groups(self):
-        for start in range(len(self.groups) - COLUMNS, -1, -COLUMNS):
-            print(' '.join('%-2d' % g for g in self.groups[start:start + COLUMNS]))
+        for start in range(len(self.blockgroups) - COLUMNS, -1, -COLUMNS):
+            print(' '.join('%-2d' % g
+                           for g in self.blockgroups[start:start + COLUMNS]))
 
     def print(self):
         print_board(self.blocks, self.exa, self.held)
@@ -430,20 +419,19 @@ def scan_group(blocks, i, block, visited):
 
 
 def get_groups(blocks):
-    groupid = 0
-    groups = [0] * len(blocks)
-    groupsizes = [0] * len(blocks)
+    blockgroups = [0] * len(blocks)
+    groups = [(0, NOBLOCK)]
     visited = set()
 
     for i, block in enumerate(blocks):
         if block != NOBLOCK and i not in visited:
-            groupid += 1
+            groupid = len(groups)
             group = tuple(scan_group(blocks, i, block, visited))
+            groups.append((len(group), block))
             for j in group:
-                groups[j] = groupid
-                groupsizes[j] = len(group)
+                blockgroups[j] = groupid
 
-    return groups, groupsizes, groupid
+    return blockgroups, groups
 
 
 def neighbors(i, blocks):