strategy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import io
  2. import time
  3. from collections import deque
  4. from contextlib import redirect_stdout
  5. from itertools import combinations, islice
  6. from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  7. detect_held, print_board, is_basic, is_bomb, \
  8. bomb_to_basic
  9. GRAB, DROP, SWAP, LEFT, RIGHT, SPEED = range(6)
  10. MOVE_DELAYS = (
  11. # in milliseconds
  12. 50, # GRAB
  13. 50, # DROP
  14. 50, # SWAP
  15. 30, # LEFT
  16. 30, # RIGHT
  17. 30, # SPEED
  18. )
  19. GET = ((GRAB,), (SWAP, GRAB), (GRAB, SWAP, DROP, SWAP, GRAB))
  20. PUT = ((DROP,), (DROP, SWAP), (DROP, SWAP, GRAB, SWAP, DROP))
  21. MIN_BASIC_GROUP_SIZE = 4
  22. MIN_BOMB_GROUP_SIZE = 2
  23. POINTS_DEPTH = 3
  24. FRAG_DEPTH = 4
  25. DEFRAG_PRIO = 4
  26. COLSIZE_PRIO = 5
  27. COLSIZE_PANIC = 8
  28. COLSIZE_MAX = 9
  29. BOMB_POINTS = 5
  30. class State:
  31. def __init__(self, blocks, exa, held, colskip=None):
  32. self.blocks = blocks
  33. self.exa = exa
  34. self.held = held
  35. self.moves = ()
  36. self.score = ()
  37. self.nrows = len(self.blocks) // COLUMNS
  38. if colskip is None:
  39. colskip = []
  40. for col in range(COLUMNS):
  41. for row in range(self.nrows):
  42. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  43. colskip.append(row)
  44. break
  45. else:
  46. colskip.append(self.nrows)
  47. self.colskip = colskip
  48. def grabbing_or_dropping(self):
  49. skip = self.colskip[self.exa]
  50. i = (skip + 1) * COLUMNS + self.exa
  51. return i < len(self.blocks) and self.blocks[i] == NOBLOCK
  52. def iter_columns(self):
  53. def gen_col(col):
  54. for row in range(self.nrows):
  55. i = row * COLUMNS + col
  56. if self.blocks[i] != NOBLOCK:
  57. yield i
  58. for col in range(COLUMNS):
  59. yield gen_col(col)
  60. @classmethod
  61. def detect(cls, board, pad=2):
  62. blocks = [NOBLOCK] * (COLUMNS * pad) + list(detect_blocks(board))
  63. exa = detect_exa(board)
  64. held = detect_held(board, exa)
  65. return cls(blocks, exa, held)
  66. def copy(self):
  67. return State(list(self.blocks), self.exa, self.held, list(self.colskip))
  68. def causes_panic(self):
  69. return self.max_colsize() >= COLSIZE_PANIC
  70. def max_colsize(self):
  71. return self.nrows - self.empty_rows()
  72. def empty_rows(self):
  73. for i, block in enumerate(self.blocks):
  74. if block != NOBLOCK:
  75. return i // COLUMNS
  76. return 0
  77. def holes(self):
  78. start_row = self.empty_rows()
  79. score = 0
  80. for col in range(COLUMNS):
  81. for row in range(start_row, self.nrows):
  82. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  83. break
  84. score += row - start_row + 1
  85. return score
  86. def score(self, points, moves, prev):
  87. prev_colsize = prev.nrows - 2
  88. #delay = moves_delay(moves)
  89. delay = len(moves)
  90. # Don't care about defragging for few rows, just score points quickly.
  91. # This saves computation time which in turn makes for nice combos when
  92. # the bot speeds around throwing blocks on top of each other.
  93. if prev_colsize < DEFRAG_PRIO:
  94. return -points, delay
  95. holes = self.holes()
  96. frag = 0 if points else self.fragmentation()
  97. # When rows start stacking up, start defragmenting colors to make
  98. # opportunities for scoring points.
  99. if prev_colsize < COLSIZE_PRIO:
  100. return -points, frag, holes, delay
  101. # When they stack higher, start moving blocks down into holes before
  102. # continuing to defragment.
  103. if prev_colsize < COLSIZE_PANIC:
  104. panic = int(self.causes_panic())
  105. return panic, -points, holes, frag, delay
  106. # Column heights are getting out of hand, just move shit DOWN.
  107. return holes, delay, -points, frag
  108. def find_unmovable_blocks(self):
  109. unmoveable = set()
  110. bombed = set()
  111. for block, group in self.find_groups():
  112. if is_basic(block) and len(group) >= MIN_BASIC_GROUP_SIZE:
  113. for i in group:
  114. unmoveable.add(i)
  115. elif is_bomb(block) and len(group) >= MIN_BOMB_GROUP_SIZE:
  116. bombed.add(bomb_to_basic(block))
  117. for i in group:
  118. unmoveable.add(i)
  119. for i, block in enumerate(self.blocks):
  120. if block in bombed:
  121. unmoveable.add(i)
  122. return unmoveable
  123. def move(self, moves):
  124. s = self.copy() if moves else self
  125. s.moves = moves
  126. # avoid swapping/grabbing currently exploding items
  127. #unmoveable = s.find_unmovable_blocks()
  128. s.placed = set()
  129. s.grabbed = {}
  130. for move in moves:
  131. if move == LEFT:
  132. assert s.exa > 0
  133. s.exa -= 1
  134. elif move == RIGHT:
  135. assert s.exa < COLUMNS - 1
  136. s.exa += 1
  137. elif move == GRAB:
  138. assert s.held == NOBLOCK
  139. row = s.colskip[s.exa]
  140. assert row < s.nrows
  141. i = row * COLUMNS + s.exa
  142. #assert i not in unmoveable
  143. s.held = s.blocks[i]
  144. s.blocks[i] = NOBLOCK
  145. s.grabbed[i] = s.held
  146. s.colskip[s.exa] += 1
  147. elif move == DROP:
  148. assert s.held != NOBLOCK
  149. row = s.colskip[s.exa]
  150. assert row > 0
  151. i = (row - 1) * COLUMNS + s.exa
  152. s.blocks[i] = s.held
  153. s.held = NOBLOCK
  154. s.placed.add(i)
  155. s.colskip[s.exa] -= 1
  156. elif move == SWAP:
  157. row = s.colskip[s.exa]
  158. i = row * COLUMNS + s.exa
  159. j = i + COLUMNS
  160. assert j < len(s.blocks)
  161. #assert i not in unmoveable
  162. #assert j not in unmoveable
  163. bi = s.blocks[i]
  164. bj = s.blocks[j]
  165. if bi != bj:
  166. s.blocks[i] = bj
  167. s.blocks[j] = bi
  168. s.grabbed[i] = bi
  169. s.grabbed[j] = bj
  170. s.placed.add(i)
  171. s.placed.add(j)
  172. if moves and self.max_colsize() < COLSIZE_MAX:
  173. assert s.max_colsize() <= COLSIZE_MAX
  174. return s
  175. def find_groups(self, depth=POINTS_DEPTH, minsize=2):
  176. def follow_group(i, block, group):
  177. if self.blocks[i] == block and i not in visited:
  178. group.append(i)
  179. visited.add(i)
  180. for nb in self.neighbors(i):
  181. follow_group(nb, block, group)
  182. visited = set()
  183. for col in self.iter_columns():
  184. for i in islice(col, depth):
  185. block = self.blocks[i]
  186. group = []
  187. follow_group(i, block, group)
  188. if len(group) >= minsize:
  189. yield block, group
  190. def neighbors(self, i):
  191. row, col = divmod(i, COLUMNS)
  192. if col > 0 and self.blocks[i - 1] != NOBLOCK:
  193. yield i - 1
  194. if col < COLUMNS - 1 and self.blocks[i + 1] != NOBLOCK:
  195. yield i + 1
  196. if row > 0 and self.blocks[i - COLUMNS] != NOBLOCK:
  197. yield i - COLUMNS
  198. if row < self.nrows - 1 and self.blocks[i + COLUMNS] != NOBLOCK:
  199. yield i + COLUMNS
  200. def fragmentation(self, depth=FRAG_DEPTH):
  201. """
  202. Minimize the sum of dist(i,j) between all blocks i,j of the same color.
  203. Magnify vertical distances to avoid column stacking.
  204. """
  205. def dist(i, j):
  206. yi, xi = divmod(i, COLUMNS)
  207. yj, xj = divmod(j, COLUMNS)
  208. # for blocks in the same group, only count vertical distance so
  209. # that groups are spread out horizontally
  210. if groups[i] == groups[j]:
  211. return abs(yj - yi)
  212. return abs(xj - xi) + abs(yj - yi) * 2 - 1
  213. colors = {}
  214. groups = {}
  215. groupsizes = {}
  216. for groupid, (block, group) in enumerate(self.find_groups(depth, 1)):
  217. colors.setdefault(block, []).extend(group)
  218. for i in group:
  219. groups[i] = groupid
  220. groupsizes[i] = len(group)
  221. return sum(dist(i, j)
  222. for block, color in colors.items()
  223. for i, j in combinations(color, 2))
  224. def points(self):
  225. def group_size(start):
  226. work = [start]
  227. visited.add(start)
  228. size = 0
  229. block = self.blocks[start]
  230. while work:
  231. i = work.pop()
  232. # avoid giving points to moving a block within the same group
  233. if self.grabbed.get(i, None) == block:
  234. return 0
  235. if self.blocks[i] == block:
  236. size += 1
  237. for nb in self.neighbors(i):
  238. if nb not in visited:
  239. visited.add(nb)
  240. work.append(nb)
  241. return size
  242. points = 0
  243. visited = set()
  244. for i in self.placed:
  245. if i not in visited:
  246. block = self.blocks[i]
  247. size = group_size(i)
  248. if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
  249. points += size
  250. elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
  251. points += BOMB_POINTS
  252. return -points
  253. def has_explosion(self):
  254. return any(is_bomb(block) and
  255. any(self.blocks[j] == block for j in self.neighbors(i))
  256. for i, block in enumerate(self.blocks))
  257. def gen_moves(self):
  258. yield ()
  259. def shift_exa(diff):
  260. direction = RIGHT if diff > 0 else LEFT
  261. return abs(diff) * (direction,)
  262. ignore_exa_column = self.grabbing_or_dropping()
  263. for src in range(COLUMNS):
  264. mov1 = shift_exa(src - self.exa)
  265. if mov1 or not ignore_exa_column:
  266. yield mov1 + (SWAP,)
  267. yield mov1 + (GRAB, SWAP, DROP)
  268. yield mov1 + (SWAP, GRAB, SWAP, DROP)
  269. yield mov1 + (GRAB, SWAP, DROP, SWAP)
  270. yield mov1 + (SWAP, GRAB, SWAP, DROP, SWAP)
  271. for dst in range(COLUMNS):
  272. if dst != src:
  273. mov2 = shift_exa(dst - src)
  274. for get in GET:
  275. for put in PUT:
  276. yield mov1 + get + mov2 + put
  277. def gen_valid_moves(self):
  278. for moves in self.gen_moves():
  279. try:
  280. yield self.move(moves)
  281. except AssertionError:
  282. pass
  283. def solve(self):
  284. assert self.exa is not None
  285. if self.held != NOBLOCK:
  286. return self.move((DROP,))
  287. valid = deque(self.gen_valid_moves())
  288. if len(valid) == 0:
  289. return self.move(())
  290. best_score = ()
  291. for key in self.score_keys():
  292. if len(valid) == 1:
  293. break
  294. for state in valid:
  295. state.score = key(state)
  296. best = min(state.score for state in valid)
  297. best_score += (best,)
  298. for i in range(len(valid)):
  299. state = valid.popleft()
  300. if state.score == best:
  301. valid.append(state)
  302. best = valid.popleft()
  303. best.score = best_score
  304. return best
  305. def score_keys(self):
  306. cls = self.__class__
  307. colsize = self.nrows - 2
  308. if colsize >= COLSIZE_PANIC:
  309. return cls.holes, cls.nmoves, cls.points, cls.fragmentation
  310. if colsize >= COLSIZE_PRIO:
  311. return cls.causes_panic, cls.points, cls.holes, cls.fragmentation, cls.nmoves
  312. return cls.points, cls.fragmentation, cls.holes, cls.nmoves
  313. def print(self):
  314. print_board(self.blocks, self.exa, self.held)
  315. def tostring(self):
  316. stream = io.StringIO()
  317. with redirect_stdout(stream):
  318. self.print()
  319. return stream.getvalue()
  320. def has_same_exa(self, state):
  321. return self.exa == state.exa and self.held == state.held
  322. def nmoves(self):
  323. return len(self.moves)
  324. def delay(self):
  325. return moves_delay(self.moves)
  326. def keys(self):
  327. return moves_to_keys(self.moves)
  328. def __lt__(self, other):
  329. return self.score < other.score
  330. def loops(self, prev):
  331. return self.moves and \
  332. self.exa == prev.exa and \
  333. self.moves == prev.moves and \
  334. self.score == prev.score
  335. def move_to_key(move):
  336. return 'jjkadl'[move]
  337. def moves_to_keys(moves):
  338. return ''.join(move_to_key(move) for move in moves)
  339. def moves_delay(moves):
  340. return sum(MOVE_DELAYS[m] for m in moves)
  341. if __name__ == '__main__':
  342. import sys
  343. from PIL import Image
  344. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  345. state = State.detect(board)
  346. print('parsed:')
  347. state.print()
  348. print()
  349. start = time.time()
  350. newstate = state.solve()
  351. end = time.time()
  352. print('best move:', newstate.keys())
  353. print('score:', newstate.score)
  354. print('elapsed:', round((end - start) * 1000, 1), 'ms')
  355. print()
  356. print('target after move:')
  357. newstate.print()
  358. print()
  359. #for solution in sorted(state.solutions()):
  360. # print('move %18s:' % solution.keys(), solution.score)