strategy.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. import io
  2. import time
  3. from collections import deque
  4. from contextlib import redirect_stdout
  5. from copy import copy
  6. from itertools import combinations
  7. from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  8. detect_held, print_board, is_basic, is_bomb
  9. GRAB, DROP, SWAP, LEFT, RIGHT, SPEED = range(6)
  10. MOVE_DELAYS = (
  11. # in milliseconds
  12. 50, # GRAB
  13. 50, # DROP
  14. 50, # SWAP
  15. 30, # LEFT
  16. 30, # RIGHT
  17. 30, # SPEED
  18. )
  19. MIN_BASIC_GROUP_SIZE = 4
  20. MIN_BOMB_GROUP_SIZE = 2
  21. COLSIZE_PRIO = 5
  22. COLSIZE_PANIC = 8
  23. BOMB_POINTS = 5
  24. class State:
  25. def __init__(self, blocks, exa, held, colskip, busy, moves, groups, groupsizes, maxgroup):
  26. self.blocks = blocks
  27. self.exa = exa
  28. self.held = held
  29. self.colskip = colskip
  30. self.busy = busy
  31. self.moves = moves
  32. self.groups = groups
  33. self.groupsizes = groupsizes
  34. self.maxgroup = maxgroup
  35. self.nrows = len(blocks) // COLUMNS
  36. @classmethod
  37. def detect(cls, board, pad=2):
  38. blocks = [NOBLOCK] * (COLUMNS * pad) + list(detect_blocks(board))
  39. exa = detect_exa(board)
  40. held = detect_held(board, exa)
  41. colskip = get_colskip(blocks)
  42. busy = get_busy(blocks, colskip)
  43. groups, groupsizes, maxgroup = get_groups(blocks)
  44. return cls(bytearray(blocks), exa, held, bytearray(colskip), busy, (),
  45. bytearray(groups), bytearray(groupsizes), maxgroup)
  46. def copy(self, deep):
  47. mcopy = copy if deep else lambda x: x
  48. return self.__class__(mcopy(self.blocks),
  49. self.exa,
  50. self.held,
  51. mcopy(self.colskip),
  52. self.busy,
  53. self.moves,
  54. mcopy(self.groups),
  55. mcopy(self.groupsizes),
  56. self.maxgroup)
  57. def colbusy(self, col):
  58. return (self.busy >> col) & 1
  59. def colrows(self, col):
  60. return self.nrows - self.colskip[col]
  61. def maxrows(self):
  62. return max(map(self.colrows, range(COLUMNS)))
  63. def causes_panic(self):
  64. return self.max_colsize() >= COLSIZE_PANIC
  65. def max_colsize(self):
  66. return self.nrows - self.empty_rows()
  67. def empty_rows(self):
  68. for i, block in enumerate(self.blocks):
  69. if block != NOBLOCK:
  70. return i // COLUMNS
  71. return 0
  72. def holes(self):
  73. start_row = self.empty_rows()
  74. score = 0
  75. for col in range(COLUMNS):
  76. for row in range(start_row, self.nrows):
  77. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  78. break
  79. score += row - start_row + 1
  80. return score
  81. def locked(self, i):
  82. block = self.blocks[i]
  83. if block == NOBLOCK:
  84. return False
  85. if is_basic(block):
  86. return self.groupsizes[i] >= MIN_BASIC_GROUP_SIZE
  87. assert is_bomb(block)
  88. return self.groupsizes[i] >= MIN_BOMB_GROUP_SIZE
  89. def move(self, *moves):
  90. deep = any(move in (GRAB, DROP, SWAP) for move in moves)
  91. s = self.copy(deep)
  92. s.moves += moves
  93. for move in moves:
  94. if move == LEFT:
  95. assert s.exa > 0
  96. s.exa -= 1
  97. elif move == RIGHT:
  98. assert s.exa < COLUMNS - 1
  99. s.exa += 1
  100. elif move == GRAB:
  101. assert not s.colbusy(s.exa)
  102. assert s.held == NOBLOCK
  103. row = s.colskip[s.exa]
  104. assert row < s.nrows
  105. i = row * COLUMNS + s.exa
  106. if not s.locked(i):
  107. s.colskip[s.exa] += 1
  108. s.held = s.remove(i)
  109. elif move == DROP:
  110. if s.held != NOBLOCK:
  111. row = s.colskip[s.exa]
  112. assert row > 0
  113. i = (row - 1) * COLUMNS + s.exa
  114. s.colskip[s.exa] -= 1
  115. s.place(i, s.held)
  116. s.held = NOBLOCK
  117. elif move == SWAP:
  118. assert not s.colbusy(s.exa)
  119. row = s.colskip[s.exa]
  120. i = row * COLUMNS + s.exa
  121. j = i + COLUMNS
  122. if j < len(s.blocks) and not s.locked(i) and not s.locked(j):
  123. bi = s.blocks[i]
  124. bj = s.blocks[j]
  125. assert bi != NOBLOCK
  126. assert bj != NOBLOCK
  127. if bi != bj:
  128. s.blocks[i] = NOBLOCK
  129. s.blocks[j] = NOBLOCK
  130. visited = set()
  131. s.ungroup(i, visited)
  132. s.ungroup(j, visited)
  133. s.place(j, bi)
  134. s.place(i, bj)
  135. return s
  136. def remove(self, i):
  137. block = self.blocks[i]
  138. assert block != NOBLOCK
  139. self.blocks[i] = NOBLOCK
  140. self.ungroup(i, set())
  141. return block
  142. def ungroup(self, i, visited):
  143. assert self.blocks[i] == NOBLOCK
  144. oldid = self.groups[i]
  145. for nb in neighbors(i, self.blocks):
  146. if self.groups[nb] == oldid:
  147. self.create_group(nb, visited)
  148. self.groups[i] = 0
  149. self.groupsizes[i] = 0
  150. def place(self, i, block):
  151. assert block != NOBLOCK
  152. assert self.blocks[i] == NOBLOCK
  153. self.blocks[i] = block
  154. self.create_group(i, set())
  155. def create_group(self, start, visited):
  156. def scan(i):
  157. if i not in visited:
  158. yield i
  159. visited.add(i)
  160. for nb in neighbors(i, self.blocks):
  161. if self.blocks[nb] == block:
  162. yield from scan(nb)
  163. block = self.blocks[start]
  164. group = tuple(scan(start))
  165. if group:
  166. self.maxgroup = newid = self.maxgroup + 1
  167. for j in group:
  168. self.groups[j] = newid
  169. self.groupsizes[j] = len(group)
  170. def fragmentation(self):
  171. """
  172. Minimize the sum of dist(i,j) between all blocks i,j of the same color.
  173. Magnify vertical distances to avoid column stacking.
  174. """
  175. def dist(i, j):
  176. yi, xi = divmod(i, COLUMNS)
  177. yj, xj = divmod(j, COLUMNS)
  178. # for blocks in the same group, only count vertical distance so
  179. # that groups are spread out horizontally
  180. if self.groups[i] == self.groups[j]:
  181. return abs(yj - yi)
  182. return abs(xj - xi) + abs(yj - yi) * 2 - 1
  183. colors = {}
  184. for i, block in enumerate(self.blocks):
  185. if block != NOBLOCK:
  186. colors.setdefault(block, []).append(i)
  187. return sum(dist(i, j)
  188. for blocks in colors.values()
  189. for i, j in combinations(blocks, 2))
  190. def group_leaders(self):
  191. seen = set()
  192. for i, groupid in enumerate(self.groups):
  193. if groupid > 0 and groupid not in seen:
  194. seen.add(groupid)
  195. yield i
  196. def points(self):
  197. points = 0
  198. for leader in self.group_leaders():
  199. block = self.blocks[leader]
  200. size = self.groupsizes[leader]
  201. if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
  202. points += size
  203. elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
  204. points += self.maxrows()
  205. return -points
  206. def gen_moves(self):
  207. yield self
  208. for src in self.gen_shift(not self.colbusy(self.exa)):
  209. yield from src.gen_stationary()
  210. for get in src.gen_get():
  211. for dst in get.gen_shift(False):
  212. yield from dst.gen_put()
  213. def gen_shift(self, allow_noshift):
  214. if allow_noshift:
  215. yield self
  216. left = self
  217. for i in range(self.exa):
  218. left = left.move(LEFT)
  219. yield left
  220. right = self
  221. for i in range(COLUMNS - self.exa - 1):
  222. right = right.move(RIGHT)
  223. yield right
  224. def gen_stationary(self):
  225. # SWAP
  226. # GRAB, SWAP, DROP
  227. # GRAB, SWAP, DROP, SWAP
  228. # SWAP, GRAB, SWAP, DROP
  229. # SWAP, GRAB, SWAP, DROP, SWAP
  230. if not self.colbusy(self.exa):
  231. avail = self.colrows(self.exa)
  232. if avail >= 2:
  233. swap = self.move(SWAP)
  234. yield swap
  235. if avail >= 3:
  236. grab = self.move(GRAB, SWAP, DROP)
  237. yield grab
  238. yield grab.move(SWAP)
  239. swap = swap.move(GRAB, SWAP, DROP)
  240. yield swap
  241. yield swap.move(SWAP)
  242. def gen_get(self):
  243. # GRAB
  244. # SWAP, GRAB
  245. # GRAB, SWAP, DROP, SWAP, GRAB
  246. if not self.colbusy(self.exa):
  247. avail = self.colrows(self.exa)
  248. if avail >= 1:
  249. grab = self.move(GRAB)
  250. yield grab
  251. if avail >= 2:
  252. yield self.move(SWAP, GRAB)
  253. if avail >= 3:
  254. yield grab.move(SWAP, DROP, SWAP, GRAB)
  255. def gen_put(self):
  256. # DROP
  257. # DROP, SWAP
  258. # DROP, SWAP, GRAB, SWAP, DROP
  259. if not self.colbusy(self.exa):
  260. avail = self.colrows(self.exa)
  261. drop = self.move(DROP)
  262. yield drop
  263. if avail >= 1:
  264. swap = drop.move(SWAP)
  265. yield swap
  266. if avail >= 2:
  267. yield swap.move(GRAB, SWAP, DROP)
  268. def force(self, *moves):
  269. state = self.move(*moves)
  270. state.score = ()
  271. return state
  272. def solve(self):
  273. assert self.exa is not None
  274. if self.held != NOBLOCK:
  275. return self.force(DROP)
  276. pool = deque(self.gen_moves())
  277. if len(pool) == 0:
  278. return self.force()
  279. best_score = ()
  280. for key in self.score_keys():
  281. if len(pool) == 1:
  282. break
  283. for state in pool:
  284. state.score = key(state)
  285. best = min(state.score for state in pool)
  286. best_score += (best,)
  287. for i in range(len(pool)):
  288. state = pool.popleft()
  289. if state.score == best:
  290. pool.append(state)
  291. best = pool.popleft()
  292. best.score = best_score
  293. return best
  294. def score_keys(self):
  295. cls = self.__class__
  296. colsize = self.nrows - 2
  297. if colsize >= COLSIZE_PANIC:
  298. return cls.holes, cls.points, cls.nmoves, cls.fragmentation
  299. if colsize >= COLSIZE_PRIO:
  300. return cls.causes_panic, cls.points, cls.holes, \
  301. cls.fragmentation, cls.nmoves
  302. return cls.points, cls.fragmentation, cls.holes, cls.nmoves
  303. def print_groupsizes(self):
  304. for start in range(len(self.groupsizes) - COLUMNS, -1, -COLUMNS):
  305. print(' '.join('%-2d' % g for g in self.groupsizes[start:start + COLUMNS]))
  306. def print_groups(self):
  307. for start in range(len(self.groups) - COLUMNS, -1, -COLUMNS):
  308. print(' '.join('%-2d' % g for g in self.groups[start:start + COLUMNS]))
  309. def print(self):
  310. print_board(self.blocks, self.exa, self.held)
  311. def tostring(self):
  312. stream = io.StringIO()
  313. with redirect_stdout(stream):
  314. self.print()
  315. return stream.getvalue()
  316. def has_same_exa(self, state):
  317. return self.exa == state.exa and self.held == state.held
  318. def nmoves(self):
  319. return len(self.moves)
  320. def delay(self):
  321. return moves_delay(self.moves)
  322. def keys(self):
  323. return moves_to_keys(self.moves)
  324. def loops(self, prev):
  325. return self.moves and \
  326. self.exa == prev.exa and \
  327. self.moves == prev.moves and \
  328. self.score == prev.score
  329. def get_colskip(blocks):
  330. def colskip(col):
  331. for row, block in enumerate(blocks[col::COLUMNS]):
  332. if block != NOBLOCK:
  333. return row
  334. return len(blocks) // COLUMNS
  335. return list(map(colskip, range(COLUMNS)))
  336. def get_busy(blocks, colskip):
  337. mask = 0
  338. for col, skip in enumerate(colskip):
  339. start = (skip + 1) * COLUMNS + col
  340. colbusy = NOBLOCK in blocks[start::COLUMNS]
  341. mask |= colbusy << col
  342. return mask
  343. def scan_group(blocks, i, block, visited):
  344. yield i
  345. visited.add(i)
  346. for nb in neighbors(i, blocks):
  347. if blocks[nb] == block and nb not in visited:
  348. yield from scan_group(blocks, nb, block, visited)
  349. def get_groups(blocks):
  350. groupid = 0
  351. groups = [0] * len(blocks)
  352. groupsizes = [0] * len(blocks)
  353. visited = set()
  354. for i, block in enumerate(blocks):
  355. if block != NOBLOCK and i not in visited:
  356. groupid += 1
  357. group = tuple(scan_group(blocks, i, block, visited))
  358. for j in group:
  359. groups[j] = groupid
  360. groupsizes[j] = len(group)
  361. return groups, groupsizes, groupid
  362. def neighbors(i, blocks):
  363. y, x = divmod(i, COLUMNS)
  364. if x > 0:
  365. yield i - 1
  366. if x < COLUMNS - 1:
  367. yield i + 1
  368. if y > 0:
  369. yield i - COLUMNS
  370. if y < len(blocks) // COLUMNS - 1:
  371. yield i + COLUMNS
  372. def move_to_key(move):
  373. return 'jjkadl'[move]
  374. def moves_to_keys(moves):
  375. return ''.join(move_to_key(move) for move in moves)
  376. def moves_delay(moves):
  377. return sum(MOVE_DELAYS[m] for m in moves)
  378. if __name__ == '__main__':
  379. import sys
  380. from PIL import Image
  381. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  382. state = State.detect(board)
  383. print('parsed:')
  384. state.print()
  385. print()
  386. print('groups:')
  387. state.print_groups()
  388. print()
  389. print('groupsizes:')
  390. state.print_groupsizes()
  391. print()
  392. start = time.time()
  393. newstate = state.solve()
  394. end = time.time()
  395. print('best move:', newstate.keys())
  396. print('score:', newstate.score)
  397. print('elapsed:', round((end - start) * 1000, 1), 'ms')
  398. print()
  399. print('target after move:')
  400. newstate.print()
  401. #print()
  402. #print('generated moves:')
  403. #for state in state.gen_moves():
  404. # print(state.keys())