Add GRU Model

This commit is contained in:
Corentin 2021-08-31 02:38:15 +09:00
commit 30f55fa967
2 changed files with 89 additions and 21 deletions

View file

@ -9,7 +9,9 @@ import torch
from torch import nn from torch import nn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell from src.torch_networks import (
TorchLSTMModel, TorchLSTMCellModel, TorchGRUModel,
CustomRNNModel, ChainRNNLayer, BNLSTMCell, CustomLSTMCell)
from src.torch_utils.train import parameter_summary from src.torch_utils.train import parameter_summary
@ -40,7 +42,7 @@ def main():
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length') parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
parser.add_argument('--dimension', type=int, default=15, help='Input dimension') parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension') parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension')
parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length') parser.add_argument('--layer', type=int, default=3, help='LSTM layer stack length')
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train') parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
arguments = parser.parse_args() arguments = parser.parse_args()
@ -50,27 +52,34 @@ def main():
sequence_size: int = arguments.sequence sequence_size: int = arguments.sequence
input_dim: int = arguments.dimension input_dim: int = arguments.dimension
hidden_size: int = arguments.hidden hidden_size: int = arguments.hidden
num_layer: int = arguments.lstm num_layer: int = arguments.layer
max_step: int = arguments.step max_step: int = arguments.step
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
if model == 'stack': if model == 'stack':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'stack-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
elif model == 'stack-torchcell': elif model == 'stack-torchcell':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device) network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
elif model == 'chain': elif model == 'stack-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
layer_class=ChainLSTMLayer).to(device) elif model == 'stack-custom':
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device)
elif model == 'chain-lstm':
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainRNNLayer).to(device)
elif model == 'chain-bn': elif model == 'chain-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device) layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device)
elif model == 'chain-custom':
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device)
elif model == 'torch-cell': elif model == 'torch-cell':
network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'torch-lstm': elif model == 'torch-lstm':
network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'torch-gru':
network = TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
else: else:
print('Error : Unkown model') print('Error : Unkown model')
sys.exit(1) sys.exit(1)
@ -96,7 +105,7 @@ def main():
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary])) for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4) optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()

View file

@ -131,7 +131,51 @@ class BNLSTMCell(nn.Module):
nn.init.zeros_(self.bias) nn.init.zeros_(self.bias)
class StackLSTMLayer(nn.Module): class CustomLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.empty(input_size + hidden_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size))
self.bias = nn.Parameter(torch.empty(4 * hidden_size))
# self.in_norm = nn.BatchNorm1d(4 * hidden_size)
# self.cell_norm = nn.BatchNorm1d(4 * hidden_size)
# self.out_norm = nn.LayerNorm(hidden_size)
self.in_act = nn.Sigmoid()
self.forget_act = nn.Sigmoid()
self.cell_act = nn.Tanh()
self.out_act = nn.Sigmoid()
self.reset_parameters()
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, torch.Tensor]:
hx, cx = state
gates = (
torch.mm(torch.cat([input_data, hx], 1), self.weight_ih)
+ torch.mm(hx, self.weight_hh)
+ self.bias)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = self.in_act(ingate)
forgetgate = self.forget_act(forgetgate)
cellgate = self.cell_act(cellgate)
outgate = self.out_act(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return (hy, cy)
def reset_parameters(self) -> None:
for weight in [self.weight_hh, self.weight_ih]:
nn.init.xavier_normal_(weight)
nn.init.zeros_(self.bias)
class StackRNNLayer(nn.Module):
def __init__(self, input_size, hidden_size, cell_class=LSTMCell): def __init__(self, input_size, hidden_size, cell_class=LSTMCell):
super().__init__() super().__init__()
self.cell = cell_class(input_size=input_size, hidden_size=hidden_size) self.cell = cell_class(input_size=input_size, hidden_size=hidden_size)
@ -146,7 +190,7 @@ class StackLSTMLayer(nn.Module):
return torch.stack(outputs), state return torch.stack(outputs), state
class ChainLSTMLayer(nn.Module): class ChainRNNLayer(nn.Module):
def __init__(self, input_size, hidden_size, cell_class=LSTMCell): def __init__(self, input_size, hidden_size, cell_class=LSTMCell):
super().__init__() super().__init__()
self.cell = cell_class(input_size=input_size, hidden_size=hidden_size) self.cell = cell_class(input_size=input_size, hidden_size=hidden_size)
@ -161,8 +205,8 @@ class ChainLSTMLayer(nn.Module):
return torch.stack(outputs), state return torch.stack(outputs), state
class _CustomLSTMLayer(nn.Module): class _CustomRNNLayer(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, layer_class=StackLSTMLayer, cell_class=LSTMCell): def __init__(self, input_size, hidden_size, num_layers, layer_class=StackRNNLayer, cell_class=LSTMCell):
super().__init__() super().__init__()
self.num_layer = num_layers self.num_layer = num_layers
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
@ -184,15 +228,15 @@ class _CustomLSTMLayer(nn.Module):
return output, (torch.stack(output_states), torch.stack(output_cell_states)) return output, (torch.stack(output_states), torch.stack(output_cell_states))
class CustomLSTMModel(nn.Module): class CustomRNNModel(nn.Module):
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1,
num_layer: int = 3, layer_class=StackLSTMLayer, cell_class=LSTMCell): num_layer: int = 3, layer_class=StackRNNLayer, cell_class=nn.LSTMCell):
super().__init__() super().__init__()
self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2 self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2
self.output_size = output_size if output_size > 0 else input_size self.output_size = output_size if output_size > 0 else input_size
self.lstm = _CustomLSTMLayer(input_size=input_size, hidden_size=self.hidden_size, self.lstm = _CustomRNNLayer(input_size=input_size, hidden_size=self.hidden_size,
num_layers=num_layer, layer_class=layer_class, cell_class=cell_class) num_layers=num_layer, layer_class=layer_class, cell_class=cell_class)
self.dense = nn.Linear(hidden_size, self.output_size) self.dense = nn.Linear(hidden_size, self.output_size)
self.zero_state = torch.nn.Parameter( self.zero_state = torch.nn.Parameter(
@ -213,3 +257,18 @@ class CustomLSTMModel(nn.Module):
input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0)) input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0))
output, state = self.lstm(input_data, init_state) output, state = self.lstm(input_data, init_state)
return self.dense(output), state return self.dense(output), state
class TorchGRUModel(nn.Module):
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3):
super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
self.lstm = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layer)
self.dense = nn.Linear(hidden_size, output_size)
def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
output, state = self.lstm(input_data, init_state)
return self.dense(output), state