strategy.py 14 KB


  1. import io
  2. import time
  3. from collections import deque
  4. from contextlib import redirect_stdout
  5. from copy import copy
  6. from itertools import combinations
  7. from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  8. detect_held, print_board, is_basic, is_bomb
  9. GRAB, DROP, SWAP, LEFT, RIGHT, SPEED = range(6)
  10. MOVE_DELAYS = (
  11. # in milliseconds
  12. 50, # GRAB
  13. 50, # DROP
  14. 50, # SWAP
  15. 30, # LEFT
  16. 30, # RIGHT
  17. 30, # SPEED
  18. )
  19. MIN_BASIC_GROUP_SIZE = 4
  20. MIN_BOMB_GROUP_SIZE = 2
  21. COLSIZE_PRIO = 5
  22. COLSIZE_PANIC = 8
  23. BOMB_POINTS = 5
  24. class State:
  25. def __init__(self, blocks, exa, held, colskip, busy, moves, blockgroups, groups):
  26. self.blocks = blocks
  27. self.exa = exa
  28. self.held = held
  29. self.colskip = colskip
  30. self.busy = busy
  31. self.moves = moves
  32. self.blockgroups = blockgroups
  33. self.groups = groups
  34. self.nrows = len(blocks) // COLUMNS
  35. @classmethod
  36. def detect(cls, board, pad=2):
  37. blocks = [NOBLOCK] * (COLUMNS * pad) + list(detect_blocks(board))
  38. exa = detect_exa(board)
  39. held = detect_held(board, exa)
  40. colskip = get_colskip(blocks)
  41. busy = get_busy(blocks, colskip)
  42. blockgroups, groups = get_groups(blocks)
  43. return cls(bytearray(blocks), exa, held, bytearray(colskip), busy, (),
  44. bytearray(blockgroups), groups)
  45. def copy(self, deep):
  46. mcopy = copy if deep else lambda x: x
  47. return self.__class__(mcopy(self.blocks),
  48. self.exa,
  49. self.held,
  50. mcopy(self.colskip),
  51. self.busy,
  52. self.moves,
  53. mcopy(self.blockgroups),
  54. mcopy(self.groups))
  55. def colbusy(self, col):
  56. return (self.busy >> col) & 1
  57. def colrows(self, col):
  58. return self.nrows - self.colskip[col]
  59. def maxrows(self):
  60. return max(map(self.colrows, range(COLUMNS)))
  61. def causes_panic(self):
  62. return self.max_colsize() >= COLSIZE_PANIC
  63. def max_colsize(self):
  64. return self.nrows - self.empty_rows()
  65. def empty_rows(self):
  66. for i, block in enumerate(self.blocks):
  67. if block != NOBLOCK:
  68. return i // COLUMNS
  69. return 0
  70. def holes(self):
  71. start_row = self.empty_rows()
  72. score = 0
  73. for col in range(COLUMNS):
  74. for row in range(start_row, self.nrows):
  75. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  76. break
  77. score += row - start_row + 1
  78. return score
  79. def locked(self, i):
  80. size, block = self.groups[self.blockgroups[i]]
  81. if block == NOBLOCK:
  82. return False
  83. if is_basic(block):
  84. return size >= MIN_BASIC_GROUP_SIZE
  85. assert is_bomb(block)
  86. return size >= MIN_BOMB_GROUP_SIZE
  87. def move(self, *moves):
  88. deep = any(move in (GRAB, DROP, SWAP) for move in moves)
  89. s = self.copy(deep)
  90. s.moves += moves
  91. for move in moves:
  92. if move == LEFT:
  93. assert s.exa > 0
  94. s.exa -= 1
  95. elif move == RIGHT:
  96. assert s.exa < COLUMNS - 1
  97. s.exa += 1
  98. elif move == GRAB:
  99. assert not s.colbusy(s.exa)
  100. assert s.held == NOBLOCK
  101. row = s.colskip[s.exa]
  102. assert row < s.nrows
  103. i = row * COLUMNS + s.exa
  104. if not s.locked(i):
  105. s.colskip[s.exa] += 1
  106. s.held = s.remove(i)
  107. elif move == DROP:
  108. if s.held != NOBLOCK:
  109. row = s.colskip[s.exa]
  110. assert row > 0
  111. i = (row - 1) * COLUMNS + s.exa
  112. s.colskip[s.exa] -= 1
  113. s.place(i, s.held)
  114. s.held = NOBLOCK
  115. elif move == SWAP:
  116. assert not s.colbusy(s.exa)
  117. row = s.colskip[s.exa]
  118. i = row * COLUMNS + s.exa
  119. j = i + COLUMNS
  120. if j < len(s.blocks) and not s.locked(i) and not s.locked(j):
  121. bi = s.blocks[i]
  122. bj = s.blocks[j]
  123. assert bi != NOBLOCK
  124. assert bj != NOBLOCK
  125. if bi != bj:
  126. s.blocks[i] = NOBLOCK
  127. s.blocks[j] = NOBLOCK
  128. visited = set()
  129. s.ungroup(i, visited)
  130. s.ungroup(j, visited)
  131. s.place(j, bi)
  132. s.place(i, bj)
  133. return s
  134. def remove(self, i):
  135. block = self.blocks[i]
  136. assert block != NOBLOCK
  137. self.blocks[i] = NOBLOCK
  138. self.ungroup(i, set())
  139. return block
  140. def ungroup(self, i, visited):
  141. assert self.blocks[i] == NOBLOCK
  142. oldid = self.blockgroups[i]
  143. self.blockgroups[i] = 0
  144. self.groups[oldid] = 0, NOBLOCK
  145. for nb in neighbors(i, self.blocks):
  146. if self.blockgroups[nb] == oldid:
  147. self.create_group(nb, visited)
  148. def place(self, i, block):
  149. assert block != NOBLOCK
  150. assert self.blocks[i] == NOBLOCK
  151. self.blocks[i] = block
  152. self.create_group(i, set())
  153. def create_group(self, start, visited):
  154. def scan(i):
  155. if i not in visited:
  156. yield i
  157. visited.add(i)
  158. for nb in neighbors(i, self.blocks):
  159. if self.blocks[nb] == block:
  160. yield from scan(nb)
  161. block = self.blocks[start]
  162. group = tuple(scan(start))
  163. if group:
  164. newid = len(self.groups)
  165. self.groups.append((len(group), block))
  166. for j in group:
  167. self.blockgroups[j] = newid
  168. def fragmentation(self):
  169. """
  170. Minimize the sum of dist(i,j) between all blocks i,j of the same color.
  171. Magnify vertical distances to avoid column stacking.
  172. Add the vertical components of i and j to prioritize lower columns.
  173. """
  174. def dist(i, j):
  175. yi, xi = divmod(i, COLUMNS)
  176. yj, xj = divmod(j, COLUMNS)
  177. # for blocks in the same group, only count vertical distance so that
  178. # groups are spread out horizontally
  179. if self.blockgroups[i] == self.blockgroups[j]:
  180. return abs(yj - yi)
  181. return abs(xj - xi) + abs(yj - yi) * 2 - 1 + \
  182. (self.nrows - yi) + (self.nrows - yj)
  183. colors = {}
  184. for i, block in enumerate(self.blocks):
  185. if block != NOBLOCK:
  186. colors.setdefault(block, []).append(i)
  187. return sum(dist(i, j)
  188. for blocks in colors.values()
  189. for i, j in combinations(blocks, 2))
  190. def points(self):
  191. points = 0
  192. for size, block in self.groups:
  193. if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
  194. points += size
  195. elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
  196. points += self.maxrows()
  197. return -points
  198. def gen_moves(self):
  199. yield self
  200. for src in self.gen_shift(not self.colbusy(self.exa)):
  201. yield from src.gen_stationary()
  202. for get in src.gen_get():
  203. for dst in get.gen_shift(False):
  204. yield from dst.gen_put()
  205. def gen_shift(self, allow_noshift):
  206. if allow_noshift:
  207. yield self
  208. left = self
  209. for i in range(self.exa):
  210. left = left.move(LEFT)
  211. yield left
  212. right = self
  213. for i in range(COLUMNS - self.exa - 1):
  214. right = right.move(RIGHT)
  215. yield right
  216. def gen_stationary(self):
  217. # SWAP
  218. # GRAB, SWAP, DROP
  219. # GRAB, SWAP, DROP, SWAP
  220. # SWAP, GRAB, SWAP, DROP
  221. # SWAP, GRAB, SWAP, DROP, SWAP
  222. if not self.colbusy(self.exa):
  223. avail = self.colrows(self.exa)
  224. if avail >= 2:
  225. swap = self.move(SWAP)
  226. yield swap
  227. if avail >= 3:
  228. grab = self.move(GRAB, SWAP, DROP)
  229. yield grab
  230. yield grab.move(SWAP)
  231. swap = swap.move(GRAB, SWAP, DROP)
  232. yield swap
  233. yield swap.move(SWAP)
  234. def gen_get(self):
  235. # GRAB
  236. # SWAP, GRAB
  237. # GRAB, SWAP, DROP, SWAP, GRAB
  238. if not self.colbusy(self.exa):
  239. avail = self.colrows(self.exa)
  240. if avail >= 1:
  241. grab = self.move(GRAB)
  242. yield grab
  243. if avail >= 2:
  244. yield self.move(SWAP, GRAB)
  245. if avail >= 3:
  246. yield grab.move(SWAP, DROP, SWAP, GRAB)
  247. def gen_put(self):
  248. # DROP
  249. # DROP, SWAP
  250. # DROP, SWAP, GRAB, SWAP, DROP
  251. if not self.colbusy(self.exa):
  252. avail = self.colrows(self.exa)
  253. drop = self.move(DROP)
  254. yield drop
  255. if avail >= 1:
  256. swap = drop.move(SWAP)
  257. yield swap
  258. if avail >= 2:
  259. yield swap.move(GRAB, SWAP, DROP)
  260. def force(self, *moves):
  261. state = self.move(*moves)
  262. state.score = ()
  263. return state
  264. def solve(self):
  265. assert self.exa is not None
  266. if self.held != NOBLOCK:
  267. return self.force(DROP)
  268. pool = deque(self.gen_moves())
  269. if len(pool) == 0:
  270. return self.force()
  271. best_score = ()
  272. for key in self.score_keys():
  273. if len(pool) == 1:
  274. break
  275. for state in pool:
  276. state.score = key(state)
  277. best = min(state.score for state in pool)
  278. best_score += (best,)
  279. for i in range(len(pool)):
  280. state = pool.popleft()
  281. if state.score == best:
  282. pool.append(state)
  283. best = pool.popleft()
  284. best.score = best_score
  285. return best
  286. def score_keys(self):
  287. cls = self.__class__
  288. colsize = self.nrows - 2
  289. if colsize >= COLSIZE_PANIC:
  290. return cls.holes, cls.points, cls.nmoves, cls.fragmentation
  291. if colsize >= COLSIZE_PRIO:
  292. return cls.causes_panic, cls.points, cls.holes, \
  293. cls.fragmentation, cls.nmoves
  294. return cls.points, cls.fragmentation, cls.holes, cls.nmoves
  295. def print_groupsizes(self):
  296. for start in range(len(self.blockgroups) - COLUMNS, -1, -COLUMNS):
  297. print(' '.join('%-2d' % self.groups[g][0]
  298. for g in self.blockgroups[start:start + COLUMNS]))
  299. def print_groups(self):
  300. for start in range(len(self.blockgroups) - COLUMNS, -1, -COLUMNS):
  301. print(' '.join('%-2d' % g
  302. for g in self.blockgroups[start:start + COLUMNS]))
  303. def print(self):
  304. print_board(self.blocks, self.exa, self.held)
  305. def tostring(self):
  306. stream = io.StringIO()
  307. with redirect_stdout(stream):
  308. self.print()
  309. return stream.getvalue()
  310. def has_same_exa(self, state):
  311. return self.exa == state.exa and self.held == state.held
  312. def nmoves(self):
  313. return len(self.moves)
  314. def delay(self):
  315. return moves_delay(self.moves)
  316. def keys(self):
  317. return moves_to_keys(self.moves)
  318. def loops(self, prev):
  319. return self.moves and \
  320. self.exa == prev.exa and \
  321. self.moves == prev.moves and \
  322. self.score == prev.score
  323. def get_colskip(blocks):
  324. def colskip(col):
  325. for row, block in enumerate(blocks[col::COLUMNS]):
  326. if block != NOBLOCK:
  327. return row
  328. return len(blocks) // COLUMNS
  329. return list(map(colskip, range(COLUMNS)))
  330. def get_busy(blocks, colskip):
  331. mask = 0
  332. for col, skip in enumerate(colskip):
  333. start = (skip + 1) * COLUMNS + col
  334. colbusy = NOBLOCK in blocks[start::COLUMNS]
  335. mask |= colbusy << col
  336. return mask
  337. def scan_group(blocks, i, block, visited):
  338. yield i
  339. visited.add(i)
  340. for nb in neighbors(i, blocks):
  341. if blocks[nb] == block and nb not in visited:
  342. yield from scan_group(blocks, nb, block, visited)
  343. def get_groups(blocks):
  344. blockgroups = [0] * len(blocks)
  345. groups = [(0, NOBLOCK)]
  346. visited = set()
  347. for i, block in enumerate(blocks):
  348. if block != NOBLOCK and i not in visited:
  349. groupid = len(groups)
  350. group = tuple(scan_group(blocks, i, block, visited))
  351. groups.append((len(group), block))
  352. for j in group:
  353. blockgroups[j] = groupid
  354. return blockgroups, groups
  355. def neighbors(i, blocks):
  356. y, x = divmod(i, COLUMNS)
  357. if x > 0:
  358. yield i - 1
  359. if x < COLUMNS - 1:
  360. yield i + 1
  361. if y > 0:
  362. yield i - COLUMNS
  363. if y < len(blocks) // COLUMNS - 1:
  364. yield i + COLUMNS
  365. def move_to_key(move):
  366. return 'jjkadl'[move]
  367. def moves_to_keys(moves):
  368. return ''.join(move_to_key(move) for move in moves)
  369. def moves_delay(moves):
  370. return sum(MOVE_DELAYS[m] for m in moves)
  371. if __name__ == '__main__':
  372. import sys
  373. from PIL import Image
  374. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  375. state = State.detect(board)
  376. print('parsed:')
  377. state.print()
  378. print()
  379. print('groups:')
  380. state.print_groups()
  381. print()
  382. print('groupsizes:')
  383. state.print_groupsizes()
  384. print()
  385. start = time.time()
  386. newstate = state.solve()
  387. end = time.time()
  388. print('best move:', newstate.keys())
  389. print('score:', newstate.score)
  390. print('elapsed:', round((end - start) * 1000, 1), 'ms')
  391. print()
  392. print('target after move:')
  393. newstate.print()
  394. #print()
  395. #print('generated moves:')
  396. #for state in state.gen_moves():
  397. # print(state.keys())