Commit 0db87a73 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Compute groups once and update them on each move

parent 2512a051
......@@ -3,7 +3,7 @@ import time
from collections import deque
from contextlib import redirect_stdout
from copy import copy
from itertools import combinations, islice
from itertools import combinations
from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
detect_held, print_board, is_basic, is_bomb
......@@ -20,25 +20,22 @@ MOVE_DELAYS = (
)
MIN_BASIC_GROUP_SIZE = 4
MIN_BOMB_GROUP_SIZE = 2
POINTS_DEPTH = 3
FRAG_DEPTH = 4
DEFRAG_PRIO = 4
COLSIZE_PRIO = 5
COLSIZE_PANIC = 8
COLSIZE_MAX = 9
BOMB_POINTS = 5
class State:
def __init__(self, blocks, exa, held, colskip, busy, moves, placed, grabbed):
def __init__(self, blocks, exa, held, colskip, busy, moves, groups, groupsizes, maxgroup):
self.blocks = blocks
self.exa = exa
self.held = held
self.colskip = colskip
self.busy = busy
self.moves = moves
self.placed = placed
self.grabbed = grabbed
self.groups = groups
self.groupsizes = groupsizes
self.maxgroup = maxgroup
self.nrows = len(blocks) // COLUMNS
@classmethod
......@@ -48,7 +45,8 @@ class State:
held = detect_held(board, exa)
colskip = get_colskip(blocks)
busy = get_busy(blocks, colskip)
return cls(blocks, exa, held, colskip, busy, (), set(), {})
groups, groupsizes, maxgroup = get_groups(blocks)
return cls(blocks, exa, held, colskip, busy, (), groups, groupsizes, maxgroup)
def copy(self, deep):
mcopy = copy if deep else lambda x: x
......@@ -58,26 +56,18 @@ class State:
mcopy(self.colskip),
self.busy,
self.moves,
mcopy(self.placed),
mcopy(self.grabbed))
mcopy(self.groups),
mcopy(self.groupsizes),
self.maxgroup))
def colbusy(self, col):
return (self.busy >> col) & 1
def grabbing_or_dropping(self):
skip = self.colskip[self.exa]
i = (skip + 1) * COLUMNS + self.exa
return i < len(self.blocks) and self.blocks[i] == NOBLOCK
def iter_columns(self):
def gen_col(col):
for row in range(self.nrows):
i = row * COLUMNS + col
if self.blocks[i] != NOBLOCK:
yield i
def colrows(self, col):
return self.nrows - self.colskip[col]
for col in range(COLUMNS):
yield gen_col(col)
def maxrows(self):
return max(map(self.colrows, range(COLUMNS)))
def causes_panic(self):
return self.max_colsize() >= COLSIZE_PANIC
......@@ -101,9 +91,6 @@ class State:
score += row - start_row + 1
return score
def colrows(self, col):
return self.nrows - self.colskip[col]
def move(self, *moves):
deep = any(move in (GRAB, DROP, SWAP) for move in moves)
s = self.copy(deep)
......@@ -124,22 +111,21 @@ class State:
row = s.colskip[s.exa]
assert row < s.nrows
i = row * COLUMNS + s.exa
s.grabbed[i] = s.held = s.blocks[i]
s.held = s.blocks[i]
s.blocks[i] = NOBLOCK
s.grabbed[i] = s.held
s.colskip[s.exa] += 1
s.ungroup(i)
elif move == DROP:
assert not s.colbusy(s.exa)
assert s.held != NOBLOCK
row = s.colskip[s.exa]
assert row > 0
# XXX assert s.nrows - row < COLSIZE_MAX
i = (row - 1) * COLUMNS + s.exa
s.blocks[i] = s.held
s.held = NOBLOCK
s.placed.add(i)
s.colskip[s.exa] -= 1
s.regroup(i)
elif move == SWAP:
assert not s.colbusy(s.exa)
......@@ -150,45 +136,55 @@ class State:
bi = s.blocks[i]
bj = s.blocks[j]
if bi != bj:
s.blocks[i] = bj
s.blocks[i] = NOBLOCK
s.blocks[j] = NOBLOCK
s.ungroup(i)
s.ungroup(j)
s.blocks[j] = bi
s.grabbed[i] = bi
s.grabbed[j] = bj
s.placed.add(i)
s.placed.add(j)
s.regroup(j)
s.blocks[i] = bj
s.regroup(i)
return s
def find_groups(self, depth=POINTS_DEPTH, minsize=2):
def follow_group(i, block, group):
if self.blocks[i] == block and i not in visited:
group.append(i)
def ungroup(self, i):
assert self.blocks[i] == NOBLOCK
visited = set()
oldid = self.groups[i]
for nb in neighbors(i, self.blocks):
if self.groups[nb] == oldid:
newgroup = self.scan_group(nb, visited)
if newgroup:
self.maxgroup = newid = self.maxgroup + 1
for j in newgroup:
self.groups[j] = newid
self.groupsizes[j] = len(newgroup)
self.groups[i] = 0
self.groupsizes[i] = 0
def regroup(self, i):
assert self.blocks[i] != NOBLOCK
self.maxgroup = newid = self.maxgroup + 1
newgroup = self.scan_group(i, set())
for j in newgroup:
self.groups[j] = newid
self.groupsizes[j] = len(newgroup)
def scan_group(self, start, visited):
def scan(i):
if i not in visited:
yield i
visited.add(i)
for nb in self.neighbors(i):
follow_group(nb, block, group)
for nb in neighbors(i, self.blocks):
if self.blocks[nb] == block:
yield from scan(nb)
visited = set()
block = self.blocks[start]
return tuple(scan(start))
for col in self.iter_columns():
for i in islice(col, depth):
block = self.blocks[i]
group = []
follow_group(i, block, group)
if len(group) >= minsize:
yield block, group
def neighbors(self, i):
row, col = divmod(i, COLUMNS)
if col > 0 and self.blocks[i - 1] != NOBLOCK:
yield i - 1
if col < COLUMNS - 1 and self.blocks[i + 1] != NOBLOCK:
yield i + 1
if row > 0 and self.blocks[i - COLUMNS] != NOBLOCK:
yield i - COLUMNS
if row < self.nrows - 1 and self.blocks[i + COLUMNS] != NOBLOCK:
yield i + COLUMNS
def fragmentation(self, depth=FRAG_DEPTH):
def fragmentation(self):
"""
Minimize the sum of dist(i,j) between all blocks i,j of the same color.
Magnify vertical distances to avoid column stacking.
......@@ -199,64 +195,45 @@ class State:
# for blocks in the same group, only count vertical distance so
# that groups are spread out horizontally
if groups[i] == groups[j]:
if self.groups[i] == self.groups[j]:
return abs(yj - yi)
return abs(xj - xi) + abs(yj - yi) * 2 - 1
colors = {}
groups = {}
for groupid, (block, group) in enumerate(self.find_groups(depth, 1)):
colors.setdefault(block, []).extend(group)
for i in group:
groups[i] = groupid
for i, block in enumerate(self.blocks):
if block != NOBLOCK:
colors.setdefault(block, []).append(i)
return sum(dist(i, j)
for block, color in colors.items()
for i, j in combinations(color, 2))
for blocks in colors.values()
for i, j in combinations(blocks, 2))
def points(self):
def group_size(start):
work = [start]
visited.add(start)
size = 0
block = self.blocks[start]
while work:
i = work.pop()
# avoid giving points to moving a block within the same group
if self.grabbed.get(i, None) == block:
return 0
if self.blocks[i] == block:
size += 1
for nb in self.neighbors(i):
if nb not in visited:
visited.add(nb)
work.append(nb)
return size
def group_leaders(self):
seen = set()
for i, groupid in enumerate(self.groups):
if groupid > 0 and groupid not in seen:
seen.add(groupid)
yield i
def points(self):
points = 0
visited = set()
for i in self.placed:
if i not in visited:
block = self.blocks[i]
size = group_size(i)
for leader in self.group_leaders():
block = self.blocks[leader]
size = self.groupsizes[leader]
if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
points += size
elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
points += BOMB_POINTS
if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
points += size
elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
points += self.maxrows()
return -points
def gen_moves(self):
yield self
for src in self.gen_shift(not self.grabbing_or_dropping()):
for src in self.gen_shift(not self.colbusy(self.exa)):
yield from src.gen_stationary()
for get in src.gen_get():
......@@ -421,6 +398,42 @@ def get_busy(blocks, colskip):
return mask
def scan_group(blocks, i, block, visited):
yield i
visited.add(i)
for nb in neighbors(i, blocks):
if blocks[nb] == block and nb not in visited:
yield from scan_group(blocks, nb, block, visited)
def get_groups(blocks):
groupid = 0
groups = [0] * len(blocks)
groupsizes = [0] * len(blocks)
visited = set()
for i, block in enumerate(blocks):
if block != NOBLOCK and i not in visited:
groupid += 1
group = tuple(scan_group(blocks, i, block, visited))
for j in group:
groups[j] = groupid
return groups, groupsizes, groupid
def neighbors(i, blocks):
y, x = divmod(i, COLUMNS)
if x > 0:
yield i - 1
if x < COLUMNS - 1:
yield i + 1
if y > 0:
yield i - COLUMNS
if y < len(blocks) // COLUMNS - 1:
yield i + COLUMNS
def move_to_key(move):
return 'jjkadl'[move]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment