strategy.py 13 KB


  1. import io
  2. import time
  3. from collections import deque
  4. from contextlib import redirect_stdout
  5. from copy import copy
  6. from itertools import combinations
  7. from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  8. detect_held, print_board, is_basic, is_bomb
  9. GRAB, DROP, SWAP, LEFT, RIGHT, SPEED = range(6)
  10. MOVE_DELAYS = (
  11. # in milliseconds
  12. 50, # GRAB
  13. 50, # DROP
  14. 50, # SWAP
  15. 30, # LEFT
  16. 30, # RIGHT
  17. 30, # SPEED
  18. )
  19. MIN_BASIC_GROUP_SIZE = 4
  20. MIN_BOMB_GROUP_SIZE = 2
  21. COLSIZE_PRIO = 5
  22. COLSIZE_PANIC = 8
  23. BOMB_POINTS = 5
  24. class State:
  25. def __init__(self, blocks, exa, held, colskip, busy, moves, groups, groupsizes, maxgroup):
  26. self.blocks = blocks
  27. self.exa = exa
  28. self.held = held
  29. self.colskip = colskip
  30. self.busy = busy
  31. self.moves = moves
  32. self.groups = groups
  33. self.groupsizes = groupsizes
  34. self.maxgroup = maxgroup
  35. self.nrows = len(blocks) // COLUMNS
  36. @classmethod
  37. def detect(cls, board, pad=2):
  38. blocks = [NOBLOCK] * (COLUMNS * pad) + list(detect_blocks(board))
  39. exa = detect_exa(board)
  40. held = detect_held(board, exa)
  41. colskip = get_colskip(blocks)
  42. busy = get_busy(blocks, colskip)
  43. groups, groupsizes, maxgroup = get_groups(blocks)
  44. return cls(blocks, exa, held, colskip, busy, (), groups, groupsizes, maxgroup)
  45. def copy(self, deep):
  46. mcopy = copy if deep else lambda x: x
  47. return self.__class__(mcopy(self.blocks),
  48. self.exa,
  49. self.held,
  50. mcopy(self.colskip),
  51. self.busy,
  52. self.moves,
  53. mcopy(self.groups),
  54. mcopy(self.groupsizes),
  55. self.maxgroup)
  56. def colbusy(self, col):
  57. return (self.busy >> col) & 1
  58. def colrows(self, col):
  59. return self.nrows - self.colskip[col]
  60. def maxrows(self):
  61. return max(map(self.colrows, range(COLUMNS)))
  62. def causes_panic(self):
  63. return self.max_colsize() >= COLSIZE_PANIC
  64. def max_colsize(self):
  65. return self.nrows - self.empty_rows()
  66. def empty_rows(self):
  67. for i, block in enumerate(self.blocks):
  68. if block != NOBLOCK:
  69. return i // COLUMNS
  70. return 0
  71. def holes(self):
  72. start_row = self.empty_rows()
  73. score = 0
  74. for col in range(COLUMNS):
  75. for row in range(start_row, self.nrows):
  76. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  77. break
  78. score += row - start_row + 1
  79. return score
  80. def move(self, *moves):
  81. deep = any(move in (GRAB, DROP, SWAP) for move in moves)
  82. s = self.copy(deep)
  83. s.moves += moves
  84. for move in moves:
  85. if move == LEFT:
  86. assert s.exa > 0
  87. s.exa -= 1
  88. elif move == RIGHT:
  89. assert s.exa < COLUMNS - 1
  90. s.exa += 1
  91. elif move == GRAB:
  92. assert not s.colbusy(s.exa)
  93. assert s.held == NOBLOCK
  94. row = s.colskip[s.exa]
  95. assert row < s.nrows
  96. i = row * COLUMNS + s.exa
  97. s.held = s.blocks[i]
  98. s.blocks[i] = NOBLOCK
  99. s.colskip[s.exa] += 1
  100. s.ungroup(i)
  101. elif move == DROP:
  102. assert not s.colbusy(s.exa)
  103. assert s.held != NOBLOCK
  104. row = s.colskip[s.exa]
  105. assert row > 0
  106. i = (row - 1) * COLUMNS + s.exa
  107. s.blocks[i] = s.held
  108. s.held = NOBLOCK
  109. s.colskip[s.exa] -= 1
  110. s.regroup(i)
  111. elif move == SWAP:
  112. assert not s.colbusy(s.exa)
  113. row = s.colskip[s.exa]
  114. i = row * COLUMNS + s.exa
  115. j = i + COLUMNS
  116. assert j < len(s.blocks)
  117. bi = s.blocks[i]
  118. bj = s.blocks[j]
  119. if bi != bj:
  120. s.blocks[i] = NOBLOCK
  121. s.blocks[j] = NOBLOCK
  122. s.ungroup(i)
  123. s.ungroup(j)
  124. s.blocks[j] = bi
  125. s.regroup(j)
  126. s.blocks[i] = bj
  127. s.regroup(i)
  128. return s
  129. def ungroup(self, i):
  130. assert self.blocks[i] == NOBLOCK
  131. visited = set()
  132. oldid = self.groups[i]
  133. for nb in neighbors(i, self.blocks):
  134. if self.groups[nb] == oldid:
  135. newgroup = self.scan_group(nb, visited)
  136. if newgroup:
  137. self.maxgroup = newid = self.maxgroup + 1
  138. for j in newgroup:
  139. self.groups[j] = newid
  140. self.groupsizes[j] = len(newgroup)
  141. self.groups[i] = 0
  142. self.groupsizes[i] = 0
  143. def regroup(self, i):
  144. assert self.blocks[i] != NOBLOCK
  145. self.maxgroup = newid = self.maxgroup + 1
  146. newgroup = self.scan_group(i, set())
  147. for j in newgroup:
  148. self.groups[j] = newid
  149. self.groupsizes[j] = len(newgroup)
  150. def scan_group(self, start, visited):
  151. def scan(i):
  152. if i not in visited:
  153. yield i
  154. visited.add(i)
  155. for nb in neighbors(i, self.blocks):
  156. if self.blocks[nb] == block:
  157. yield from scan(nb)
  158. block = self.blocks[start]
  159. return tuple(scan(start))
  160. def fragmentation(self):
  161. """
  162. Minimize the sum of dist(i,j) between all blocks i,j of the same color.
  163. Magnify vertical distances to avoid column stacking.
  164. """
  165. def dist(i, j):
  166. yi, xi = divmod(i, COLUMNS)
  167. yj, xj = divmod(j, COLUMNS)
  168. # for blocks in the same group, only count vertical distance so
  169. # that groups are spread out horizontally
  170. if self.groups[i] == self.groups[j]:
  171. return abs(yj - yi)
  172. return abs(xj - xi) + abs(yj - yi) * 2 - 1
  173. colors = {}
  174. for i, block in enumerate(self.blocks):
  175. if block != NOBLOCK:
  176. colors.setdefault(block, []).append(i)
  177. return sum(dist(i, j)
  178. for blocks in colors.values()
  179. for i, j in combinations(blocks, 2))
  180. def group_leaders(self):
  181. seen = set()
  182. for i, groupid in enumerate(self.groups):
  183. if groupid > 0 and groupid not in seen:
  184. seen.add(groupid)
  185. yield i
  186. def points(self):
  187. points = 0
  188. for leader in self.group_leaders():
  189. block = self.blocks[leader]
  190. size = self.groupsizes[leader]
  191. if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
  192. points += size
  193. elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
  194. points += self.maxrows()
  195. return -points
  196. def gen_moves(self):
  197. yield self
  198. for src in self.gen_shift(not self.colbusy(self.exa)):
  199. yield from src.gen_stationary()
  200. for get in src.gen_get():
  201. for dst in get.gen_shift(False):
  202. yield from dst.gen_put()
  203. def gen_shift(self, allow_noshift):
  204. if allow_noshift:
  205. yield self
  206. left = self
  207. for i in range(self.exa):
  208. left = left.move(LEFT)
  209. yield left
  210. right = self
  211. for i in range(COLUMNS - self.exa - 1):
  212. right = right.move(RIGHT)
  213. yield right
  214. def gen_stationary(self):
  215. # SWAP
  216. # GRAB, SWAP, DROP
  217. # GRAB, SWAP, DROP, SWAP
  218. # SWAP, GRAB, SWAP, DROP
  219. # SWAP, GRAB, SWAP, DROP, SWAP
  220. if not self.colbusy(self.exa):
  221. avail = self.colrows(self.exa)
  222. if avail >= 2:
  223. swap = self.move(SWAP)
  224. yield swap
  225. if avail >= 3:
  226. grab = self.move(GRAB, SWAP, DROP)
  227. yield grab
  228. yield grab.move(SWAP)
  229. swap = swap.move(GRAB, SWAP, DROP)
  230. yield swap
  231. yield swap.move(SWAP)
  232. def gen_get(self):
  233. # GRAB
  234. # SWAP, GRAB
  235. # GRAB, SWAP, DROP, SWAP, GRAB
  236. if not self.colbusy(self.exa):
  237. avail = self.colrows(self.exa)
  238. if avail >= 1:
  239. grab = self.move(GRAB)
  240. yield grab
  241. if avail >= 2:
  242. yield self.move(SWAP, GRAB)
  243. if avail >= 3:
  244. yield grab.move(SWAP, DROP, SWAP, GRAB)
  245. def gen_put(self):
  246. # DROP
  247. # DROP, SWAP
  248. # DROP, SWAP, GRAB, SWAP, DROP
  249. if not self.colbusy(self.exa):
  250. avail = self.colrows(self.exa)
  251. drop = self.move(DROP)
  252. yield drop
  253. if avail >= 1:
  254. swap = drop.move(SWAP)
  255. yield swap
  256. if avail >= 2:
  257. yield swap.move(GRAB, SWAP, DROP)
  258. def force(self, *moves):
  259. state = self.move(*moves)
  260. state.score = ()
  261. return state
  262. def solve(self):
  263. assert self.exa is not None
  264. if self.held != NOBLOCK:
  265. return self.force(DROP)
  266. pool = deque(self.gen_moves())
  267. if len(pool) == 0:
  268. return self.force()
  269. best_score = ()
  270. for key in self.score_keys():
  271. if len(pool) == 1:
  272. break
  273. for state in pool:
  274. state.score = key(state)
  275. best = min(state.score for state in pool)
  276. best_score += (best,)
  277. for i in range(len(pool)):
  278. state = pool.popleft()
  279. if state.score == best:
  280. pool.append(state)
  281. best = pool.popleft()
  282. best.score = best_score
  283. return best
  284. def score_keys(self):
  285. cls = self.__class__
  286. colsize = self.nrows - 2
  287. if colsize >= COLSIZE_PANIC:
  288. return cls.holes, cls.nmoves, cls.points, cls.fragmentation
  289. if colsize >= COLSIZE_PRIO:
  290. return cls.causes_panic, cls.points, cls.holes, \
  291. cls.fragmentation, cls.nmoves
  292. return cls.points, cls.fragmentation, cls.holes, cls.nmoves
  293. def print(self):
  294. print_board(self.blocks, self.exa, self.held)
  295. def tostring(self):
  296. stream = io.StringIO()
  297. with redirect_stdout(stream):
  298. self.print()
  299. return stream.getvalue()
  300. def has_same_exa(self, state):
  301. return self.exa == state.exa and self.held == state.held
  302. def nmoves(self):
  303. return len(self.moves)
  304. def delay(self):
  305. return moves_delay(self.moves)
  306. def keys(self):
  307. return moves_to_keys(self.moves)
  308. def loops(self, prev):
  309. return self.moves and \
  310. self.exa == prev.exa and \
  311. self.moves == prev.moves and \
  312. self.score == prev.score
  313. def get_colskip(blocks):
  314. def colskip(col):
  315. for row, block in enumerate(blocks[col::COLUMNS]):
  316. if block != NOBLOCK:
  317. return row
  318. return len(blocks) // COLUMNS
  319. return list(map(colskip, range(COLUMNS)))
  320. def get_busy(blocks, colskip):
  321. mask = 0
  322. for col, skip in enumerate(colskip):
  323. start = (skip + 1) * COLUMNS + col
  324. colbusy = NOBLOCK in blocks[start::COLUMNS]
  325. mask |= colbusy << col
  326. return mask
  327. def scan_group(blocks, i, block, visited):
  328. yield i
  329. visited.add(i)
  330. for nb in neighbors(i, blocks):
  331. if blocks[nb] == block and nb not in visited:
  332. yield from scan_group(blocks, nb, block, visited)
  333. def get_groups(blocks):
  334. groupid = 0
  335. groups = [0] * len(blocks)
  336. groupsizes = [0] * len(blocks)
  337. visited = set()
  338. for i, block in enumerate(blocks):
  339. if block != NOBLOCK and i not in visited:
  340. groupid += 1
  341. group = tuple(scan_group(blocks, i, block, visited))
  342. for j in group:
  343. groups[j] = groupid
  344. return groups, groupsizes, groupid
  345. def neighbors(i, blocks):
  346. y, x = divmod(i, COLUMNS)
  347. if x > 0:
  348. yield i - 1
  349. if x < COLUMNS - 1:
  350. yield i + 1
  351. if y > 0:
  352. yield i - COLUMNS
  353. if y < len(blocks) // COLUMNS - 1:
  354. yield i + COLUMNS
  355. def move_to_key(move):
  356. return 'jjkadl'[move]
  357. def moves_to_keys(moves):
  358. return ''.join(move_to_key(move) for move in moves)
  359. def moves_delay(moves):
  360. return sum(MOVE_DELAYS[m] for m in moves)
  361. if __name__ == '__main__':
  362. import sys
  363. from PIL import Image
  364. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  365. state = State.detect(board)
  366. print('parsed:')
  367. state.print()
  368. print()
  369. start = time.time()
  370. newstate = state.solve()
  371. end = time.time()
  372. print('best move:', newstate.keys())
  373. print('score:', newstate.score)
  374. print('elapsed:', round((end - start) * 1000, 1), 'ms')
  375. print()
  376. print('target after move:')
  377. newstate.print()
  378. #print()
  379. #print('generated moves:')
  380. #for state in state.gen_moves():
  381. # print(state.keys())