Add GRU Model
This commit is contained in:
parent
1704b7aad1
commit
30f55fa967
2 changed files with 89 additions and 21 deletions
35
recorder.py
35
recorder.py
|
|
@ -9,7 +9,9 @@ import torch
|
|||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell
|
||||
from src.torch_networks import (
|
||||
TorchLSTMModel, TorchLSTMCellModel, TorchGRUModel,
|
||||
CustomRNNModel, ChainRNNLayer, BNLSTMCell, CustomLSTMCell)
|
||||
from src.torch_utils.train import parameter_summary
|
||||
|
||||
|
||||
|
|
@ -40,7 +42,7 @@ def main():
|
|||
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
|
||||
parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
|
||||
parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension')
|
||||
parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length')
|
||||
parser.add_argument('--layer', type=int, default=3, help='LSTM layer stack length')
|
||||
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
|
|
@ -50,27 +52,34 @@ def main():
|
|||
sequence_size: int = arguments.sequence
|
||||
input_dim: int = arguments.dimension
|
||||
hidden_size: int = arguments.hidden
|
||||
num_layer: int = arguments.lstm
|
||||
num_layer: int = arguments.layer
|
||||
max_step: int = arguments.step
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if model == 'stack':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
elif model == 'stack-bn':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
elif model == 'stack-torchcell':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
|
||||
elif model == 'chain':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainLSTMLayer).to(device)
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
|
||||
elif model == 'stack-bn':
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
|
||||
elif model == 'stack-custom':
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device)
|
||||
elif model == 'chain-lstm':
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainRNNLayer).to(device)
|
||||
elif model == 'chain-bn':
|
||||
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device)
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device)
|
||||
elif model == 'chain-custom':
|
||||
network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
|
||||
layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device)
|
||||
elif model == 'torch-cell':
|
||||
network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
elif model == 'torch-lstm':
|
||||
network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
elif model == 'torch-gru':
|
||||
network = TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
|
||||
else:
|
||||
print('Error : Unkown model')
|
||||
sys.exit(1)
|
||||
|
|
@ -96,7 +105,7 @@ def main():
|
|||
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
|
||||
for name, shape, size in param_summary]))
|
||||
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4)
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6)
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
|
|
|
|||
|
|
@ -131,7 +131,51 @@ class BNLSTMCell(nn.Module):
|
|||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
class StackLSTMLayer(nn.Module):
|
||||
class CustomLSTMCell(nn.Module):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.weight_ih = nn.Parameter(torch.empty(input_size + hidden_size, 4 * hidden_size))
|
||||
self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size))
|
||||
self.bias = nn.Parameter(torch.empty(4 * hidden_size))
|
||||
# self.in_norm = nn.BatchNorm1d(4 * hidden_size)
|
||||
# self.cell_norm = nn.BatchNorm1d(4 * hidden_size)
|
||||
# self.out_norm = nn.LayerNorm(hidden_size)
|
||||
|
||||
self.in_act = nn.Sigmoid()
|
||||
self.forget_act = nn.Sigmoid()
|
||||
self.cell_act = nn.Tanh()
|
||||
self.out_act = nn.Sigmoid()
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
|
||||
torch.Tensor, torch.Tensor]:
|
||||
hx, cx = state
|
||||
gates = (
|
||||
torch.mm(torch.cat([input_data, hx], 1), self.weight_ih)
|
||||
+ torch.mm(hx, self.weight_hh)
|
||||
+ self.bias)
|
||||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
||||
|
||||
ingate = self.in_act(ingate)
|
||||
forgetgate = self.forget_act(forgetgate)
|
||||
cellgate = self.cell_act(cellgate)
|
||||
outgate = self.out_act(outgate)
|
||||
|
||||
cy = (forgetgate * cx) + (ingate * cellgate)
|
||||
hy = outgate * torch.tanh(cy)
|
||||
|
||||
return (hy, cy)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
for weight in [self.weight_hh, self.weight_ih]:
|
||||
nn.init.xavier_normal_(weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
class StackRNNLayer(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, cell_class=LSTMCell):
|
||||
super().__init__()
|
||||
self.cell = cell_class(input_size=input_size, hidden_size=hidden_size)
|
||||
|
|
@ -146,7 +190,7 @@ class StackLSTMLayer(nn.Module):
|
|||
return torch.stack(outputs), state
|
||||
|
||||
|
||||
class ChainLSTMLayer(nn.Module):
|
||||
class ChainRNNLayer(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, cell_class=LSTMCell):
|
||||
super().__init__()
|
||||
self.cell = cell_class(input_size=input_size, hidden_size=hidden_size)
|
||||
|
|
@ -161,8 +205,8 @@ class ChainLSTMLayer(nn.Module):
|
|||
return torch.stack(outputs), state
|
||||
|
||||
|
||||
class _CustomLSTMLayer(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers, layer_class=StackLSTMLayer, cell_class=LSTMCell):
|
||||
class _CustomRNNLayer(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers, layer_class=StackRNNLayer, cell_class=LSTMCell):
|
||||
super().__init__()
|
||||
self.num_layer = num_layers
|
||||
self.layers = nn.ModuleList(
|
||||
|
|
@ -184,14 +228,14 @@ class _CustomLSTMLayer(nn.Module):
|
|||
return output, (torch.stack(output_states), torch.stack(output_cell_states))
|
||||
|
||||
|
||||
class CustomLSTMModel(nn.Module):
|
||||
class CustomRNNModel(nn.Module):
|
||||
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1,
|
||||
num_layer: int = 3, layer_class=StackLSTMLayer, cell_class=LSTMCell):
|
||||
num_layer: int = 3, layer_class=StackRNNLayer, cell_class=nn.LSTMCell):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2
|
||||
self.output_size = output_size if output_size > 0 else input_size
|
||||
|
||||
self.lstm = _CustomLSTMLayer(input_size=input_size, hidden_size=self.hidden_size,
|
||||
self.lstm = _CustomRNNLayer(input_size=input_size, hidden_size=self.hidden_size,
|
||||
num_layers=num_layer, layer_class=layer_class, cell_class=cell_class)
|
||||
self.dense = nn.Linear(hidden_size, self.output_size)
|
||||
|
||||
|
|
@ -213,3 +257,18 @@ class CustomLSTMModel(nn.Module):
|
|||
input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0))
|
||||
output, state = self.lstm(input_data, init_state)
|
||||
return self.dense(output), state
|
||||
|
||||
|
||||
class TorchGRUModel(nn.Module):
|
||||
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3):
|
||||
super().__init__()
|
||||
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
|
||||
output_size = output_size if output_size > 0 else input_size
|
||||
|
||||
self.lstm = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layer)
|
||||
self.dense = nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[
|
||||
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
output, state = self.lstm(input_data, init_state)
|
||||
return self.dense(output), state
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue