Source code for miniosl.dataset

"""dataset for training in pytorch"""
import miniosl
import torch
import numpy as np
import logging
import os
import os.path
import recordclass


def load_sfen_from_file(sfen, *,
                        compress_and_rm: bool = False, strict: bool = False
                        ) -> list[miniosl.SubRecord]:
    """load game records from file for `GameRecordBlock`

    games must be completed to serve as training data.
    """
    sfen_npz = ''
    if sfen.endswith('.npz'):
        sfen_npz = sfen
    elif os.path.isfile(f'{sfen}.npz'):
        sfen_npz = f'{sfen}.npz'
    if os.path.isfile(sfen_npz):
        with open(sfen_npz, 'r') as f:
            record_set = miniosl.RecordSet.from_npz(sfen_npz)
            return [miniosl.SubRecord(_) for _ in record_set.records]
    data = miniosl.MiniRecordVector()
    ignored = 0
    with open(f'{sfen}', 'r') as f:
        for line in f:
            line = line.strip()
            record = miniosl.usi_record(line)
            if record.result == miniosl.InGame:
                if strict:
                    logging.warning(f'in game {line}')
                    raise ValueError('in game')
                ignored += 1
                continue
            data.append(record)
        if compress_and_rm:
            record_set = miniosl.RecordSet(data)
            record_set.save_npz(f'{sfen}.npz')
            os.remove(f'{sfen}')
    return [miniosl.SubRecord(_) for _ in data]


[docs]class GameDataset(torch.utils.data.Dataset): """dataset for training. - sample position by game index \ (a random position in the game record is returned) - keep the latest `GameRecordBlock` s (e.g., 50) as in MuZero :param window_size: number of game records to keep :param block_unit: number of game records for a block, \ that is a unit in add/replace oporation :param batch_with_collate: need to specify `collate_fn=lambda indices: dataset.collate(indices)` for trainloader, if (and only if) True """ def __init__(self, window_size: int, block_unit: int, batch_with_collate: bool = True): self.window_size = window_size self.block_unit = block_unit self.blocks = miniosl.GameBlockVector() self.cur_block_id = 0 self.block_limit = (window_size + block_unit - 1) // block_unit self.batch_with_collate = batch_with_collate self.blocks.reserve(self.block_limit)
[docs] def block_id(self) -> int: """number of block added so far""" return self.cur_block_id
[docs] def unit_size(self) -> int: """number of records in a `GameRecordBlock`""" return self.block_unit
[docs] def make_block(sfen, compress_and_rm=True, strict=True): """a helper function to make a block""" if isinstance(sfen, str): lst = load_sfen_from_file(sfen, compress_and_rm=compress_and_rm, strict=strict) return miniosl.GameRecordBlock(lst) elif isinstance(sfen, list): return miniosl.GameRecordBlock(sfen) raise ValueError(f'not supported {type(sfen)}')
[docs] def add(self, new_block): """add or replace the oldest one with `new_block`""" if len(new_block) < self.unit_size(): raise ValueError(f'size error {len(new_block)} {self.block_unit}') if len(self.blocks) < self.block_limit: self.blocks.append(new_block) else: self.blocks[self.cur_block_id % self.block_limit] = new_block self.cur_block_id += 1
[docs] def stored_game_records(self): return self.block_unit * len(self.blocks)
[docs] def sample_id(self, index): if not (0 <= index < len(self)): raise ValueError("index") # need this due to the spec of __len__(), even after fully filled index %= self.stored_game_records() return index // self.block_unit, index % self.block_unit
def __len__(self): """epoch should go beyond #records to cover #positions""" return self.stored_game_records() * 100 def __getitem__(self, idx): p, s = self.sample_id(idx) return ((p, s) if self.batch_with_collate else self.reshape_item(self.blocks[p].sample(s)))
[docs] def reshape_item(self, item): input, move_label, value_label, aux_label = item return (torch.from_numpy(input), move_label, np.float32(value_label), torch.from_numpy(aux_label))
[docs] def collate(self, indices): N = len(indices) inputs = np.zeros(N*miniosl.input_unit, dtype=np.int8) inputs2 = np.zeros(N*miniosl.input_unit, dtype=np.int8) policy_labels = np.zeros(N, dtype=np.int32) value_labels = np.zeros(N, dtype=np.float32) aux_labels = np.zeros(N*miniosl.aux_unit, dtype=np.int8) legalmove_labels = np.zeros(N * miniosl.legalmove_bs_sz, dtype=np.uint8) # sampled_ids = np.zeros(N, dtype=np.uint16) miniosl.collate_features(self.blocks, indices, inputs, policy_labels, value_labels, aux_labels, inputs2, legalmove_labels, # sampled_ids ) # for offset, (p, s) in enumerate(indices): # self.blocks[p][s].sample_feature_labels_to( # offset, # inputs, policy_labels, value_labels, aux_labels) return (torch.from_numpy(inputs.reshape(N, -1, 9, 9)), torch.from_numpy(policy_labels), torch.from_numpy(value_labels), torch.from_numpy(aux_labels.reshape(N, -1)), torch.from_numpy(inputs2.reshape(N, -1, 9, 9)), torch.from_numpy(legalmove_labels.reshape(N, -1)), )
def load_torch_dataset(path: str | list[str]) -> torch.utils.data.Dataset: """load dataset from file""" if isinstance(path, list) or path.endswith('.sfen') \ or path.endswith('.txt') or path.endswith('sfen.npz'): if not isinstance(path, list): path = [path] import tqdm with tqdm.tqdm(total=len(path), disable=(len(path) < 10)) as pbar: blocks = [] pbar.set_description('loading blocks') for _ in path: blocks.append(miniosl.GameDataset.make_block( _, compress_and_rm=False, strict=False)) pbar.update(1) sizes = [len(_) for _ in blocks] unit = min(sizes) logging.info(f'load {sum(sizes)} data as {unit} x {len(sizes)}') set = GameDataset(unit * len(sizes), unit, batch_with_collate=True) for blk in blocks: set.add(blk) return set raise ValueError(f'unsupported data {path}') """input features (obs, obs_after) and labels (others)""" BatchInput = recordclass.recordclass( 'BatchInput', ( 'obs', 'move', 'value', 'aux', 'obs_after', 'legal_move')) BatchOutput = recordclass.recordclass( 'BatchOutput', ('move', 'value', 'aux')) def preprocess_batch(batch: torch.Tensor): """unpack data to feed them into neuralnetworks batch should be located on gpus in advance for efficient data transfer """ inputs = BatchInput(*batch) inputs.obs = inputs.obs.float() inputs.aux = inputs.aux.float() inputs.obs_after = inputs.obs_after.float() inputs.obs /= miniosl.One inputs.move = inputs.move.long() inputs.aux /= miniosl.One inputs.obs_after /= miniosl.One return inputs TargetValue = recordclass.recordclass( 'TargetValue', ('td1', 'td2', 'soft')) def process_target_batch(batch_output: tuple[torch.Tensor]): _, succ_value, *_ = batch_output succ_value = succ_value.float() ret = TargetValue( -succ_value[:, 0].flatten().detach(), # tdtarget -succ_value[:, 1].flatten().detach(), # tdtarget2 -succ_value[:, 2].flatten().detach(), # softtarget ) return ret