Add GRU Model
This commit is contained in:
parent
1704b7aad1
commit
30f55fa967
2 changed files with 89 additions and 21 deletions
35
recorder.py
35
recorder.py
|
|
@ -9,7 +9,9 @@ import torch
|
|||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell
|
||||
from src.torch_networks import (
|
||||
TorchLSTMModel, TorchLSTMCellModel, TorchGRUModel,
|
||||
CustomRNNModel, ChainRNNLayer, BNLSTMCell, CustomLSTMCell)
|
||||
from src.torch_utils.train import parameter_summary
|
||||
|
||||
|
||||
|
|
@ -40,7 +42,7 @@ def main():
|
|||
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
|
||||
parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
|
||||
parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension')
|
||||
parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length')
|
||||
parser.add_argument('--layer', type=int, default=3, help='LSTM layer stack length')
|
||||
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
|
|
@ -50,27 +52,34 @@ def main():
|
|||
sequence_size: int = arguments.sequence
|
||||
input_dim: int = arguments.dimension
|
||||
hidden_size: int = arguments.hidden
|
||||
num_layer: int = arguments.lstm
|
||||
num_layer: int = arguments.layer
|
||||
max_step: int = arguments.step
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if model == 'stack':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
elif model == 'stack-bn':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
elif model == 'stack-torchcell':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
|
||||
elif model == 'chain':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainLSTMLayer).to(device)
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
|
||||
elif model == 'stack-bn':
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
|
||||
elif model == 'stack-custom':
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device)
|
||||
elif model == 'chain-lstm':
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainRNNLayer).to(device)
|
||||
elif model == 'chain-bn':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device)
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device)
|
||||
elif model == 'chain-custom':
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device)
|
||||
elif model == 'torch-cell':
|
||||
network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
elif model == 'torch-lstm':
|
||||
network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
elif model == 'torch-gru':
|
||||
network = TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
else:
|
||||
print('Error : Unkown model')
|
||||
sys.exit(1)
|
||||
|
|
@ -96,7 +105,7 @@ def main():
|
|||
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
|
||||
for name, shape, size in param_summary]))
|
||||
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4)
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6)
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue