strategy.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import io
  2. import time
  3. from contextlib import redirect_stdout
  4. from itertools import combinations, islice
  5. from parse import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  6. detect_held, print_board, is_basic, is_bomb, bomb_to_basic
  7. GRAB, DROP, SWAP, LEFT, RIGHT, SPEED = range(6)
  8. GET = ((GRAB,), (SWAP, GRAB), (GRAB, SWAP, DROP, SWAP, GRAB))
  9. PUT = ((DROP,), (DROP, SWAP), (DROP, SWAP, GRAB, SWAP, DROP))
  10. MIN_BASIC_GROUP_SIZE = 4
  11. MIN_BOMB_GROUP_SIZE = 2
  12. FIND_GROUPS_DEPTH = 3
  13. FRAG_DEPTH = 3
  14. COLSIZE_PRIO = 5
  15. COLSIZE_PANIC = 7
  16. COLSIZE_MAX = 8
  17. BOMB_POINTS = 2
  18. MIN_ROWS = 2
  19. class State:
  20. def __init__(self, blocks, exa, held):
  21. assert exa is not None
  22. self.blocks = blocks
  23. self.exa = exa
  24. self.held = held
  25. def grabbing_of_dropping(self):
  26. skip = self.colskip(self.exa)
  27. i = (skip + 1) * COLUMNS + self.exa
  28. return i < len(self.blocks) and self.blocks[i] == NOBLOCK
  29. def iter_columns(self):
  30. nrows = self.nrows()
  31. def gen_col(col):
  32. for row in range(nrows):
  33. i = row * COLUMNS + col
  34. if self.blocks[i] != NOBLOCK:
  35. yield i
  36. for col in range(COLUMNS):
  37. yield gen_col(col)
  38. @classmethod
  39. def detect(cls, board, pad=2):
  40. blocks = [NOBLOCK] * (COLUMNS * pad) + list(detect_blocks(board))
  41. exa = detect_exa(board)
  42. held = detect_held(board, exa)
  43. return cls(blocks, exa, held)
  44. def copy(self):
  45. return State(list(self.blocks), self.exa, self.held)
  46. def colsizes(self):
  47. for col in range(COLUMNS):
  48. yield self.nrows() - self.colskip(col)
  49. def empty_column_score(self):
  50. skip = 0
  51. for i, block in enumerate(self.blocks):
  52. if block != NOBLOCK:
  53. skip = i // COLUMNS
  54. break
  55. nrows = self.nrows()
  56. score = 0
  57. for col in range(COLUMNS):
  58. for row in range(skip, nrows):
  59. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  60. break
  61. score += row - skip + 1
  62. return score
  63. def score(self, points, moves, prev):
  64. frag = self.fragmentation()
  65. colsize_score = self.empty_column_score()
  66. #return -points, frag + colsize_score, len(moves)
  67. frag += colsize_score
  68. prev_colsize = max(prev.colsizes())
  69. if prev_colsize >= COLSIZE_PANIC:
  70. return colsize_score, len(moves), -points, frag
  71. elif prev_colsize >= COLSIZE_PRIO:
  72. return -points, colsize_score, frag, len(moves)
  73. else:
  74. return -points, frag, colsize_score, len(moves)
  75. def score_moves(self):
  76. for moves in self.gen_moves():
  77. try:
  78. points, newstate = self.simulate(moves)
  79. yield newstate.score(points, moves, self), moves
  80. except AssertionError:
  81. pass
  82. def colskip(self, col):
  83. nrows = self.nrows()
  84. for row in range(nrows):
  85. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  86. return row
  87. return nrows
  88. def find_unmovable_blocks(self):
  89. unmoveable = set()
  90. bombed = set()
  91. for block, group in self.find_groups():
  92. if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
  93. for i in group:
  94. unmoveable.add(i)
  95. elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
  96. bombed.add(bomb_to_basic(block))
  97. for i in group:
  98. unmoveable.add(i)
  99. for i, block in enumerate(self.blocks):
  100. if block in bombed:
  101. unmoveable.add(i)
  102. return unmoveable
  103. def simulate(self, moves):
  104. s = self.copy()
  105. points = 0
  106. # avoid swapping/grabbing currently exploding items
  107. #unmoveable = s.find_unmovable_blocks()
  108. for move in moves:
  109. if move == LEFT:
  110. assert s.exa > 0
  111. s.exa -= 1
  112. elif move == RIGHT:
  113. assert s.exa < COLUMNS - 1
  114. s.exa += 1
  115. elif move == GRAB:
  116. assert s.held == NOBLOCK
  117. row = s.colskip(s.exa)
  118. assert row < s.nrows()
  119. i = row * COLUMNS + s.exa
  120. #assert i not in unmoveable
  121. s.held = s.blocks[i]
  122. s.blocks[i] = NOBLOCK
  123. elif move == DROP:
  124. assert s.held != NOBLOCK
  125. row = s.colskip(s.exa)
  126. assert row > 0
  127. i = row * COLUMNS + s.exa
  128. s.blocks[i - COLUMNS] = s.held
  129. s.held = NOBLOCK
  130. #points += s.score_points()
  131. elif move == SWAP:
  132. row = s.colskip(s.exa)
  133. assert row < s.nrows() - 2
  134. i = row * COLUMNS + s.exa
  135. j = i + COLUMNS
  136. #assert i not in unmoveable
  137. #assert j not in unmoveable
  138. s.blocks[i], s.blocks[j] = s.blocks[j], s.blocks[i]
  139. #points += s.score_points()
  140. if moves and max(self.colsizes()) < COLSIZE_MAX:
  141. assert max(s.colsizes()) <= COLSIZE_MAX
  142. points += s.score_points()
  143. return points, s
  144. def find_groups(self, depth=FIND_GROUPS_DEPTH, minsize=2):
  145. def follow_group(i, block, group):
  146. if self.blocks[i] == block and i not in visited:
  147. group.append(i)
  148. visited.add(i)
  149. for nb in self.neighbors(i):
  150. follow_group(nb, block, group)
  151. visited = set()
  152. for col in self.iter_columns():
  153. for i in islice(col, depth):
  154. block = self.blocks[i]
  155. group = []
  156. follow_group(i, block, group)
  157. if len(group) >= minsize:
  158. yield block, group
  159. def neighbors(self, i):
  160. def gen_indices():
  161. row, col = divmod(i, COLUMNS)
  162. if col > 0:
  163. yield i - 1
  164. if col < COLUMNS - 1:
  165. yield i + 1
  166. if row > 0:
  167. yield i - COLUMNS
  168. if row < self.nrows() - 1:
  169. yield i + COLUMNS
  170. for j in gen_indices():
  171. if self.blocks[j] != NOBLOCK:
  172. yield j
  173. def fragmentation(self, depth=FRAG_DEPTH):
  174. """
  175. Minimize the sum of dist(i,j) between all blocks i,j of the same color.
  176. Magnify vertical distances to avoid column stacking.
  177. """
  178. def dist(i, j):
  179. yi, xi = divmod(i, COLUMNS)
  180. yj, xj = divmod(j, COLUMNS)
  181. # for blocks in the same group, only count vertical distance so that
  182. # groups are spread out horizontally
  183. if groups[i] == groups[j]:
  184. return abs(yj - yi)
  185. return abs(xj - xi) + abs(yj - yi) * 2 - 1
  186. colors = {}
  187. groups = {}
  188. groupsizes = {}
  189. for groupid, (block, group) in enumerate(self.find_groups(depth, 1)):
  190. colors.setdefault(block, []).extend(group)
  191. for i in group:
  192. groups[i] = groupid
  193. groupsizes[i] = len(group)
  194. return sum(dist(i, j) # * (1 + 2 * is_bomb(block))
  195. for block, color in colors.items()
  196. for i, j in combinations(color, 2))
  197. def score_points(self, multiplier=1):
  198. remove = []
  199. points = 0
  200. for block, group in self.find_groups():
  201. if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
  202. remove.extend(group)
  203. points += len(group) * multiplier
  204. elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
  205. points += BOMB_POINTS
  206. remove.extend(group)
  207. for i, other in enumerate(self.blocks):
  208. if other == bomb_to_basic(block):
  209. remove.append(i)
  210. remove.sort()
  211. prev = None
  212. for i in remove:
  213. if i != prev:
  214. while self.blocks[i] != NOBLOCK:
  215. self.blocks[i] = self.blocks[i - COLUMNS]
  216. i -= COLUMNS
  217. prev = i
  218. if points:
  219. points += self.score_points(min(2, multiplier * 2))
  220. return points
  221. def has_explosion(self):
  222. return any(is_bomb(block) and
  223. any(self.blocks[j] == block for j in self.neighbors(i))
  224. for i, block in enumerate(self.blocks))
  225. def gen_moves(self):
  226. yield ()
  227. def make_move(diff):
  228. direction = RIGHT if diff > 0 else LEFT
  229. return abs(diff) * (direction,)
  230. for src in range(COLUMNS):
  231. mov1 = make_move(src - self.exa)
  232. yield mov1 + (SWAP,)
  233. yield mov1 + (GRAB, SWAP, DROP)
  234. yield mov1 + (SWAP, GRAB, SWAP, DROP)
  235. for dst in range(COLUMNS):
  236. if dst != src:
  237. mov2 = make_move(dst - src)
  238. for get in GET:
  239. for put in PUT:
  240. yield mov1 + get + mov2 + put
  241. def solve(self):
  242. if self.held != NOBLOCK:
  243. return (DROP,)
  244. if self.nrows() < MIN_ROWS:
  245. return ()
  246. if self.grabbing_of_dropping():
  247. return ()
  248. if self.has_explosion():
  249. return ()
  250. score, moves = min(self.score_moves())
  251. if not moves:
  252. return (SPEED,)
  253. return moves
  254. def print(self):
  255. print_board(self.blocks, self.exa, self.held)
  256. def tostring(self):
  257. stream = io.StringIO()
  258. with redirect_stdout(stream):
  259. self.print()
  260. return stream.getvalue()
  261. def nrows(self):
  262. return len(self.blocks) // COLUMNS
  263. def moves_to_keys(moves):
  264. return ''.join('jjkadl'[move] for move in moves)
  265. if __name__ == '__main__':
  266. import sys
  267. from PIL import Image
  268. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  269. state = State.detect(board)
  270. print('parsed:')
  271. state.print()
  272. print()
  273. print('empty cols:', state.empty_column_score())
  274. print()
  275. start = time.time()
  276. moves = state.solve()
  277. end = time.time()
  278. print('moves:', moves_to_keys(moves))
  279. print('elapsed:', end - start)
  280. print()
  281. print('target after moves:')
  282. points, newstate = state.simulate(moves)
  283. newstate.print()
  284. print()
  285. for score, moves in sorted(state.score_moves()):
  286. print('move %18s:' % moves_to_keys(moves), score)