"""neural networks in pytorch"""
from __future__ import annotations
import torch
from torch import nn
from typing import Tuple
class ResBlock(nn.Module):
def __init__(self, block: nn.Module):
super().__init__()
self.block = block
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.block(data) + data
class ResBlockAlt(nn.Module):
def __init__(self, block: nn.Module):
super().__init__()
self.block = block
def forward(self, data: torch.Tensor) -> torch.Tensor:
return nn.functional.silu(self.block(data) + data)
class Conv2d(nn.Module):
"""a variant of `nn.Conv2d` separiting bias terms as BatchNorm2d.
The parameters follow the pseudocode in the Gumbel MuZero's paper,
except for BN; (1) scaling is fixed in the original but
learnable here due to the lack of api in pytorch, and (2) momentum
is relaxed than the original of 0.001 to speed up learning.
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
*, padding: int = 0, stride=1, groups=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
padding=padding, bias=False,
stride=stride, groups=groups)
self.bn = nn.BatchNorm2d(out_channels, eps=1e-3)
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.bn(self.conv(data))
class ConvTranspose2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
*, padding: int = 0, stride=1, groups=1):
super().__init__()
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
padding=padding, bias=False,
stride=stride, groups=groups)
self.bn = nn.BatchNorm2d(out_channels, eps=1e-3)
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.bn(self.conv(data))
class PolicyHead(nn.Module):
def __init__(self, *, channels: int, out_channels: int):
super().__init__()
self.head = nn.Sequential(
Conv2d(channels, channels, 1),
nn.ReLU(),
Conv2d(channels, out_channels, 1),
nn.Flatten()
)
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.head(data)
class ValueHead(nn.Module):
def __init__(self, channels: int, hidden_layer_size):
super().__init__()
self.head = nn.Sequential(
Conv2d(channels, 1, 1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(81, hidden_layer_size),
nn.ReLU(),
nn.Linear(hidden_layer_size, 4),
nn.Tanh(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.head(x)
class PoolBias(nn.Module):
"""an interpretation of Fig. 12 in the Gumbel MuZero paper"""
def __init__(self, *, channels: int):
super().__init__()
self.channels = channels
self.conv1x1a = Conv2d(channels, channels, 1)
self.conv1x1b = Conv2d(channels, channels, 1)
self.conv1x1out = Conv2d(channels, channels, 1)
self.linear = nn.Linear(2*channels, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
a = nn.functional.relu(self.conv1x1a(x))
b = nn.functional.relu(self.conv1x1b(x))
bmax = nn.functional.max_pool2d(b, kernel_size=9)
bmean = nn.functional.avg_pool2d(b, kernel_size=9)
b = torch.cat((bmax, bmean),
dim=1).squeeze(-1).squeeze(-1)
b = self.linear(b)
c = a + b[:, :, None, None]
return self.conv1x1out(c)
class KatagoBlock(nn.Module):
"""Nested bottleneck residual architecture described in KataGo's doc
with an additional broadcasting path.
This block places four (instead of two) 3x3 cnn layers inside a
bottleneck part for efficiency. The additional path provides a
global information by broadcasting at the end of the bottleneck
skipping four 3x3 cnn layers. A file oriented convolution gives a
relatively good balance between the cost and accuracy. Note that
the kernel shape of 9x1 is not a popular one but already tried
in shogi though in a bit different context.
https://github.com/lightvector/KataGo/blob/master/docs/KataGoMethods.md#nested-bottleneck-residual-nets
https://raw.githubusercontent.com/lightvector/KataGo/master/images/docs/bottlenecknestedresblock.png
https://www.apply.computer-shogi.org/wcsc33/appeal/Ryfamate/appeal_ryfamate_20230421.pdf
"""
def __init__(self, channels: int, bottleneck_scale=2):
super().__init__()
b_channels = channels // bottleneck_scale
self.convin = Conv2d(channels, b_channels, 1)
self.block1 = nn.Sequential(
Conv2d(b_channels, b_channels, 3, padding=1),
nn.ReLU(),
Conv2d(b_channels, b_channels, 3, padding=1),
)
self.block2 = nn.Sequential(
Conv2d(b_channels, b_channels, 3, padding=1),
nn.ReLU(),
Conv2d(b_channels, b_channels, 3, padding=1),
)
self.convout = Conv2d(b_channels, channels, 1)
self.conv_filea = Conv2d(b_channels, b_channels, (9, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
a = nn.functional.relu(self.convin(x))
b = self.block1(a)
c = nn.functional.silu(a + b)
d = self.block2(c)
files = self.conv_filea(a)
e = nn.functional.silu(c + files + d)
f = self.convout(e)
return nn.functional.silu(x + f)
def make_gumbel_az_block(channels: int) -> nn.Module:
b_channels = channels // 2
return ResBlockAlt(
nn.Sequential(
Conv2d(channels, b_channels, 1),
nn.ReLU(),
Conv2d(b_channels, b_channels, 3, padding=1),
nn.ReLU(),
Conv2d(b_channels, b_channels, 3, padding=1),
nn.ReLU(),
Conv2d(b_channels, channels, 1),
))
class BasicBody(nn.Module):
"""Body of networks to provide a good feature vector for heads.
Typical composition would be
- [KatagoBlock(256, 4) x 2 + PoolBias] x 3, or
- [gumbel_az_block(128) x 3 + PoolBias] x 3.
"""
def __init__(self, *, in_channels: int, channels: int, num_blocks: int,
broadcast_every: int = 8):
super().__init__()
self.conv1 = Conv2d(in_channels, channels, 3, padding=1)
self.body = nn.Sequential(
*[KatagoBlock(channels, 4) # make_gumbel_az_block(channels)
if (_ + 1) % broadcast_every != 0
else ResBlockAlt(PoolBias(channels=channels))
for _ in range(num_blocks)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = nn.functional.relu(self.conv1(x))
x = self.body(x)
return x
class BasicNetwork(nn.Module):
def __init__(self, *, in_channels: int, channels: int, out_channels: int):
super().__init__()
self.body = BasicBody(in_channels=in_channels, channels=channels,
num_blocks=2)
self.head = PolicyHead(channels=channels, out_channels=out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.body(x)
return self.head(x)
class PVNetwork(nn.Module):
def __init__(self, *, in_channels: int, channels: int, out_channels: int,
num_blocks: int,
value_head_hidden: int = 256, broadcast_every: int = 3):
super().__init__()
self.body = BasicBody(in_channels=in_channels, channels=channels,
num_blocks=num_blocks,
broadcast_every=broadcast_every)
self.head = PolicyHead(channels=channels, out_channels=out_channels)
self.value_head = ValueHead(channels=channels,
hidden_layer_size=value_head_hidden)
self.config = {
'in_channels': in_channels, 'channels': channels,
'out_channels': out_channels, # no aux
'num_blocks': num_blocks, 'value_head_hidden': value_head_hidden,
'broadcast_every': broadcast_every,
}
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""take a batch of input features
and return a batch of [policies, values].
"""
x = self.body(x)
return self.head(x), self.value_head(x)
def save_with_dict(self, filename):
"""save config and weights into .ptd file"""
torch.save({'cfg': self.config,
'model_state_dict': self.state_dict()},
filename)
@classmethod
def load_with_dict(cls, filename):
"""make a module with configs and weights saved in .ptd file"""
import logging
import json
objs = torch.load(filename, map_location=torch.device('cpu'))
cfg = objs['cfg']
model = cls(**cfg)
logging.debug(json.dumps(cfg, indent=4))
if 'model_state_dict' in objs:
model.load_state_dict(objs['model_state_dict'])
return model
def clone(self):
"""clone a module with current weights"""
cls = self.__class__
obj = cls(**self.config)
obj.load_state_dict(self.state_dict())
return obj
def soft_update(self, new_state_dict, tau: float = .5,
keys = None):
"""update weights with new ones"""
my_parameters = self.state_dict()
if not keys:
keys = my_parameters.keys()
for key in keys:
old_value = my_parameters[key]
new_value = new_state_dict[key]
my_parameters[key] = tau * new_value + (1-tau) * old_value
self.load_state_dict(my_parameters)
[docs]class StandardNetwork(PVNetwork):
"""Standard residual networks with bottleneck architecture
:param in_channels: number of channels in input features,
:param channels: number of channels in main body,
:param out_channels: number of channels in policy_head,
:param auxout_channels: number of channels in miscellaneous output,
:param value_head_hidden: hidden units in the last layer in the value head
"""
def __init__(self, *, in_channels: int, channels: int, out_channels: int,
auxout_channels: int, num_blocks: int,
value_head_hidden: int = 256, broadcast_every: int = 3):
super().__init__(in_channels=in_channels, channels=channels,
out_channels=out_channels,
num_blocks=num_blocks,
value_head_hidden=value_head_hidden,
broadcast_every=broadcast_every)
self.aux_head = PolicyHead(channels=channels,
out_channels=auxout_channels)
self.config['auxout_channels'] = auxout_channels
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]:
"""take a batch of input features
and return a batch of [policies, values, misc].
"""
x = self.body(x)
return self.head(x), self.value_head(x), self.aux_head(x)
# Local Variables:
# python-indent-offset: 4
# End: