strategy.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. import io
  2. import time
  3. from collections import deque
  4. from contextlib import redirect_stdout
  5. from copy import copy
  6. from itertools import combinations
  7. from detection import COLUMNS, NOBLOCK, detect_blocks, detect_exa, \
  8. detect_held, print_board, is_basic, is_bomb
  9. GRAB, DROP, SWAP, LEFT, RIGHT, SPEED = range(6)
  10. MOVE_DELAYS = (
  11. # in milliseconds
  12. 50, # GRAB
  13. 50, # DROP
  14. 50, # SWAP
  15. 30, # LEFT
  16. 30, # RIGHT
  17. 30, # SPEED
  18. )
  19. MIN_BASIC_GROUP_SIZE = 4
  20. MIN_BOMB_GROUP_SIZE = 2
  21. COLSIZE_PRIO = 5
  22. COLSIZE_PANIC = 8
  23. BOMB_POINTS = 5
  24. class State:
  25. def __init__(self, blocks, exa, held, colskip, busy, moves, blockgroups, groups):
  26. self.blocks = blocks
  27. self.exa = exa
  28. self.held = held
  29. self.colskip = colskip
  30. self.busy = busy
  31. self.moves = moves
  32. self.blockgroups = blockgroups
  33. self.groups = groups
  34. self.nrows = len(blocks) // COLUMNS
  35. @classmethod
  36. def detect(cls, board, pad=2):
  37. blocks = [NOBLOCK] * (COLUMNS * pad) + list(detect_blocks(board))
  38. exa = detect_exa(board)
  39. held = detect_held(board, exa)
  40. colskip = get_colskip(blocks)
  41. busy = get_busy(blocks, colskip)
  42. blockgroups, groups = get_groups(blocks)
  43. return cls(bytearray(blocks), exa, held, bytearray(colskip), busy, (),
  44. bytearray(blockgroups), groups)
  45. def copy(self, deep):
  46. mcopy = copy if deep else lambda x: x
  47. return self.__class__(mcopy(self.blocks),
  48. self.exa,
  49. self.held,
  50. mcopy(self.colskip),
  51. self.busy,
  52. self.moves,
  53. mcopy(self.blockgroups),
  54. mcopy(self.groups))
  55. def colbusy(self, col):
  56. return (self.busy >> col) & 1
  57. def colrows(self, col):
  58. return self.nrows - self.colskip[col]
  59. def maxrows(self):
  60. return max(map(self.colrows, range(COLUMNS)))
  61. def causes_panic(self):
  62. return self.max_colsize() >= COLSIZE_PANIC
  63. def max_colsize(self):
  64. return self.nrows - self.empty_rows()
  65. def empty_rows(self):
  66. for i, block in enumerate(self.blocks):
  67. if block != NOBLOCK:
  68. return i // COLUMNS
  69. return 0
  70. def holes(self):
  71. #return self.imbalance()
  72. start_row = self.empty_rows()
  73. score = 0
  74. for col in range(COLUMNS):
  75. for row in range(start_row, self.nrows):
  76. if self.blocks[row * COLUMNS + col] != NOBLOCK:
  77. break
  78. score += row - start_row + 1
  79. return score
  80. def locked(self, i):
  81. size, block = self.groups[self.blockgroups[i]]
  82. if block == NOBLOCK:
  83. return False
  84. if is_basic(block):
  85. return size >= MIN_BASIC_GROUP_SIZE
  86. assert is_bomb(block)
  87. return size >= MIN_BOMB_GROUP_SIZE
  88. def move(self, *moves):
  89. deep = any(move in (GRAB, DROP, SWAP) for move in moves)
  90. s = self.copy(deep)
  91. s.moves += moves
  92. for move in moves:
  93. if move == LEFT:
  94. assert s.exa > 0
  95. s.exa -= 1
  96. elif move == RIGHT:
  97. assert s.exa < COLUMNS - 1
  98. s.exa += 1
  99. elif move == GRAB:
  100. assert not s.colbusy(s.exa)
  101. assert s.held == NOBLOCK
  102. row = s.colskip[s.exa]
  103. assert row < s.nrows
  104. i = row * COLUMNS + s.exa
  105. if not s.locked(i):
  106. s.colskip[s.exa] += 1
  107. s.held = s.remove(i)
  108. elif move == DROP:
  109. if s.held != NOBLOCK:
  110. row = s.colskip[s.exa]
  111. assert row > 0
  112. i = (row - 1) * COLUMNS + s.exa
  113. s.colskip[s.exa] -= 1
  114. s.place(i, s.held)
  115. s.held = NOBLOCK
  116. elif move == SWAP:
  117. assert not s.colbusy(s.exa)
  118. row = s.colskip[s.exa]
  119. i = row * COLUMNS + s.exa
  120. j = i + COLUMNS
  121. if j < len(s.blocks) and not s.locked(i) and not s.locked(j):
  122. bi = s.blocks[i]
  123. bj = s.blocks[j]
  124. assert bi != NOBLOCK
  125. assert bj != NOBLOCK
  126. if bi != bj:
  127. s.blocks[i] = NOBLOCK
  128. s.blocks[j] = NOBLOCK
  129. visited = set()
  130. s.ungroup(i, visited)
  131. s.ungroup(j, visited)
  132. s.place(j, bi)
  133. s.place(i, bj)
  134. return s
  135. def remove(self, i):
  136. block = self.blocks[i]
  137. assert block != NOBLOCK
  138. self.blocks[i] = NOBLOCK
  139. self.ungroup(i, set())
  140. return block
  141. def ungroup(self, i, visited):
  142. assert self.blocks[i] == NOBLOCK
  143. oldid = self.blockgroups[i]
  144. self.blockgroups[i] = 0
  145. self.groups[oldid] = 0, NOBLOCK
  146. for nb in neighbors(i, self.blocks):
  147. if self.blockgroups[nb] == oldid:
  148. self.create_group(nb, visited)
  149. def place(self, i, block):
  150. assert block != NOBLOCK
  151. assert self.blocks[i] == NOBLOCK
  152. self.blocks[i] = block
  153. self.create_group(i, set())
  154. def create_group(self, start, visited):
  155. def scan(i):
  156. if i not in visited:
  157. yield i
  158. visited.add(i)
  159. for nb in neighbors(i, self.blocks):
  160. if self.blocks[nb] == block:
  161. yield from scan(nb)
  162. block = self.blocks[start]
  163. group = tuple(scan(start))
  164. if group:
  165. newid = len(self.groups)
  166. self.groups.append((len(group), block))
  167. for j in group:
  168. self.blockgroups[j] = newid
  169. def fragmentation(self):
  170. """
  171. Minimize the sum of dist(i,j) between all blocks i,j of the same color.
  172. Magnify vertical distances to avoid column stacking.
  173. Add the vertical components of i and j to prioritize lower columns.
  174. """
  175. def dist(i, j):
  176. yi, xi = divmod(i, COLUMNS)
  177. yj, xj = divmod(j, COLUMNS)
  178. # for blocks in the same group, only count vertical distance so that
  179. # groups are spread out horizontally
  180. if self.blockgroups[i] == self.blockgroups[j]:
  181. return abs(yj - yi)
  182. return abs(xj - xi) + abs(yj - yi) * 2 - 1 + \
  183. (self.nrows - yi) + (self.nrows - yj)
  184. colors = {}
  185. for i, block in enumerate(self.blocks):
  186. if block != NOBLOCK:
  187. colors.setdefault(block, []).append(i)
  188. return sum(dist(i, j) ** 2
  189. for blocks in colors.values()
  190. for i, j in combinations(blocks, 2))
  191. def points(self):
  192. points = 0
  193. for size, block in self.groups:
  194. if is_basic(block) and size >= MIN_BASIC_GROUP_SIZE:
  195. points += size
  196. elif is_bomb(block) and size >= MIN_BOMB_GROUP_SIZE:
  197. points += self.maxrows()
  198. return -points
  199. def imbalance(self):
  200. colsizes = tuple(self.nrows - skip
  201. for col, skip in enumerate(self.colskip))
  202. #if not self.colbusy(col))
  203. mean = sum(colsizes) / len(colsizes)
  204. return sum((size - mean) ** 2 for size in colsizes)
  205. def gen_moves(self):
  206. if self.held == NOBLOCK:
  207. yield self
  208. for src in self.gen_shift(True):
  209. yield from src.gen_stationary()
  210. for get in src.gen_get():
  211. for dst in get.gen_shift(False):
  212. yield from dst.gen_put()
  213. else:
  214. for dst in self.gen_shift(True):
  215. yield from dst.gen_put()
  216. def gen_shift(self, allow_noshift):
  217. if allow_noshift:
  218. yield self
  219. left = self
  220. for i in range(self.exa):
  221. left = left.move(LEFT)
  222. yield left
  223. right = self
  224. for i in range(COLUMNS - self.exa - 1):
  225. right = right.move(RIGHT)
  226. yield right
  227. def gen_stationary(self):
  228. # SWAP
  229. # GRAB, SWAP, DROP
  230. # GRAB, SWAP, DROP, SWAP
  231. # SWAP, GRAB, SWAP, DROP
  232. # SWAP, GRAB, SWAP, DROP, SWAP
  233. if not self.colbusy(self.exa):
  234. avail = self.colrows(self.exa)
  235. if avail >= 2:
  236. swap = self.move(SWAP)
  237. yield swap
  238. if avail >= 3:
  239. grab = self.move(GRAB, SWAP, DROP)
  240. yield grab
  241. yield grab.move(SWAP)
  242. swap = swap.move(GRAB, SWAP, DROP)
  243. yield swap
  244. yield swap.move(SWAP)
  245. def gen_get(self):
  246. # GRAB
  247. # SWAP, GRAB
  248. # GRAB, SWAP, DROP, SWAP, GRAB
  249. if not self.colbusy(self.exa):
  250. avail = self.colrows(self.exa)
  251. if avail >= 1:
  252. grab = self.move(GRAB)
  253. yield grab
  254. if avail >= 2:
  255. yield self.move(SWAP, GRAB)
  256. if avail >= 3:
  257. yield grab.move(SWAP, DROP, SWAP, GRAB)
  258. def gen_put(self):
  259. # DROP
  260. # DROP, SWAP
  261. # DROP, SWAP, GRAB, SWAP, DROP
  262. if not self.colbusy(self.exa):
  263. avail = self.colrows(self.exa)
  264. drop = self.move(DROP)
  265. yield drop
  266. if avail >= 1:
  267. swap = drop.move(SWAP)
  268. yield swap
  269. if avail >= 2:
  270. yield swap.move(GRAB, SWAP, DROP)
  271. def solve(self):
  272. assert self.exa is not None
  273. pool = deque(self.gen_moves())
  274. assert len(pool) > 0
  275. best_score = ()
  276. for key in self.score_keys():
  277. if len(pool) == 1:
  278. break
  279. for state in pool:
  280. state.score = key(state)
  281. best = min(state.score for state in pool)
  282. best_score += (best,)
  283. for i in range(len(pool)):
  284. state = pool.popleft()
  285. assert state.held == NOBLOCK
  286. if state.score == best:
  287. pool.append(state)
  288. best = pool.popleft()
  289. best.score = best_score
  290. return best
  291. def score_keys(self):
  292. cls = self.__class__
  293. colsize = self.nrows - 2
  294. if colsize >= COLSIZE_PANIC:
  295. return cls.holes, cls.points, cls.nmoves, cls.fragmentation
  296. if colsize >= COLSIZE_PRIO:
  297. return cls.causes_panic, cls.points, cls.holes, \
  298. cls.fragmentation, cls.nmoves
  299. return cls.points, cls.fragmentation, cls.holes, cls.nmoves
  300. def print_groupsizes(self):
  301. for start in range(len(self.blockgroups) - COLUMNS, -1, -COLUMNS):
  302. print(' '.join('%-2d' % self.groups[g][0]
  303. for g in self.blockgroups[start:start + COLUMNS]))
  304. def print_groups(self):
  305. for start in range(len(self.blockgroups) - COLUMNS, -1, -COLUMNS):
  306. print(' '.join('%-2d' % g
  307. for g in self.blockgroups[start:start + COLUMNS]))
  308. def print(self):
  309. print_board(self.blocks, self.exa, self.held)
  310. def tostring(self):
  311. stream = io.StringIO()
  312. with redirect_stdout(stream):
  313. self.print()
  314. return stream.getvalue()
  315. def has_same_exa(self, state):
  316. return self.exa == state.exa and self.held == state.held
  317. def nmoves(self):
  318. return len(self.moves)
  319. def delay(self):
  320. return moves_delay(self.moves)
  321. def keys(self):
  322. return moves_to_keys(self.moves)
  323. def loops(self, prev):
  324. return self.moves and \
  325. self.exa == prev.exa and \
  326. self.held == prev.held and \
  327. self.moves == prev.moves and \
  328. self.score == prev.score
  329. def get_colskip(blocks):
  330. def colskip(col):
  331. for row, block in enumerate(blocks[col::COLUMNS]):
  332. if block != NOBLOCK:
  333. return row
  334. return len(blocks) // COLUMNS
  335. return list(map(colskip, range(COLUMNS)))
  336. def get_busy(blocks, colskip):
  337. mask = 0
  338. for col, skip in enumerate(colskip):
  339. start = (skip + 1) * COLUMNS + col
  340. colbusy = NOBLOCK in blocks[start::COLUMNS]
  341. mask |= colbusy << col
  342. return mask
  343. def scan_group(blocks, i, block, visited):
  344. yield i
  345. visited.add(i)
  346. for nb in neighbors(i, blocks):
  347. if blocks[nb] == block and nb not in visited:
  348. yield from scan_group(blocks, nb, block, visited)
  349. def get_groups(blocks):
  350. blockgroups = [0] * len(blocks)
  351. groups = [(0, NOBLOCK)]
  352. visited = set()
  353. for i, block in enumerate(blocks):
  354. if block != NOBLOCK and i not in visited:
  355. groupid = len(groups)
  356. group = tuple(scan_group(blocks, i, block, visited))
  357. groups.append((len(group), block))
  358. for j in group:
  359. blockgroups[j] = groupid
  360. return blockgroups, groups
  361. def neighbors(i, blocks):
  362. y, x = divmod(i, COLUMNS)
  363. if x > 0:
  364. yield i - 1
  365. if x < COLUMNS - 1:
  366. yield i + 1
  367. if y > 0:
  368. yield i - COLUMNS
  369. if y < len(blocks) // COLUMNS - 1:
  370. yield i + COLUMNS
  371. def move_to_key(move):
  372. return 'jjkadl'[move]
  373. def moves_to_keys(moves):
  374. return ''.join(move_to_key(move) for move in moves)
  375. def moves_delay(moves):
  376. return sum(MOVE_DELAYS[m] for m in moves)
  377. if __name__ == '__main__':
  378. import sys
  379. from PIL import Image
  380. board = Image.open('screens/board%d.png' % int(sys.argv[1])).convert('HSV')
  381. state = State.detect(board)
  382. print('parsed:')
  383. state.print()
  384. print()
  385. print('groups:')
  386. state.print_groups()
  387. print()
  388. print('groupsizes:')
  389. state.print_groupsizes()
  390. print()
  391. start = time.time()
  392. newstate = state.solve()
  393. end = time.time()
  394. print('best move:', newstate.keys())
  395. print('score:', newstate.score)
  396. print('elapsed:', round((end - start) * 1000, 1), 'ms')
  397. print()
  398. print('target after move:')
  399. newstate.print()
  400. #print()
  401. #print('generated moves:')
  402. #for state in state.gen_moves():
  403. # print(state.keys())