strategy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. import io
  2. import time
  3. from collections import deque
  4. from contextlib import redirect_stdout
  5. from copy import copy
  6. from itertools import combinations, islice
  7. from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  8. detect_held, print_board, is_basic, is_bomb
  9. GRAB, DROP, SWAP, LEFT, RIGHT, SPEED = range(6)
  10. MOVE_DELAYS = (
  11. # in milliseconds
  12. 50, # GRAB
  13. 50, # DROP
  14. 50, # SWAP
  15. 30, # LEFT
  16. 30, # RIGHT
  17. 30, # SPEED
  18. )
  19. MIN_BASIC_GROUP_SIZE = 4
  20. MIN_BOMB_GROUP_SIZE = 2
  21. POINTS_DEPTH = 3
  22. FRAG_DEPTH = 4
  23. DEFRAG_PRIO = 4
  24. COLSIZE_PRIO = 5
  25. COLSIZE_PANIC = 8
  26. COLSIZE_MAX = 9
  27. BOMB_POINTS = 5
  28. class State:
  29. def __init__(self, blocks, exa, held, colskip, busy, moves, placed, grabbed):
  30. self.blocks = blocks
  31. self.exa = exa
  32. self.held = held
  33. self.colskip = colskip
  34. self.busy = busy
  35. self.moves = moves
  36. self.placed = placed
  37. self.grabbed = grabbed
  38. self.nrows = len(blocks) // COLUMNS
  39. @classmethod
  40. def detect(cls, board, pad=2):
  41. blocks = [NOBLOCK] * (COLUMNS * pad) + list(detect_blocks(board))
  42. exa = detect_exa(board)
  43. held = detect_held(board, exa)
  44. colskip = get_colskip(blocks)
  45. busy = get_busy(blocks, colskip)
  46. return cls(blocks, exa, held, colskip, busy, (), set(), {})
  47. def copy(self, deep):
  48. mcopy = copy if deep else lambda x: x
  49. return self.__class__(mcopy(self.blocks),
  50. self.exa,
  51. self.held,
  52. mcopy(self.colskip),
  53. self.busy,
  54. self.moves,
  55. mcopy(self.placed),
  56. mcopy(self.grabbed))
  57. def colbusy(self, col):
  58. return (self.busy >> col) & 1
  59. def grabbing_or_dropping(self):
  60. skip = self.colskip[self.exa]
  61. i = (skip + 1) * COLUMNS + self.exa
  62. return i < len(self.blocks) and self.blocks[i] == NOBLOCK
  63. def iter_columns(self):
  64. def gen_col(col):
  65. for row in range(self.nrows):
  66. i = row * COLUMNS + col
  67. if self.blocks[i] != NOBLOCK:
  68. yield i
  69. for col in range(COLUMNS):
  70. yield gen_col(col)
  71. def causes_panic(self):
  72. return self.max_colsize() >= COLSIZE_PANIC
  73. def max_colsize(self):
  74. return self.nrows - self.empty_rows()
  75. def empty_rows(self):
  76. for i, block in enumerate(self.blocks):
  77. if block != NOBLOCK:
  78. return i // COLUMNS
  79. return 0
  80. def holes(self):
  81. start_row = self.empty_rows()
  82. score = 0
  83. for col in range(COLUMNS):
  84. for row in range(start_row, self.nrows):
  85. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  86. break
  87. score += row - start_row + 1
  88. return score
  89. def colrows(self, col):
  90. return self.nrows - self.colskip[col]
  91. def move(self, *moves):
  92. deep = any(move in (GRAB, DROP, SWAP) for move in moves)
  93. s = self.copy(deep)
  94. s.moves += moves
  95. for move in moves:
  96. if move == LEFT:
  97. assert s.exa > 0
  98. s.exa -= 1
  99. elif move == RIGHT:
  100. assert s.exa < COLUMNS - 1
  101. s.exa += 1
  102. elif move == GRAB:
  103. assert not s.colbusy(s.exa)
  104. assert s.held == NOBLOCK
  105. row = s.colskip[s.exa]
  106. assert row < s.nrows
  107. i = row * COLUMNS + s.exa
  108. s.grabbed[i] = s.held = s.blocks[i]
  109. s.blocks[i] = NOBLOCK
  110. s.grabbed[i] = s.held
  111. s.colskip[s.exa] += 1
  112. elif move == DROP:
  113. assert not s.colbusy(s.exa)
  114. assert s.held != NOBLOCK
  115. row = s.colskip[s.exa]
  116. assert row > 0
  117. # XXX assert s.nrows - row < COLSIZE_MAX
  118. i = (row - 1) * COLUMNS + s.exa
  119. s.blocks[i] = s.held
  120. s.held = NOBLOCK
  121. s.placed.add(i)
  122. s.colskip[s.exa] -= 1
  123. elif move == SWAP:
  124. assert not s.colbusy(s.exa)
  125. row = s.colskip[s.exa]
  126. i = row * COLUMNS + s.exa
  127. j = i + COLUMNS
  128. assert j < len(s.blocks)
  129. bi = s.blocks[i]
  130. bj = s.blocks[j]
  131. if bi != bj:
  132. s.blocks[i] = bj
  133. s.blocks[j] = bi
  134. s.grabbed[i] = bi
  135. s.grabbed[j] = bj
  136. s.placed.add(i)
  137. s.placed.add(j)
  138. return s
  139. def find_groups(self, depth=POINTS_DEPTH, minsize=2):
  140. def follow_group(i, block, group):
  141. if self.blocks[i] == block and i not in visited:
  142. group.append(i)
  143. visited.add(i)
  144. for nb in self.neighbors(i):
  145. follow_group(nb, block, group)
  146. visited = set()
  147. for col in self.iter_columns():
  148. for i in islice(col, depth):
  149. block = self.blocks[i]
  150. group = []
  151. follow_group(i, block, group)
  152. if len(group) >= minsize:
  153. yield block, group
  154. def neighbors(self, i):
  155. row, col = divmod(i, COLUMNS)
  156. if col > 0 and self.blocks[i - 1] != NOBLOCK:
  157. yield i - 1
  158. if col < COLUMNS - 1 and self.blocks[i + 1] != NOBLOCK:
  159. yield i + 1
  160. if row > 0 and self.blocks[i - COLUMNS] != NOBLOCK:
  161. yield i - COLUMNS
  162. if row < self.nrows - 1 and self.blocks[i + COLUMNS] != NOBLOCK:
  163. yield i + COLUMNS
  164. def fragmentation(self, depth=FRAG_DEPTH):
  165. """
  166. Minimize the sum of dist(i,j) between all blocks i,j of the same color.
  167. Magnify vertical distances to avoid column stacking.
  168. """
  169. def dist(i, j):
  170. yi, xi = divmod(i, COLUMNS)
  171. yj, xj = divmod(j, COLUMNS)
  172. # for blocks in the same group, only count vertical distance so
  173. # that groups are spread out horizontally
  174. if groups[i] == groups[j]:
  175. return abs(yj - yi)
  176. return abs(xj - xi) + abs(yj - yi) * 2 - 1
  177. colors = {}
  178. groups = {}
  179. groupsizes = {}
  180. for groupid, (block, group) in enumerate(self.find_groups(depth, 1)):
  181. colors.setdefault(block, []).extend(group)
  182. for i in group:
  183. groups[i] = groupid
  184. groupsizes[i] = len(group)
  185. return sum(dist(i, j)
  186. for block, color in colors.items()
  187. for i, j in combinations(color, 2))
  188. def points(self):
  189. def group_size(start):
  190. work = [start]
  191. visited.add(start)
  192. size = 0
  193. block = self.blocks[start]
  194. while work:
  195. i = work.pop()
  196. # avoid giving points to moving a block within the same group
  197. if self.grabbed.get(i, None) == block:
  198. return 0
  199. if self.blocks[i] == block:
  200. size += 1
  201. for nb in self.neighbors(i):
  202. if nb not in visited:
  203. visited.add(nb)
  204. work.append(nb)
  205. return size
  206. points = 0
  207. visited = set()
  208. for i in self.placed:
  209. if i not in visited:
  210. block = self.blocks[i]
  211. size = group_size(i)
  212. if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
  213. points += size
  214. elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
  215. points += BOMB_POINTS
  216. return -points
  217. def gen_moves(self):
  218. yield self
  219. for src in self.gen_shift(not self.grabbing_or_dropping()):
  220. yield from src.gen_stationary()
  221. for get in src.gen_get():
  222. for dst in get.gen_shift(False):
  223. yield from dst.gen_put()
  224. def gen_shift(self, allow_noshift):
  225. if allow_noshift:
  226. yield self
  227. left = self
  228. for i in range(self.exa):
  229. left = left.move(LEFT)
  230. yield left
  231. right = self
  232. for i in range(COLUMNS - self.exa - 1):
  233. right = right.move(RIGHT)
  234. yield right
  235. def gen_stationary(self):
  236. # SWAP
  237. # GRAB, SWAP, DROP
  238. # GRAB, SWAP, DROP, SWAP
  239. # SWAP, GRAB, SWAP, DROP
  240. # SWAP, GRAB, SWAP, DROP, SWAP
  241. if not self.colbusy(self.exa):
  242. avail = self.colrows(self.exa)
  243. if avail >= 2:
  244. swap = self.move(SWAP)
  245. yield swap
  246. if avail >= 3:
  247. grab = self.move(GRAB, SWAP, DROP)
  248. yield grab
  249. yield grab.move(SWAP)
  250. swap = swap.move(GRAB, SWAP, DROP)
  251. yield swap
  252. yield swap.move(SWAP)
  253. def gen_get(self):
  254. # GRAB
  255. # SWAP, GRAB
  256. # GRAB, SWAP, DROP, SWAP, GRAB
  257. if not self.colbusy(self.exa):
  258. avail = self.colrows(self.exa)
  259. if avail >= 1:
  260. grab = self.move(GRAB)
  261. yield grab
  262. if avail >= 2:
  263. yield self.move(SWAP, GRAB)
  264. if avail >= 3:
  265. yield grab.move(SWAP, DROP, SWAP, GRAB)
  266. def gen_put(self):
  267. # DROP
  268. # DROP, SWAP
  269. # DROP, SWAP, GRAB, SWAP, DROP
  270. if not self.colbusy(self.exa):
  271. avail = self.colrows(self.exa)
  272. drop = self.move(DROP)
  273. yield drop
  274. if avail >= 1:
  275. swap = drop.move(SWAP)
  276. yield swap
  277. if avail >= 2:
  278. yield swap.move(GRAB, SWAP, DROP)
  279. def force(self, *moves):
  280. state = self.move(*moves)
  281. state.score = ()
  282. return state
  283. def solve(self):
  284. assert self.exa is not None
  285. if self.held != NOBLOCK:
  286. return self.force(DROP)
  287. pool = deque(self.gen_moves())
  288. if len(pool) == 0:
  289. return self.force()
  290. best_score = ()
  291. for key in self.score_keys():
  292. if len(pool) == 1:
  293. break
  294. for state in pool:
  295. state.score = key(state)
  296. best = min(state.score for state in pool)
  297. best_score += (best,)
  298. for i in range(len(pool)):
  299. state = pool.popleft()
  300. if state.score == best:
  301. pool.append(state)
  302. best = pool.popleft()
  303. best.score = best_score
  304. return best
  305. def score_keys(self):
  306. cls = self.__class__
  307. colsize = self.nrows - 2
  308. if colsize >= COLSIZE_PANIC:
  309. return cls.holes, cls.nmoves, cls.points, cls.fragmentation
  310. if colsize >= COLSIZE_PRIO:
  311. return cls.causes_panic, cls.points, cls.holes, \
  312. cls.fragmentation, cls.nmoves
  313. return cls.points, cls.fragmentation, cls.holes, cls.nmoves
  314. def print(self):
  315. print_board(self.blocks, self.exa, self.held)
  316. def tostring(self):
  317. stream = io.StringIO()
  318. with redirect_stdout(stream):
  319. self.print()
  320. return stream.getvalue()
  321. def has_same_exa(self, state):
  322. return self.exa == state.exa and self.held == state.held
  323. def nmoves(self):
  324. return len(self.moves)
  325. def delay(self):
  326. return moves_delay(self.moves)
  327. def keys(self):
  328. return moves_to_keys(self.moves)
  329. def loops(self, prev):
  330. return self.moves and \
  331. self.exa == prev.exa and \
  332. self.moves == prev.moves and \
  333. self.score == prev.score
  334. def get_colskip(blocks):
  335. def colskip(col):
  336. for row, block in enumerate(blocks[col::COLUMNS]):
  337. if block != NOBLOCK:
  338. return row
  339. return len(blocks) // COLUMNS
  340. return list(map(colskip, range(COLUMNS)))
  341. def get_busy(blocks, colskip):
  342. mask = 0
  343. for col, skip in enumerate(colskip):
  344. start = (skip + 1) * COLUMNS + col
  345. colbusy = NOBLOCK in blocks[start::COLUMNS]
  346. mask |= colbusy << col
  347. return mask
  348. def move_to_key(move):
  349. return 'jjkadl'[move]
  350. def moves_to_keys(moves):
  351. return ''.join(move_to_key(move) for move in moves)
  352. def moves_delay(moves):
  353. return sum(MOVE_DELAYS[m] for m in moves)
  354. if __name__ == '__main__':
  355. import sys
  356. from PIL import Image
  357. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  358. state = State.detect(board)
  359. print('parsed:')
  360. state.print()
  361. print()
  362. start = time.time()
  363. newstate = state.solve()
  364. end = time.time()
  365. print('best move:', newstate.keys())
  366. print('score:', newstate.score)
  367. print('elapsed:', round((end - start) * 1000, 1), 'ms')
  368. print()
  369. print('target after move:')
  370. newstate.print()
  371. #print()
  372. #print('generated moves:')
  373. #for state in state.gen_moves():
  374. # print(state.keys())