Source code for miniosl.drawing

"""visualize state and channels"""
from __future__ import annotations
import miniosl
import numpy as np
import drawsvg as dw
import logging
import base64

css = ".piece { font-family: Noto Serif CJK JP; }"  # prevent tofu in png

csadict = {
    'FU': '歩', 'KY': '香', 'KE': '桂', 'GI': '銀', 'KI': '金',
    'KA': '角', 'HI': '飛', 'OU': '王',
    'TO': 'と', 'NY': '杏', 'NK': '圭', 'NG': '全', 'UM': '馬', 'RY': '龍',
}
kanjirow = [None, '一', '二', '三', '四', '五', '六', '七', '八', '九']
kanjicol = [None, '1', '2', '3', '4', '5', '6', '7', '8', '9']

no_pieces_in_hand = 'なし'
label_hand_b = '先手持駒'
label_hand_w = '後手持駒'
hand_others = '他'

char_property = {'class': "piece", 'fill': 'black'}
sub_property = {'class': "piece", 'fill': 'gray'}
b_property = {'fill': 'orange', 'fill_opacity': 0.2, 'stroke_width': 0}
w_property = {'fill': 'blue', 'fill_opacity': 0.2, 'stroke_width': 0}
in_check_property = {'fill': 'red', 'fill_opacity': 0.5, 'stroke_width': 0}
last_move_property = {'fill': 'yellow', 'fill_opacity': 0.2, 'stroke_width': 0}
scale = 20


def ptype_to_ja(ptype: miniosl.Ptype) -> str:
    return csadict[miniosl.to_csa(ptype)]


def hand_pieces_to_ja(state, player):
    ret = ''
    for ptype in miniosl.piece_stand_order:
        cnt = state.count_hand(player, ptype)
        if cnt == 0:
            continue
        ret += csadict[miniosl.to_csa(ptype)] * cnt
    return ret


def gx(x):
    """reference pixel coordinate for shogi board x in [1,9]"""
    return (11-x)*scale+scale/2


def gy(y):
    """reference pixel coordinate for shogi board y in [1,9]"""
    return (y+1)*scale


def radius(cnt):
    """radius of circle for strength cnt"""
    return scale/4 * (cnt ** 1/4)


class ShogiSVG:
    """svg under drawing"""

    def __init__(self, state: miniosl.BaseState, last_move_ja: str = '',
                 id_prefix: str = '', move_number: int = 0,
                 repeat_distance: int = 0, repeat_count: int = 0):
        self.state = state
        self.last_move_ja = last_move_ja
        self.move_number = move_number
        self.repeat_distance = repeat_distance
        self.repeat_count = repeat_count + 1
        width = scale*14.5
        height = scale*12.5 if move_number > 0 else scale*11.5
        self.d = dw.Drawing(width, height, id_prefix=id_prefix)
        self.d.append_css(css)

    def add(self, element) -> None:
        self.d.append(element)

    def _decorate_king(self, player: miniosl.Player) -> None:
        prop = b_property if player == miniosl.black else w_property
        if self.state.turn() == player and self.state.in_check():
            prop = in_check_property
        else:
            prop = prop.copy()
            prop['fill_opacity'] = 0.1
        x, y = self.state.king_square(player).to_xy()
        self.add(dw.Rectangle(gx(x), gy(y-1), scale, scale, **prop))

    def decorate(self):
        if not isinstance(self.state, miniosl.State):
            return
        self._decorate_king(miniosl.black)
        self._decorate_king(miniosl.white)
        plane_c, plane_b, plane_w = np.zeros((9, 9)), np.zeros((9, 9)), np.zeros((9, 9))
        for y in range(1, 10):
            for x in range(1, 10):
                sq = miniosl.Square(x, y)
                eb = self.state.count_cover(miniosl.black, sq)
                ew = self.state.count_cover(miniosl.white, sq)
                if eb > 0 and ew > 0:
                    plane_c[y-1][x-1] = eb+ew
                elif eb > 0:
                    plane_b[y-1][x-1] = eb
                elif ew > 0:
                    plane_w[y-1][x-1] = ew
        self.draw_plane(plane_c, 'purple')
        self.draw_plane(plane_b, 'orange')
        self.draw_plane(plane_w, 'blue')

    def draw_plane(self, plane: np.ndarray, color: str, stroke: int = 0) -> None:
        if plane.shape != (9, 9):
            raise ValueError(f'unexpected shape of plane {plane.shape}')
        maximum = np.max(plane)
        if maximum > 5:
            plane /= maximum
        weight = 1 if plane.max() > 2 or plane.max() == 0 else 4/plane.max()
        p = {'fill': color,  'fill_opacity': 0.2, 'stroke_width': stroke,
             'stroke': color}
        for y in range(1, 10):
            for x in range(1, 10):
                self.add(dw.Circle(gx(x-.5), gy(y-.5),
                                   radius(weight*plane[y-1][x-1]), **p))

    def draw_piece(self, x, y):
        piece = self.state.piece_at(miniosl.Square(x, y))
        if not piece.is_piece():
            return
        kanji = ptype_to_ja(piece.ptype())
        if piece.color() == miniosl.black:
            self.add(dw.Text(kanji, font_size=scale*.875,
                             x=gx(x)+scale*.05, y=gy(y)-scale*.125,
                             **char_property))
        else:
            assert piece.color() == miniosl.white
            self.add(dw.Text(kanji, font_size=scale*.875,
                             x=gx(x)+scale*.9, y=gy(y)-scale*.8,
                             rotate='180', **char_property))

    def draw_grid(self):
        for i in range(1, 10):
            self.add(dw.Text(kanjicol[i], font_size=scale*.75,
                             x=gx(i)+scale*.25, y=gy(0)-scale*0.25,
                             **sub_property))
            self.add(dw.Text(kanjirow[i], font_size=scale*.625,
                             x=gx(0)+scale/4, y=gy(i)-scale*0.25,
                             **sub_property))
            if i < 9:
                self.add(dw.Line(gx(0), gy(i), gx(9), gy(i), stroke='gray'))
                self.add(dw.Line(gx(i), gy(0), gx(i), gy(9), stroke='gray'))
        self.add(dw.Line(gx(0), gy(0), gx(9), gy(0), stroke='black'))
        self.add(dw.Line(gx(0), gy(9), gx(9), gy(9), stroke='black'))
        self.add(dw.Line(gx(0), gy(0), gx(0), gy(9), stroke='black'))
        self.add(dw.Line(gx(9), gy(0), gx(9), gy(9), stroke='black'))

    def show_side_to_move(self, flipped: bool):
        player_to_move = "先手" if self.state.turn() == miniosl.black else "後手"
        self.add(dw.Text('手番 '+player_to_move, font_size=scale*.75,
                         x=gx(9)+scale*.05, y=gy(10), **char_property))
        if self.last_move_ja:
            self.add(dw.Text(f'({self.last_move_ja} まで)', font_size=scale*.7,
                             x=gx(5)+scale*.05, y=gy(10), **sub_property))
        if self.move_number > 0:
            msg = f'{self.move_number}手目'
            if self.repeat_distance > 0:
                msg += f' ({self.repeat_distance}手前と同一局面 {self.repeat_count}回目)'
            self.add(dw.Text(msg, font_size=scale*.7,
                             x=gx(9)+scale*.05, y=gy(11), **sub_property))
        if flipped:
            self.add(dw.Text('先後反転', font_size=scale*.7,
                             x=gx(2)+scale*.05, y=gy(11), **sub_property))

    def hand_pieces_str(self, player: miniosl.Player) -> str:
        return hand_pieces_to_ja(self.state, player)

    def show_hand(self):
        # i: [0,4] label+space
        hand_b = self.hand_pieces_str(miniosl.black)
        hand_w = self.hand_pieces_str(miniosl.white)
        prop = sub_property.copy()
        prop['font_size'] = scale*.75
        for i in range(len(label_hand_b)):
            self.add(dw.Text(label_hand_b[i], x=gx(-2)+scale*.05, y=gy(i*0.8),
                             **prop))
            self.add(dw.Text(label_hand_w[i], x=gx(11)+scale*.5, y=gy(8.4-i*0.8),
                             rotate=180, **prop))
        if not hand_b:
            for i, c in enumerate(no_pieces_in_hand):
                self.add(dw.Text(c, x=gx(-2)+scale*.05, y=gy((i+5)*0.8),
                                 **prop))
        if not hand_w:
            for i, c in enumerate(no_pieces_in_hand):
                self.add(dw.Text(c, x=gx(11)+scale*.5, y=gy(8.4-(i+5)*0.8),
                                 rotate=180, **prop))

        prop = char_property.copy()
        prop['font_size'] = scale*.75
        for i in range(min(7, len(hand_b))):
            self.add(dw.Text(hand_b[i], x=gx(-2)+scale*.05, y=gy((i+5)*0.8),
                             **prop))
        for i in range(7, min(14, len(hand_b))):
            self.add(dw.Text(hand_b[i], x=gx(-1.2)+scale*.05, y=gy((i-2)*0.8),
                             **prop))
        if len(hand_b) >= 14:
            self.add(dw.Text(hand_others, x=gx(-0.35), y=gy(5*0.8), **prop))
        for i in range(min(7, len(hand_w))):
            self.add(dw.Text(hand_w[i], x=gx(11)+scale*.5, y=gy(8.4-(i+5)*0.8),
                             rotate=180, **prop))
        for i in range(7, min(14, len(hand_w))):
            self.add(dw.Text(hand_w[i], x=gx(10.2)+scale*.5, y=gy(8.4-(i-2)*0.8),
                             rotate=180, **prop))
        if len(hand_w) >= 14:
            self.add(dw.Text(hand_others, x=gx(9), y=gy(8.4-5*0.8),
                             rotate=180, **prop))


[docs]def state_to_svg(state: miniosl.BaseState, id: int = 0, *, decorate: bool = False, plane: np.ndarray | None = None, plane_color: str = 'orange', last_move_ja: str = '', last_to: miniosl.Square | None = None, move_number: int = 0, repeat_distance: int = 0, repeat_count: int = 0, flip_if_white: bool = False, ) -> dw.drawing.Drawing: """make a picture of state as svg :param state: state, :param id: id for svg if given, :param decorate: highlight king location and piece covers for each color, :param plane: 9x9 numpy array to make a mark on squares, :param last_move_ja: last move in japanese, :param last_to: the destination square of the last move, :param move_number: ply in a game record, :param repeat_distance: distance to the latest same position, :param repeat_count: number of the occurrence of this state, :param flip_if_white: `rotate180()` if white to move """ flipped = False if flip_if_white and state.turn() == miniosl.white: state = miniosl.State(state.rotate180()) if last_to: last_to = last_to.rotate180() flipped = True if decorate and not isinstance(state, miniosl.State): logging.warning('promote BaseState to State for decoration') state = miniosl.State(state) id_prefix = str(id) if id_prefix == '': board, stand = state.hash_code() id_prefix = f'{board}-{stand}-{move_number}-{repeat_count}' svg = ShogiSVG(state, last_move_ja=last_move_ja, move_number=move_number, id_prefix=id_prefix, repeat_distance=repeat_distance, repeat_count=repeat_count) if decorate: svg.decorate() if last_to: x, y = last_to.to_xy() svg.add(dw.Rectangle(gx(x), gy(y-1), scale, scale, **last_move_property)) for x in range(1, 10): for y in range(1, 10): svg.draw_piece(x, y) if plane is not None: svg.draw_plane(plane, plane_color, 2) svg.draw_grid() svg.show_side_to_move(flipped) svg.show_hand() return svg.d
def to_png_bytes(png: dw.raster.Raster) -> bytes: b64png = png.as_data_uri() bytes = b64png[len('data:image/png;base64,'):] return base64.b64decode(bytes.encode('utf-8')) def save_png(png: dw.raster.Raster, filename: str) -> None: with open(filename, "wb") as f: f.write(to_png_bytes(png))
[docs]def state_to_png(*args, filename='', **kwargs) -> dw.drawing.Drawing | None: """make a picture of state as svg, save to file if given `filename` parameters are the same as :py:func:`state_to_svg` except for `filename` """ png = state_to_svg(*args, **kwargs).rasterize() return png if not filename else miniosl.drawing.save_png(png, filename)
def show_channels(channels, nrows, ncols, flip=False, *, japanese=False): import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import ImageGrid fig = plt.figure(figsize=(ncols*2.5, nrows*2)) grid = ImageGrid(fig, 111, nrows_ncols=(nrows, ncols), axes_pad=0.3, label_mode='all') dan = kanjirow[1:] if japanese else np.arange(1, 10) for i, ax in enumerate(grid): if flip: ax.set_xticks(np.arange(9), np.arange(9, 0, -1)) ax.set_yticks(np.arange(9), reversed(dan)) else: ax.set_xticks(np.arange(9), np.arange(1, 10)) ax.set_yticks(np.arange(9), dan) ax.xaxis.tick_top() ax.yaxis.tick_right() ax.imshow(channels[i], cmap='Oranges', vmin=0, vmax=1, interpolation='none') if flip: ax.invert_yaxis() else: ax.invert_xaxis() ax.tick_params(axis='both', length=0) return plt.show()