strategy.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import time
  2. from collections import deque
  3. from itertools import islice
  4. from parser import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  5. detect_held, print_board, is_basic, is_bomb, bomb_to_basic
  6. GRAB, DROP, SWAP, LEFT, RIGHT = range(5)
  7. GET = ((GRAB,), (SWAP, GRAB), (GRAB, SWAP, DROP, SWAP, GRAB))
  8. PUT = ((DROP,), (DROP, SWAP), (DROP, SWAP, GRAB, SWAP, DROP))
  9. #REVERSE = [DROP, GRAB, SWAP, RIGHT, LEFT]
  10. MIN_BASIC_GROUP_SIZE = 4
  11. MIN_BOMB_GROUP_SIZE = 2
  12. FIND_GROUPS_DEPTH = 20
  13. FRAG_DEPTH = 20
  14. COLSIZE_PRIO = 5
  15. class State:
  16. def __init__(self, blocks, exa, held):
  17. assert exa is not None
  18. self.blocks = blocks
  19. self.exa = exa
  20. self.held = held
  21. def has_holes(self):
  22. return any(self.blocks[i] == NOBLOCK
  23. for col in self.iter_columns()
  24. for i in col)
  25. def iter_columns(self):
  26. nrows = self.nrows()
  27. def gen_col(col):
  28. for row in range(nrows):
  29. i = row * COLUMNS + col
  30. if self.blocks[i] != NOBLOCK:
  31. yield i
  32. for col in range(COLUMNS):
  33. yield gen_col(col)
  34. @classmethod
  35. def detect(cls, board):
  36. blocks = [NOBLOCK] * (COLUMNS * 2) + list(detect_blocks(board))
  37. exa = detect_exa(board)
  38. held = detect_held(board, exa)
  39. return cls(blocks, exa, held)
  40. def copy(self):
  41. return State(list(self.blocks), self.exa, self.held)
  42. def colsizes(self):
  43. for col in range(COLUMNS):
  44. yield self.nrows() - self.colskip(col)
  45. def hole_score(self):
  46. score = 0
  47. for col in range(COLUMNS):
  48. for row in range(self.nrows()):
  49. i = row * COLUMNS + col
  50. if self.blocks[i] != NOBLOCK:
  51. break
  52. score += row + 1
  53. return score
  54. def score(self, points, moves, prev):
  55. frag = self.fragmentation()
  56. colsizes = list(self.colsizes())
  57. #mincol = min(colsizes)
  58. maxcol = max(colsizes)
  59. colsize_score = maxcol, colsizes.count(maxcol) #, -mincol
  60. #colsize_score = tuple(sorted(colsizes, reverse=True))
  61. #if prev.nrows() >= 6:
  62. # return colsize_score, -points, frag, len(moves)
  63. colsize_score = maxcol, self.hole_score()
  64. prev_colsize = max(prev.colsizes())
  65. if prev_colsize >= COLSIZE_PRIO:
  66. return colsize_score, frag, -points, len(moves)
  67. else:
  68. return -points, frag, colsize_score, len(moves)
  69. def score_moves(self):
  70. # clear exploding blocks before computing colsize
  71. prev = self.copy()
  72. prev.score_points()
  73. for moves in self.gen_moves():
  74. try:
  75. points, newstate = self.simulate(moves)
  76. yield newstate.score(points, moves, prev), moves
  77. except AssertionError:
  78. pass
  79. def colskip(self, col):
  80. nrows = self.nrows()
  81. for row in range(nrows):
  82. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  83. return row
  84. return nrows
  85. def find_unmovable_blocks(self):
  86. unmoveable = set()
  87. bombed = set()
  88. for block, group in self.find_groups():
  89. if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
  90. for i in group:
  91. unmoveable.add(i)
  92. elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
  93. bombed.add(bomb_to_basic(block))
  94. for i in group:
  95. unmoveable.add(i)
  96. for i, block in enumerate(self.blocks):
  97. if block in bombed:
  98. unmoveable.add(i)
  99. return unmoveable
  100. def simulate(self, moves):
  101. s = self.copy()
  102. points = 0
  103. # clear the current board before planning the next move
  104. #s.score_points()
  105. if not moves:
  106. return s.score_points(), s
  107. # avoid swapping/grabbing currently exploding items
  108. unmoveable = s.find_unmovable_blocks()
  109. for move in moves:
  110. if move == LEFT:
  111. assert s.exa > 0
  112. s.exa -= 1
  113. elif move == RIGHT:
  114. assert s.exa < COLUMNS - 1
  115. s.exa += 1
  116. elif move == GRAB:
  117. assert s.held == NOBLOCK
  118. row = s.colskip(s.exa)
  119. assert row < s.nrows()
  120. i = row * COLUMNS + s.exa
  121. assert i not in unmoveable
  122. s.held = s.blocks[i]
  123. s.blocks[i] = NOBLOCK
  124. elif move == DROP:
  125. assert s.held != NOBLOCK
  126. row = s.colskip(s.exa)
  127. assert row > 0
  128. i = row * COLUMNS + s.exa
  129. s.blocks[i - COLUMNS] = s.held
  130. s.held = NOBLOCK
  131. points += s.score_points()
  132. elif move == SWAP:
  133. row = s.colskip(s.exa)
  134. assert row < s.nrows() - 2
  135. i = row * COLUMNS + s.exa
  136. j = i + COLUMNS
  137. assert i not in unmoveable
  138. assert j not in unmoveable
  139. s.blocks[i], s.blocks[j] = s.blocks[j], s.blocks[i]
  140. points += s.score_points()
  141. return points, s
  142. def find_groups(self, depth=FIND_GROUPS_DEPTH):
  143. def follow_group(i, block, group):
  144. if self.blocks[i] == block and i not in visited:
  145. group.append(i)
  146. visited.add(i)
  147. for nb in self.neighbors(i):
  148. follow_group(nb, block, group)
  149. visited = set()
  150. for col in self.iter_columns():
  151. for i in islice(col, depth):
  152. block = self.blocks[i]
  153. group = []
  154. follow_group(i, block, group)
  155. if len(group) > 1:
  156. #for j in group:
  157. # visited.add(j)
  158. yield block, group
  159. def neighbors(self, i):
  160. def gen_indices():
  161. row, col = divmod(i, COLUMNS)
  162. if col > 0:
  163. yield i - 1
  164. if col < COLUMNS - 1:
  165. yield i + 1
  166. if row > 0:
  167. yield i - COLUMNS
  168. if row < self.nrows() - 1:
  169. yield i + COLUMNS
  170. for j in gen_indices():
  171. if self.blocks[j] != NOBLOCK:
  172. yield j
  173. def fragmentation(self, depth=FRAG_DEPTH):
  174. """
  175. Sum the minimum distance from every block in the first 3 layers to the
  176. closest block of the same color.
  177. """
  178. def find_closest(i):
  179. block = self.blocks[i]
  180. work = deque([(i, -1)])
  181. visited = {i}
  182. while work:
  183. i, dist = work.popleft()
  184. if dist >= 0 and self.blocks[i] == block:
  185. return dist
  186. for j in self.neighbors(i):
  187. if j not in visited:
  188. visited.add(j)
  189. work.append((j, dist + 1))
  190. # only one of this type -> don't count as fragmented
  191. return 0
  192. return sum(find_closest(i)
  193. for col in self.iter_columns()
  194. for i in islice(col, depth))
  195. def score_points(self, multiplier=1):
  196. remove = []
  197. points = 0
  198. for block, group in self.find_groups():
  199. if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
  200. remove.extend(group)
  201. points += len(group) * multiplier
  202. elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
  203. #points += 10
  204. remove.extend(group)
  205. for i, other in enumerate(self.blocks):
  206. if other == bomb_to_basic(block):
  207. remove.append(i)
  208. remove.sort()
  209. prev = None
  210. for i in remove:
  211. if i != prev:
  212. while self.blocks[i] != NOBLOCK:
  213. self.blocks[i] = self.blocks[i - COLUMNS]
  214. i -= COLUMNS
  215. prev = i
  216. if points:
  217. points += self.score_points(min(2, multiplier * 2))
  218. return points
  219. def gen_moves(self):
  220. yield ()
  221. for src in range(COLUMNS):
  222. diff = src - self.exa
  223. direction = RIGHT if diff > 0 else LEFT
  224. mov1 = abs(diff) * (direction,)
  225. yield mov1 + (SWAP,)
  226. yield mov1 + (GRAB, SWAP, DROP)
  227. for dst in range(COLUMNS):
  228. diff = dst - src
  229. direction = RIGHT if diff > 0 else LEFT
  230. mov2 = abs(diff) * (direction,)
  231. for i, get in enumerate(GET):
  232. if i > 1 or diff != 0:
  233. for put in PUT:
  234. yield mov1 + get + mov2 + put
  235. def solve(self):
  236. if self.held != NOBLOCK:
  237. return (DROP,)
  238. score, moves = min(self.score_moves())
  239. return moves
  240. def print(self):
  241. print_board(self.blocks, self.exa, self.held)
  242. def nrows(self):
  243. return len(self.blocks) // COLUMNS
  244. def moves_to_keys(moves):
  245. return ''.join('jjkad'[move] for move in moves)
  246. if __name__ == '__main__':
  247. import sys
  248. from PIL import Image
  249. #from pprint import pprint
  250. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  251. state = State.detect(board)
  252. print('parsed:')
  253. state.print()
  254. print()
  255. start = time.time()
  256. moves = state.solve()
  257. print('moves:', moves_to_keys(moves))
  258. end = time.time()
  259. print('elapsed:', end - start)
  260. print()
  261. print('target after moves:')
  262. points, newstate = state.simulate(moves)
  263. newstate.print()
  264. print()
  265. #for score, moves in sorted(state.score_moves()):
  266. # print('move %18s:' % moves_to_keys(moves), score)
  267. # #print('moves:', moves_to_keys(moves), moves)
  268. # #print('score:', score)
  269. #print('\nmoves:', moves_to_keys(state.solve()))