strategy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import io
  2. import time
  3. from contextlib import redirect_stdout
  4. from itertools import combinations, islice
  5. from parse import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  6. detect_held, print_board, is_basic, is_bomb, bomb_to_basic
  7. GRAB, DROP, SWAP, LEFT, RIGHT, SPEED = range(6)
  8. GET = ((GRAB,), (SWAP, GRAB), (GRAB, SWAP, DROP, SWAP, GRAB))
  9. PUT = ((DROP,), (DROP, SWAP), (DROP, SWAP, GRAB, SWAP, DROP))
  10. MIN_BASIC_GROUP_SIZE = 4
  11. MIN_BOMB_GROUP_SIZE = 2
  12. FIND_GROUPS_DEPTH = 3
  13. FRAG_DEPTH = 3
  14. DEFRAG_PRIO = 3
  15. COLSIZE_PRIO = 5
  16. COLSIZE_PANIC = 7
  17. COLSIZE_MAX = 8
  18. BOMB_POINTS = 1
  19. MIN_ROWS = 2
  20. MAX_SPEED_ROWS = 3
  21. class State:
  22. def __init__(self, blocks, exa, held):
  23. self.blocks = blocks
  24. self.exa = exa
  25. self.held = held
  26. def grabbing_of_dropping(self):
  27. skip = self.colskip(self.exa)
  28. i = (skip + 1) * COLUMNS + self.exa
  29. return i < len(self.blocks) and self.blocks[i] == NOBLOCK
  30. def iter_columns(self):
  31. nrows = self.nrows()
  32. def gen_col(col):
  33. for row in range(nrows):
  34. i = row * COLUMNS + col
  35. if self.blocks[i] != NOBLOCK:
  36. yield i
  37. for col in range(COLUMNS):
  38. yield gen_col(col)
  39. @classmethod
  40. def detect(cls, board, pad=2):
  41. blocks = [NOBLOCK] * (COLUMNS * pad) + list(detect_blocks(board))
  42. exa = detect_exa(board)
  43. held = detect_held(board, exa)
  44. return cls(blocks, exa, held)
  45. def copy(self):
  46. return State(list(self.blocks), self.exa, self.held)
  47. def colsizes(self):
  48. for col in range(COLUMNS):
  49. yield self.nrows() - self.colskip(col)
  50. def colsize_panic(self):
  51. return int(max(self.colsizes()) >= COLSIZE_PANIC)
  52. def empty_column_score(self):
  53. skip = 0
  54. for i, block in enumerate(self.blocks):
  55. if block != NOBLOCK:
  56. skip = i // COLUMNS
  57. break
  58. nrows = self.nrows()
  59. score = 0
  60. for col in range(COLUMNS):
  61. for row in range(skip, nrows):
  62. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  63. break
  64. score += row - skip + 1
  65. return score
  66. def score(self, points, moves, prev):
  67. prev_colsize = max(prev.colsizes())
  68. points = self.score_points()
  69. if prev_colsize >= COLSIZE_PANIC:
  70. colsize = self.empty_column_score()
  71. #frag = self.fragmentation()
  72. return colsize, len(moves), -points #, frag
  73. elif prev_colsize >= DEFRAG_PRIO:
  74. colsize = self.empty_column_score()
  75. frag = self.fragmentation()
  76. panic = self.colsize_panic()
  77. return -points, panic, frag, colsize, len(moves)
  78. elif prev_colsize >= COLSIZE_PRIO:
  79. colsize = self.empty_column_score()
  80. return -points, colsize, len(moves)
  81. else:
  82. return -points, len(moves)
  83. def score_moves(self):
  84. for moves in self.gen_moves():
  85. try:
  86. points, newstate = self.simulate(moves)
  87. yield newstate.score(points, moves, self), moves
  88. except AssertionError:
  89. pass
  90. def colskip(self, col):
  91. nrows = self.nrows()
  92. for row in range(nrows):
  93. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  94. return row
  95. return nrows
  96. def find_unmovable_blocks(self):
  97. unmoveable = set()
  98. bombed = set()
  99. for block, group in self.find_groups():
  100. if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
  101. for i in group:
  102. unmoveable.add(i)
  103. elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
  104. bombed.add(bomb_to_basic(block))
  105. for i in group:
  106. unmoveable.add(i)
  107. for i, block in enumerate(self.blocks):
  108. if block in bombed:
  109. unmoveable.add(i)
  110. return unmoveable
  111. def simulate(self, moves):
  112. s = self.copy()
  113. points = 0
  114. # avoid swapping/grabbing currently exploding items
  115. #unmoveable = s.find_unmovable_blocks()
  116. for move in moves:
  117. if move == LEFT:
  118. assert s.exa > 0
  119. s.exa -= 1
  120. elif move == RIGHT:
  121. assert s.exa < COLUMNS - 1
  122. s.exa += 1
  123. elif move == GRAB:
  124. assert s.held == NOBLOCK
  125. row = s.colskip(s.exa)
  126. assert row < s.nrows()
  127. i = row * COLUMNS + s.exa
  128. #assert i not in unmoveable
  129. s.held = s.blocks[i]
  130. s.blocks[i] = NOBLOCK
  131. elif move == DROP:
  132. assert s.held != NOBLOCK
  133. row = s.colskip(s.exa)
  134. assert row > 0
  135. i = row * COLUMNS + s.exa
  136. s.blocks[i - COLUMNS] = s.held
  137. s.held = NOBLOCK
  138. #points += s.score_points()
  139. elif move == SWAP:
  140. row = s.colskip(s.exa)
  141. assert row < s.nrows() - 2
  142. i = row * COLUMNS + s.exa
  143. j = i + COLUMNS
  144. #assert i not in unmoveable
  145. #assert j not in unmoveable
  146. s.blocks[i], s.blocks[j] = s.blocks[j], s.blocks[i]
  147. #points += s.score_points()
  148. if moves and max(self.colsizes()) < COLSIZE_MAX:
  149. assert max(s.colsizes()) <= COLSIZE_MAX
  150. return points, s
  151. def find_groups(self, depth=FIND_GROUPS_DEPTH, minsize=2):
  152. def follow_group(i, block, group):
  153. if self.blocks[i] == block and i not in visited:
  154. group.append(i)
  155. visited.add(i)
  156. for nb in self.neighbors(i):
  157. follow_group(nb, block, group)
  158. visited = set()
  159. for col in self.iter_columns():
  160. for i in islice(col, depth):
  161. block = self.blocks[i]
  162. group = []
  163. follow_group(i, block, group)
  164. if len(group) >= minsize:
  165. yield block, group
  166. def neighbors(self, i):
  167. def gen_indices():
  168. row, col = divmod(i, COLUMNS)
  169. if col > 0:
  170. yield i - 1
  171. if col < COLUMNS - 1:
  172. yield i + 1
  173. if row > 0:
  174. yield i - COLUMNS
  175. if row < self.nrows() - 1:
  176. yield i + COLUMNS
  177. for j in gen_indices():
  178. if self.blocks[j] != NOBLOCK:
  179. yield j
  180. def fragmentation(self, depth=FRAG_DEPTH):
  181. """
  182. Minimize the sum of dist(i,j) between all blocks i,j of the same color.
  183. Magnify vertical distances to avoid column stacking.
  184. """
  185. def dist(i, j):
  186. yi, xi = divmod(i, COLUMNS)
  187. yj, xj = divmod(j, COLUMNS)
  188. # for blocks in the same group, only count vertical distance so that
  189. # groups are spread out horizontally
  190. if groups[i] == groups[j]:
  191. return abs(yj - yi)
  192. return abs(xj - xi) + abs(yj - yi) * 2 - 1
  193. colors = {}
  194. groups = {}
  195. groupsizes = {}
  196. for groupid, (block, group) in enumerate(self.find_groups(depth, 1)):
  197. colors.setdefault(block, []).extend(group)
  198. for i in group:
  199. groups[i] = groupid
  200. groupsizes[i] = len(group)
  201. return sum(dist(i, j) # * (1 + 2 * is_bomb(block))
  202. for block, color in colors.items()
  203. for i, j in combinations(color, 2))
  204. def score_points(self, multiplier=1):
  205. #remove = []
  206. points = 0
  207. for block, group in self.find_groups():
  208. if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
  209. #remove.extend(group)
  210. points += len(group) * multiplier
  211. elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
  212. points += BOMB_POINTS
  213. #remove.extend(group)
  214. #for i, other in enumerate(self.blocks):
  215. # if other == bomb_to_basic(block):
  216. # remove.append(i)
  217. #remove.sort()
  218. #prev = None
  219. #for i in remove:
  220. # if i != prev:
  221. # while self.blocks[i] != NOBLOCK:
  222. # self.blocks[i] = self.blocks[i - COLUMNS]
  223. # i -= COLUMNS
  224. # prev = i
  225. #if points:
  226. # points += self.score_points(min(2, multiplier * 2))
  227. return points
  228. def has_explosion(self):
  229. return any(is_bomb(block) and
  230. any(self.blocks[j] == block for j in self.neighbors(i))
  231. for i, block in enumerate(self.blocks))
  232. def gen_moves(self):
  233. yield ()
  234. def make_move(diff):
  235. direction = RIGHT if diff > 0 else LEFT
  236. return abs(diff) * (direction,)
  237. for src in range(COLUMNS):
  238. mov1 = make_move(src - self.exa)
  239. yield mov1 + (SWAP,)
  240. yield mov1 + (GRAB, SWAP, DROP)
  241. yield mov1 + (SWAP, GRAB, SWAP, DROP)
  242. for dst in range(COLUMNS):
  243. if dst != src:
  244. mov2 = make_move(dst - src)
  245. for get in GET:
  246. for put in PUT:
  247. yield mov1 + get + mov2 + put
  248. def solve(self):
  249. assert self.exa is not None
  250. if self.held != NOBLOCK:
  251. return (DROP,)
  252. if self.nrows() < MIN_ROWS:
  253. return ()
  254. if self.grabbing_of_dropping():
  255. return ()
  256. if self.has_explosion():
  257. return ()
  258. score, moves = min(self.score_moves())
  259. if not moves and max(self.colsizes()) <= MAX_SPEED_ROWS:
  260. return (SPEED,)
  261. return moves
  262. def print(self):
  263. print_board(self.blocks, self.exa, self.held)
  264. def tostring(self):
  265. stream = io.StringIO()
  266. with redirect_stdout(stream):
  267. self.print()
  268. return stream.getvalue()
  269. def nrows(self):
  270. return len(self.blocks) // COLUMNS
  271. def moves_to_keys(moves):
  272. return ''.join('jjkadl'[move] for move in moves)
  273. if __name__ == '__main__':
  274. import sys
  275. from PIL import Image
  276. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  277. state = State.detect(board)
  278. print('parsed:')
  279. state.print()
  280. print()
  281. print('empty cols:', state.empty_column_score())
  282. print()
  283. start = time.time()
  284. moves = state.solve()
  285. end = time.time()
  286. print('moves:', moves_to_keys(moves))
  287. print('elapsed:', end - start)
  288. print()
  289. print('target after moves:')
  290. points, newstate = state.simulate(moves)
  291. newstate.print()
  292. print()
  293. for score, moves in sorted(state.score_moves()):
  294. print('move %18s:' % moves_to_keys(moves), score)