Add GRU Model

This commit is contained in:
Corentin 2021-08-31 02:38:15 +09:00
commit 30f55fa967
2 changed files with 89 additions and 21 deletions

View file

@ -9,7 +9,9 @@ import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell
from src.torch_networks import (
TorchLSTMModel, TorchLSTMCellModel, TorchGRUModel,
CustomRNNModel, ChainRNNLayer, BNLSTMCell, CustomLSTMCell)
from src.torch_utils.train import parameter_summary
@ -40,7 +42,7 @@ def main():
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension')
parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length')
parser.add_argument('--layer', type=int, default=3, help='LSTM layer stack length')
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
arguments = parser.parse_args()
@ -50,27 +52,34 @@ def main():
sequence_size: int = arguments.sequence
input_dim: int = arguments.dimension
hidden_size: int = arguments.hidden
num_layer: int = arguments.lstm
num_layer: int = arguments.layer
max_step: int = arguments.step
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
if model == 'stack':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'stack-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'stack-torchcell':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
elif model == 'chain':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainLSTMLayer).to(device)
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
elif model == 'stack-bn':
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
elif model == 'stack-custom':
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device)
elif model == 'chain-lstm':
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainRNNLayer).to(device)
elif model == 'chain-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device)
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device)
elif model == 'chain-custom':
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device)
elif model == 'torch-cell':
network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'torch-lstm':
network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'torch-gru':
network = TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
else:
print('Error : Unkown model')
sys.exit(1)
@ -96,7 +105,7 @@ def main():
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4)
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
criterion = nn.CrossEntropyLoss()