"""inference modules"""
import miniosl
import torch
import numpy as np
import logging
import abc
from typing import Tuple
feature_channels = len(miniosl.channel_id)
def p2elo(p, eps=1e-4):
return -400 * np.log10(1/(abs(p)+eps)-1)
def sort_moves(moves, policy):
flatten = policy.flatten()
prob = [(flatten[move.policy_move_label], move) for move in moves]
prob.sort(key=lambda e: -e[0])
return prob
def softmax(x):
b = np.max(x)
y = np.exp(x - b)
return y / y.sum()
[docs]class InferenceModel(abc.ABC):
"""interface for inference using trained models"""
def __init__(self, device):
super().__init__()
self.device = device
[docs] @abc.abstractmethod
def infer(self, inputs: torch.Tensor) -> Tuple[np.ndarray, np.ndarray,
np.ndarray]:
"""return inference results for batch"""
pass
[docs] def infer_int8(self, inputs: np.ndarray | torch.Tensor
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""an optimized path"""
if isinstance(inputs, np.ndarray):
if inputs.dtype != np.int8:
raise ValueError(f'expected int8, got {inputs.dtype}')
inputs = torch.from_numpy(inputs)
if inputs.dtype != torch.int8:
raise ValueError(f'expected int8, got {inputs.dtype}')
inputs = inputs.to(self.device).float()
inputs /= miniosl.One
return self.infer(inputs)
[docs] def infer_one(self, input: np.ndarray) -> Tuple[np.ndarray, float,
np.ndarray]:
"""return inference results for a single instance"""
outputs = self.infer(torch.stack((torch.from_numpy(input),)))
return [_[0] for _ in outputs]
[docs] def eval(self, input: np.ndarray, *, take_softmax: bool = False
) -> Tuple[np.ndarray, float, np.ndarray]:
"""return (move, value, aux) tuple, after softmax"""
out = self.infer_one(input)
moves = softmax(out[0]).reshape(-1, 9, 9) \
if take_softmax else out[0].reshape(-1, 9, 9)
value = out[1]
aux = out[2].reshape(-1, 9, 9) if len(out) >= 3 else None
return (moves, value, aux)
[docs]class OnnxInfer(InferenceModel):
first_load = True
def __init__(self, path: str, device: str):
import re
super().__init__(device)
import onnxruntime as ort
if device == '':
device = 'cpu'
if device == 'cpu':
provider = ['CPUExecutionProvider']
elif device.startswith('cuda'):
cuda_pattern = r'^cuda:([0-9]+)$'
if match := re.match(cuda_pattern, device):
provider = [('CUDAExecutionProvider',
{'device_id': int(match.group(1))}),
'CPUExecutionProvider']
else:
provider = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
provider = ort.get_available_providers()
self.ort_session = ort.InferenceSession(path, providers=provider)
if OnnxInfer.first_load:
logging.info(self.ort_session.get_providers())
OnnxInfer.first_load = False
self.binding = self.ort_session.io_binding()
logging.debug(f'{device=}')
self.device = device
[docs] def infer(self, inputs: torch.Tensor):
return self.infer_iobinding(inputs.to(self.device))
# return self.infer_naive(inputs)
[docs] def infer_naive(self, inputs: torch.Tensor):
"""inefficient in gpu-cpu transfer if inputs are on already gpu"""
out = self.ort_session.run(None, {"input": inputs.to('cpu').numpy()})
return out
[docs] def infer_iobinding(self, inputs: torch.Tensor):
"""work with torch 2.5.1, onnxruntime-gpu 1.20.1
"""
inputs = inputs.contiguous()
device = self.device
self.binding.bind_input(
name='input',
device_type=device, device_id=0, element_type=np.float32,
shape=tuple(inputs.shape), buffer_ptr=inputs.data_ptr(),
)
move_tensor = torch.empty((inputs.shape[0], 2187),
dtype=torch.float32,
device=device).contiguous()
self.binding.bind_output(
name='move',
device_type=device, device_id=0, element_type=np.float32,
shape=tuple(move_tensor.shape),
buffer_ptr=move_tensor.data_ptr(),
)
# binding.bind_output('move')
value_tensor = torch.empty((inputs.shape[0], 4),
dtype=torch.float32,
device=device).contiguous()
self.binding.bind_output(
name='value',
device_type=device, device_id=0, element_type=np.float32,
shape=tuple(value_tensor.shape),
buffer_ptr=value_tensor.data_ptr(),
)
# binding.bind_output('value')
aux_tensor = torch.empty((inputs.shape[0], 9*9*22),
dtype=torch.float32,
device=device).contiguous()
self.binding.bind_output(
name='aux',
device_type=device, device_id=0, element_type=np.float32,
shape=tuple(aux_tensor.shape),
buffer_ptr=aux_tensor.data_ptr(),
)
# binding.bind_output('aux')
self.ort_session.run_with_iobinding(self.binding)
return (move_tensor.to('cpu').numpy(),
value_tensor.to('cpu').numpy(),
aux_tensor.to('cpu').numpy())
# out = binding.copy_outputs_to_cpu()
[docs]class TorchTRTInfer(InferenceModel):
def __init__(self, path: str, device: str):
import torch_tensorrt
super().__init__(device)
with torch_tensorrt.logging.info():
self.trt_module = torch.jit.load(path)
self.device = device
[docs] def infer(self, inputs: torch.Tensor):
with torch.cuda.device(torch.device(self.device)):
tensor = inputs.half().to(self.device)
outputs = self.trt_module(tensor)
ret = [_.to('cpu').numpy() for _ in outputs]
return ret
[docs]class TorchScriptInfer(InferenceModel):
def __init__(self, path: str, device: str):
super().__init__(device)
self.ts_module = torch.jit.load(path)
[docs] def infer(self, inputs: torch.Tensor):
with torch.no_grad():
tensor = inputs.to(self.device)
outputs = self.ts_module(tensor)
return [_.to('cpu').numpy() for _ in outputs]
[docs]class TorchInfer(InferenceModel):
def __init__(self, model, device: str):
super().__init__(device)
self.model = model
self.model.eval()
self.device = device
[docs] def infer(self, inputs: torch.Tensor):
with torch.no_grad():
tensor = inputs.to(self.device)
outputs = self.model(tensor)
return [_.to('cpu').numpy() for _ in outputs]
[docs]def load(path: str, device: str = "", torch_cfg: dict = {},
*,
compiled: bool = False,
strict: bool = True,
remove_aux_head: bool = False) -> InferenceModel:
"""factory method to load a model from file
:param path: filepath,
:param device: torch device such as 'cuda', 'cpu',
:param torch_cfg: network specification needed for TorchInfer.
"""
if remove_aux_head:
if not (path.endswith('.pt') or path.endswith('.chpt')):
raise ValueError('not supported')
if path.endswith('.onnx'):
return OnnxInfer(path, device)
if path.endswith('.ts'):
if not device:
device = 'cuda'
return TorchTRTInfer(path, device)
if path.endswith('.pts'): # need to used a different extention from TRT's
return TorchScriptInfer(path, device)
if path.endswith('.pt'):
NN = miniosl.network.PVNetwork if remove_aux_head \
else miniosl.network.StandardNetwork
raw_model = NN(**torch_cfg).to(device)
if compiled:
cmodel = torch.compile(raw_model)
model = cmodel
else:
model = raw_model
saved_state = torch.load(path, map_location=torch.device(device))
strict = strict and not remove_aux_head
model.load_state_dict(saved_state, strict=strict)
return TorchInfer(raw_model, device)
if path.endswith('.ptd'):
model = miniosl.StandardNetwork.load_with_dict(path)
model = model.to(device)
return TorchInfer(model, device)
if path.endswith('.chpt'):
checkpoint = torch.load(path, map_location=torch.device(device))
cfg = checkpoint['cfg']
network_cfg = cfg['network_cfg']
for obsolete_key in ['make_bottleneck']:
if obsolete_key in network_cfg:
del network_cfg[obsolete_key]
NN = miniosl.network.PVNetwork if remove_aux_head \
else miniosl.network.StandardNetwork
raw_model = NN(**network_cfg).to(device)
if compiled:
cmodel = torch.compile(raw_model)
model = cmodel
else:
model = raw_model
strict = not remove_aux_head
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
return TorchInfer(raw_model, device)
raise ValueError("unknown filetype")
class InferenceForGameArray(miniosl.InferenceModelStub):
def __init__(self, module: InferenceModel):
super().__init__()
self.module = module
def py_infer(self, features):
features = features.reshape(-1, len(miniosl.channel_id), 9, 9)
return self.module.infer_int8(features)
def export_onnx(model, *, device, filename, remove_aux_head=False):
import onnx # to detect import error eaelier
onnx.__version__
import torch.onnx
model.eval()
dtype = torch.float
dummy_input = torch.randn(1024, feature_channels, 9, 9, device=device,
dtype=dtype)
if not filename.endswith('.onnx'):
filename = f'{filename}.onnx'
if not remove_aux_head:
torch.onnx.export(model, dummy_input, filename,
dynamic_axes={'input': {0: 'batch_size'},
'move': {0: 'batch_size'},
'value': {0: 'batch_size'},
'aux': {0: 'batch_size'}},
verbose=False, input_names=['input'],
output_names=['move', 'value', 'aux'])
else:
torch.onnx.export(model, dummy_input, filename,
dynamic_axes={'input': {0: 'batch_size'},
'move': {0: 'batch_size'},
'value': {0: 'batch_size'}},
verbose=False, input_names=['input'],
output_names=['move', 'value'])
def export_tensorrt(model, *, device, filename, quiet=False):
logging.debug(f'feature challes {feature_channels}')
import torch_tensorrt
if quiet:
torch_tensorrt.logging.set_reportable_log_level(
torch_tensorrt.logging.Level.Error
)
torch_tensorrt.ts.logging.set_is_colored_output_on(True)
if not device:
device = 'cuda'
elif not device.startswith('cuda'):
raise ValueError(f'unexpected device for trt {device}')
model.eval()
model = model.half()
inputs = [
torch_tensorrt.Input(
min_shape=[1, feature_channels, 9, 9],
opt_shape=[128, feature_channels, 9, 9],
max_shape=[2048, feature_channels, 9, 9],
dtype=torch.half,
)]
enabled_precisions = {torch.half}
with torch.cuda.device(torch.device(device)):
trt_ts_module = torch_tensorrt.compile(
torch.jit.script(model),
inputs=inputs, enabled_precisions=enabled_precisions,
ir='ts',
# device=torch.device(device)
)
input_data = torch.randn(16, feature_channels, 9, 9, device=device)
_ = trt_ts_module(input_data.half())
savefile = filename if filename.endswith('.ts') else f'{filename}.ts'
torch.jit.save(trt_ts_module, savefile)
def export_torch_script(model, *, device, filename):
model.eval()
if device == 'cuda':
inputs = torch.rand(8, feature_channels, 9, 9).to(device)
ts_module = torch.jit.trace(model, inputs)
else:
ts_module = torch.jit.optimize_for_inference(torch.jit.script(model))
if not filename.endswith('.pts'):
filename = f'{filename}.pts'
torch.jit.save(ts_module, filename)
def export_model(model: miniosl.PVNetwork, *, device, filename, quiet=False,
remove_aux_head=False):
if filename.endswith('.onnx'):
export_onnx(model, device=device, filename=filename,
remove_aux_head=remove_aux_head)
elif filename.endswith('.ts'):
export_tensorrt(model, device=device, filename=filename, quiet=quiet)
elif filename.endswith('.pts'):
export_torch_script(model, device=device, filename=filename)
elif filename.endswith('.ptd'):
model.save_with_dict(filename)
else:
raise ValueError("unknown filetype")