Source code for miniosl.ui

"""primary user interface of miniosl"""
from __future__ import annotations
import miniosl
from minioslcc import MiniRecord, BaseState, State, Square, Move
import miniosl.drawing
import numpy as np
import os.path
import copy
import urllib
import logging
from typing import Tuple

Value_Scale = 1000


def is_in_notebook() -> bool:
    """detect run inside ipython notebook"""
    # https://stackoverflow.com/questions/15411967/
    try:
        shell = get_ipython().__class__
        if 'google.colab' in str(shell):
            return True
        shell = shell.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter


def default_device():
    try:
        import torch
        if torch.cuda.is_available():
            return "cuda"
    except NameError:
        pass
    return "cpu"


[docs]class UI: """ general interface for \ shogi state (:py:class:`State` , board position + pieces in hand), enhanced with move history (:py:class:`MiniRecord`) and other utilities :param init: initial contents handled by :py:meth:`load_record` :param prefer_text: preference of state visualization :param default_format: preference in `__str__` make the initial state in shogi >>> shogi = miniosl.UI() or read csa/usi file (local or web) >>> url = 'http://live4.computer-shogi.org/wcsc33/kifu/WCSC33+F7_1-900-5F+dlshogi+Ryfamate+20230505161013.csa' >>> shogi = miniosl.UI(url) >>> _ = shogi.go(10) >>> print(shogi.to_csa(), end='') P1-KY-KE * -KI * * -GI-KE-KY P2 * -HI-GI * -OU * -KI-KA * P3-FU * -FU-FU-FU-FU-FU-FU-FU P4 * * * * * * * * * P5 * -FU * * * * * +FU * P6+FU * * * * * * * * P7 * +FU+FU+FU+FU+FU+FU * +FU P8 * +KA+KI * * * +GI+HI * P9+KY+KE+GI * +OU+KI * +KE+KY + """ default_prefer_text = False # shared preference on board representation def __init__(self, init: str | BaseState | MiniRecord = '', *, prefer_text=None, default_format='usi'): self.default_format = default_format self.opening_tree = None self.nn = None self._features = None self._infer_result = None self.model = None self.fig = None self.prefer_text = ( UI.default_prefer_text if prefer_text is None else prefer_text ) if isinstance(init, UI): self._record = copy.copy(init._record) self.replay(init.cur) return self.load_record(init)
[docs] def load_record_set(self, path: str, idx: int): """load idx-th game record in :py:class:`RecordSet` :param path: filepath for `.npz` generated by \ :py:meth:`RecordSet.save_npz` or sfen text """ if path.endswith('.npz'): set = miniosl.RecordSet.from_npz(path, limit=idx+1) else: set = miniosl.RecordSet.from_usi_file(path) if idx >= len(set.records): raise ValueError(f'idx {idx} >= len {len(set.records)}' + f' of record set {str}') self.load_record(set.records[idx])
[docs] def load_record(self, src: str | BaseState | MiniRecord = ''): """load a game record from various sources. :param src: :py:class:`BaseState` or :py:class:`MiniRecord` \ or URL or filepath containing `.csa` or usi. """ self._record = None if isinstance(src, MiniRecord): self._record = copy.copy(src) else: self._record = MiniRecord() if src == '': self._record.set_initial_state(State()) # default elif isinstance(src, BaseState): self._record.set_initial_state(src) elif isinstance(src, str): if src.startswith('http') and src.endswith('csa'): with urllib.request.urlopen(src) as response: the_csa = response.read().decode('utf-8') self._record = miniosl.csa_record(the_csa) elif os.path.isfile(src): if src.endswith('.csa'): self._record = miniosl.csa_file(src) else: self._record = miniosl.usi_file(src) elif len(src) >= 8: if src[:2] == 'P1': self._record.set_initial_state(miniosl.csa_board(src)) else: self._record = miniosl.usi_record(src) else: raise ValueError(src+' not expected') else: raise ValueError(src+' unexpected type') return self.replay(0)
def __repr__(self) -> str: return "<UI '" + self.to_usi_history() + "'>" def __str__(self): return self._state.to_csa() \ if self.default_format == 'csa' else self._state.to_usi() def __copy__(self): return UI(self) def __deepcopy__(self, dict): return UI(self) def __len__(self) -> int: return self._record.state_size() # state-size, that is, move-size + 1 def hint(self, show_hint: bool = True, **kwargs): if show_hint and not self.prefer_text: self.to_img(**kwargs) # return self.fig return self.to_usi() # delegation for self._record
[docs] def to_usi_history(self) -> str: """show the history in usi >>> shogi = miniosl.UI() >>> _ = shogi.make_move('+2726FU') >>> _ = shogi.make_move('-3334FU') >>> shogi.to_usi() 'sfen lnsgkgsnl/1r5b1/pppppp1pp/6p2/9/7P1/PPPPPPP1P/1B5R1/LNSGKGSNL b - 1' >>> shogi.to_usi_history() 'startpos moves 2g2f 3c3d' """ return self._record.to_usi()
[docs] def previous_repeat_index(self) -> int: """the latest repeating state the history""" return self._record.previous_repeat_index(self.cur) \ if self.cur > 0 else 0
[docs] def repeat_count(self) -> int: """number of occurrence of the state in the history""" return self._record.repeat_count(self.cur) if len(self._record) else 0
[docs] def to_anim(self, *args, **kwargs): """make matplotlib animation showing current game record""" return self._record.to_anim(*args, **kwargs)
# delegation for self._state
[docs] def to_move(self, move_rep: str) -> Move: """interpret string as a Move""" return self._state.to_move(move_rep)
def read_japanese_move(self, move_rep: str, last_to: Square | None = None) -> Move: tbl = str.maketrans('123456789', '123456789') return self._state.read_japanese_move(move_rep.translate(tbl), last_to or Square())
[docs] def genmove(self): """generate legal moves in the state""" return self._state.genmove()
[docs] def genmove_check(self): """generate legal moves giving check to the opponent""" return self._state.genmove_check()
[docs] def in_check(self): """return whether king of the side to move is in check""" return self._state.in_check()
[docs] def in_checkmate(self): """return whether king of the side to move is in checkmate""" return self._state.in_checkmate()
def hash_code(self): return self._state.hash_code()
[docs] def to_csa(self) -> str: """show state in csa""" return self._state.to_csa()
[docs] def to_usi(self) -> str: """show state in usi""" return self._state.to_usi()
def _add_drawing_properties(self, dict): dict['last_to'] = self.last_to() dict['last_move_ja'] = self.last_move_ja dict['move_number'] = self.cur + 1 dict['repeat_distance'] = self.cur - self.previous_repeat_index() dict['repeat_count'] = self.repeat_count()
[docs] def to_img(self, *args, **kwargs): """show state in matplotlib image :return: an image shown in colab or jupyter notebooks """ self._add_drawing_properties(kwargs) self.fig = miniosl.state_to_img(self._state, *args, **kwargs) return self.fig.fig
# original / modified methods
[docs] def first(self): """go to the first state in the history""" return self.replay(0)
[docs] def last(self): """go to the last state in the history""" return self.replay(len(self._record))
def last_move_number(self): return len(self._record)
[docs] def go(self, step): """make moves (step > 0) or unmake moves (step < 0)""" if not (0 <= self.cur + step <= len(self._record)): raise ValueError(f'step out of range {self.cur}' + f' + {step} max {len(self._record)}') return self.replay(self.cur+step)
[docs] def make_move(self, move): """make a move in the current state""" if isinstance(move, str): copy = move move = self._state.to_move(move) if not move.is_normal(): raise ValueError('please check syntax '+copy) if not self._state.is_legal(move): raise ValueError('illegal move '+str(move)) if self.cur < len(self._record) \ and self._record.moves[self.cur] == move: self._do_make_move(move, self.last_to()) else: if self.cur < len(self._record): self._record = self._record.branch_at(self.cur) self._do_make_move(move, self.last_to()) self._record.append_move(move, self.in_check()) self._record.settle_repetition() # better to improve performance self.cur += 1 return self.hint()
[docs] def unmake_move(self): """undo the last move""" if self.cur < 1: raise ValueError('history empty') return self.replay(self.cur-1)
[docs] def genmove_ja(self) -> list[str]: """generate legal moves in Japanese""" return [move.to_ja(self._state) for move in self._state.genmove()]
[docs] def pv_to_ja(self, pv: list[Move]) -> list[str]: """show moves in Japanese""" ret = [] last_to = self.last_to() state = State(self._state) for move in pv: s = move.to_ja(state, last_to) ret.append(s) state.make_move(move) last_to = move.dst return ret
[docs] def last_move(self) -> Move | None: """the last move played or None""" if not (0 < self.cur <= len(self._record)): return None return self._record.moves[self.cur-1]
[docs] def last_to(self) -> Square | None: """the destination square of the last move or None >>> shogi = miniosl.UI() >>> shogi.last_to() >>> _ = shogi.make_move('+2726FU') >>> shogi.last_to() == miniosl.Square(2, 6) True """ move = self.last_move() return move.dst if move else None
[docs] def count_hand(self, color: miniosl.Player, ptype: miniosl.Ptype) -> int: """number of pieces in hand""" return self._state.count_hand(color, ptype)
[docs] def turn(self) -> miniosl.Player: """side to move""" return self._state.turn
[docs] def to_np_state_feature(self) -> np.array: """make tensor of feature for the current state (w/o history). - 44ch each 9x9 channel is responsible \ for a specific piece type and color - where each element is 1 (0) for existence (absent) of the piece \ for the first 30ch - filled by the same value indicating number of hand pieces \ for the latter 14ch - 13ch for heuristic features """ return self._state.to_np_state_feature()
[docs] def to_np_cover(self) -> np.array: """make planes to show a square is covered (1) or not (0)""" return self._state.to_np_cover()
[docs] def to_np_pack(self) -> np.array: """compress state information""" return self._state.to_np_pack()
[docs] def replay(self, idx: int, show_hint: bool = False): """move to idx-th state in the history""" self._state = State(self._record.initial_state) self._features = None self.cur = idx self.last_move_ja = None if idx == 0: return self.hint() if idx > len(self._record): raise ValueError(f'index too large {idx} > {len(self._record)}') for i, move in enumerate(self._record.moves): last_to = None if i+1 == idx: last_to = self._record.moves[i-1].dst if i > 0 else Square() self._do_make_move(move, last_to) if i+1 >= idx: break return self.hint(show_hint)
def legal_move_to_ja(self, move: Move, last_to: Square | None = None): return move.to_ja(self._state, last_to) def _do_make_move(self, move: Move, last_to: Square | None = None): self._features = None self.last_move_ja = self.legal_move_to_ja(move, last_to) self._state.make_move(move)
[docs] def load_opening_tree(self, filename): """load opening db""" self.opening_tree = miniosl.load_opening_tree(filename) logging.info(f'load opening of size {self.opening_tree.size()}')
[docs] def opening_moves(self): """retrieve or show opening moves. Note: need to load data by :py:meth:`load_opening_tree` in advance """ if self.opening_tree is None: raise RuntimeError("opening_tree not loaded") children = self.opening_tree.retrieve_children(self._state) all = sum([c[0].count() for c in children]) if not is_in_notebook(): for c in children: print(c[1], f'{c[0].count()/all*100:5.2f}% {c[0].count():5}' + f' ({c[0].black_advantage()*100:5.2f}%)', ) else: plane = np.zeros((9, 9)) for c in children: x, y = c[1].dst.to_xy() plane[y-1][x-1] = c[0].count()/all return self.hint(plane=plane)
[docs] def follow_opening(self, nth: int = 0): """make a move following the opening db""" if self.opening_tree is None: raise RuntimeError("opening_tree not loaded") children = self.opening_tree.retrieve_children(self._state) if nth >= len(children): raise IndexError(f"opening {nth} >= {len(children)}") self.make_move(children[nth][1]) return self.opening_moves()
[docs] def make_opening_db_from_sfen(self, filename: str, threshold: int = 100): """make opening tree from sfen records, and save in npz""" set = miniosl.RecordSet.from_usi_file(filename) if not filename.endswith('.npz'): set.save_npz("sfen.npz") tree = miniosl.OpeningTree.from_record_set(set, threshold) self.opening_tree = tree tree.save_npz('opening.npz')
def update_features(self): if self._features is None: history = self._record.moves[:self.cur] f = self._record.initial_state.export_features(history) self._features = f def show_channels(self, *args, **kwargs): if not hasattr(UI, 'japanese_available_in_plt'): fontname = 'Noto Sans CJK JP' import matplotlib.font_manager as fm font = fm.findfont(fontname) UI.japanese_available_in_plt = font and ("NotoSansCJK" in font) if UI.japanese_available_in_plt: import matplotlib matplotlib.rcParams['font.family'] = [fontname] return miniosl.show_channels(*args, **kwargs, japanese=UI.japanese_available_in_plt)
[docs] def show_features(self, plane_id: int | str, *args, **kwargs): """visualize features in matplotlib. :param plane_id: ``'pieces'`` | ``'hands'`` | ``'lastmove'`` \ | ``'long'`` | ``'safety'``, or integer id (in internal representation) """ self.update_features() turn = self.turn flip = turn == miniosl.white if isinstance(plane_id, str): if plane_id == "pieces": print('side-to-move PLNSBRk/gplnsbr') self.show_channels(self._features[16:], 2, 7, flip) print('opponent PLNSBRk/gplnsbr') self.show_channels(self._features, 2, 7, flip) return if plane_id == "hands": print('plnsgbr') return self.show_channels(self._features[30:], 2, 7, flip) if plane_id == "long": print('lbr+k') return self.show_channels(self._features[44:], 4, 4, flip) if plane_id == "safety": return self.show_channels(self._features[61:], 1, 3, flip) if plane_id == "safety_1": id = miniosl.channel_id['check_piece_1'] return self.show_channels(self._features[id:], 1, 6, flip) if plane_id == "safety_2": id = miniosl.channel_id['check_piece_2'] return self.show_channels(self._features[id:], 1, 6, flip) latest_history = 64 if plane_id == "lastmove": return self.show_channels(self._features[latest_history:], 1, 3, flip) plane_id = miniosl.channel_id[plane_id] if not is_in_notebook(): print(self._features[plane_id]) return self.hint(plane=self._features[plane_id], flip_if_white=True, *args, **kwargs)
[docs] def load_eval(self, path: str = "", device: str = "", torch_cfg: dict = {}): """load parameters of evaluation function from file parameters will be passed to :py:func:`inference.load`. """ import miniosl.inference variant, _ = self._state.guess_variant() if not path and miniosl.has_pretrained_eval(variant=variant): path = miniosl.pretrained_eval_path(variant=variant) path = os.path.expanduser(path) if not device: device = default_device() if os.path.exists(path): self.model = miniosl.inference.load(path, device, torch_cfg) logging.info(f'load {path=}') else: logging.warn(f'failed to load eval {path=}')
[docs] def show_top8_moves(self, verbose=False): """suggest top8 moves by policy """ _, moves = self.eval() moves = moves[:8] expl = ['' for _ in range(8)] if self.opening_tree: children = self.opening_tree.retrieve_children(self._state) table = {_[1]: _[0].count() for _ in children} count = [ table[moves[i][1]] if moves[i][1] in table else 0 for i in range(8) ] total = sum(count) cs = np.log((total + 19652 + 1) / 19652) + 1.25 for i in range(8): pb = cs * np.sqrt(total) / (count[i] + 1) pbc = pb * moves[i][0] expl[i] = f' novelty {pbc*2:.3f}' if verbose: expl[i] += f' <- {pb:.3f} * {moves[i][0]:.3f}' for i, m in enumerate(moves): print(f'{m[0]:.3f} {m[1].to_csa()}{expl[i]}')
[docs] def eval(self, verbose=False) -> Tuple[np.ndarray, list]: """return value and policy for current state. need to call :py:meth:`load_eval` in advance """ self.update_features() logits, value, aux = self.model.eval(self._features, take_softmax=False) policy = miniosl.softmax(logits.reshape(-1)).reshape(-1, 9, 9) flip = self.turn == miniosl.white if verbose: self.show_channels([np.max(policy, axis=0)], 1, 1, flip) print(f'eval = {value[0]*Value_Scale:.0f}') mp = miniosl.inference.sort_moves(self.genmove(), policy) self._infer_result = {'policy': policy, 'value': value, 'mp': mp, 'aux': aux, 'logits': logits} if verbose: for i in range(min(len(mp), 3)): print(mp[i][1], f'{mp[i][0]*100:6.1f}%', mp[i][1].to_ja(self._state)) return value*Value_Scale, mp
def gumbel_one(self, width: int = 4, cscale: int = 1): if self._features is None or self._infer_result is None: self.eval(verbose=False) history = self._record.moves[:self.cur] mp = self._infer_result['mp'] logits = self._infer_result['logits'].reshape(-1) values = [] for i in range(min(len(mp), width)): move = mp[i][1] f, terminal = \ self._record.initial_state.export_features_after_move( history, move ) _, v, *_ = self.model.infer_one(f) logit = logits[move.policy_move_label] v = -v[0].item() # negamax if terminal == miniosl.win_result(self.turn()): v = 1.0 elif terminal == miniosl.win_result(self.turn().alt): v = -1.0 elif terminal == miniosl.Draw: v = 0 values.append([logit + miniosl.transformQ(v, cscale=cscale), move, logit, v]) values.sort(key=lambda e: -e[0]) self._infer_result[f'gumbel{width}'] = values return values
[docs] def mcts(self, budget: int = 100, report: int = 4, *, batch_size=1, resume_root=None): """run mcts - budegt: number of simulations - report: number of (interim) report during search """ import tqdm if not self.model: raise RuntimeError('load model') record = self._record if self.cur < record.move_size(): record = record.branch_at(self.cur) mgr = miniosl.GameManager.from_record(record) root = resume_root report = min(budget, report) sim_done = 0 total = max(1, report) with tqdm.tqdm(total=total, disable=(report == 0)) as pbar: for i in range(total): step_forward = budget*(i+1)//max(1, report) root = miniosl.run_mcts(mgr, step_forward-sim_done, self.model, batch_size=batch_size, root=root) sim_done = step_forward pv = root.pv(ratio=0.25) if pv: m = pv[0] pbar.set_postfix( pv=f'{m[0].to_csa()} {m[1]:.3f}' + f' {"".join([_[0].to_csa() for _ in pv[1:]])}' + f' {m[2]/sim_done*100:4.1f}%' ) pbar.update(1) if sim_done < budget: # in case report == 0 root = miniosl.run_mcts(mgr, budget-sim_done, self.model, root=root) return root
[docs] def analyze(self): """evaluate all positions in the current game""" import matplotlib.pyplot as plt now = self.cur self.replay(0) evals = [] for i in range(self._record.state_size()): if i > 0: self._do_make_move(self._record.moves[i-1]) self.cur = i value, *_ = self.eval(False) evals.append(value * self.turn().sign) self.replay(now) plt.axhline(y=0, linestyle='dotted') return plt.plot(np.array(evals), '+')
def show_inference_after_move(self): if self._features is None or self._infer_result is None: return return self.show_channels(self._infer_result['aux'], 2, 6)
[docs] def ipywidget(self): """make an interactive widget for viewing record in colab""" import miniosl.widget return miniosl.widget.record_view(self)
[docs] def game_play(self): """make an interactive widget for playing in colab""" import miniosl.widget return miniosl.widget.game_play(self)
[docs] def ipywidget_slider(self): """make an interactive widget for colab or jupyter notebooks""" import miniosl.widget return miniosl.widget.slider_ui(self)