diff --git a/recorder.py b/recorder.py index a081ffe..07fef5e 100644 --- a/recorder.py +++ b/recorder.py @@ -9,7 +9,9 @@ import torch from torch import nn from torch.utils.tensorboard import SummaryWriter -from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell +from src.torch_networks import ( + TorchLSTMModel, TorchLSTMCellModel, TorchGRUModel, + CustomRNNModel, ChainRNNLayer, BNLSTMCell, CustomLSTMCell) from src.torch_utils.train import parameter_summary @@ -40,7 +42,7 @@ def main(): parser.add_argument('--sequence', type=int, default=8, help='Max sequence length') parser.add_argument('--dimension', type=int, default=15, help='Input dimension') parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension') - parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length') + parser.add_argument('--layer', type=int, default=3, help='LSTM layer stack length') parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train') arguments = parser.parse_args() @@ -50,27 +52,34 @@ def main(): sequence_size: int = arguments.sequence input_dim: int = arguments.dimension hidden_size: int = arguments.hidden - num_layer: int = arguments.lstm + num_layer: int = arguments.layer max_step: int = arguments.step device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True if model == 'stack': - network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) - elif model == 'stack-bn': - network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device) + network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) elif model == 'stack-torchcell': - network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device) - elif model == 'chain': - network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, - layer_class=ChainLSTMLayer).to(device) + network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device) + elif model == 'stack-bn': + network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device) + elif model == 'stack-custom': + network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device) + elif model == 'chain-lstm': + network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, + layer_class=ChainRNNLayer).to(device) elif model == 'chain-bn': - network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, - layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device) + network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, + layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device) + elif model == 'chain-custom': + network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, + layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device) elif model == 'torch-cell': network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) elif model == 'torch-lstm': network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) + elif model == 'torch-gru': + network = TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) else: print('Error : Unkown model') sys.exit(1) @@ -96,7 +105,7 @@ def main(): [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) - optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4) + optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) criterion = nn.CrossEntropyLoss() diff --git a/src/torch_networks.py b/src/torch_networks.py index db7ae5f..347a67e 100644 --- a/src/torch_networks.py +++ b/src/torch_networks.py @@ -131,7 +131,51 @@ class BNLSTMCell(nn.Module): nn.init.zeros_(self.bias) -class StackLSTMLayer(nn.Module): +class CustomLSTMCell(nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = nn.Parameter(torch.empty(input_size + hidden_size, 4 * hidden_size)) + self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size)) + self.bias = nn.Parameter(torch.empty(4 * hidden_size)) + # self.in_norm = nn.BatchNorm1d(4 * hidden_size) + # self.cell_norm = nn.BatchNorm1d(4 * hidden_size) + # self.out_norm = nn.LayerNorm(hidden_size) + + self.in_act = nn.Sigmoid() + self.forget_act = nn.Sigmoid() + self.cell_act = nn.Tanh() + self.out_act = nn.Sigmoid() + + self.reset_parameters() + + def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ + torch.Tensor, torch.Tensor]: + hx, cx = state + gates = ( + torch.mm(torch.cat([input_data, hx], 1), self.weight_ih) + + torch.mm(hx, self.weight_hh) + + self.bias) + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = self.in_act(ingate) + forgetgate = self.forget_act(forgetgate) + cellgate = self.cell_act(cellgate) + outgate = self.out_act(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + + return (hy, cy) + + def reset_parameters(self) -> None: + for weight in [self.weight_hh, self.weight_ih]: + nn.init.xavier_normal_(weight) + nn.init.zeros_(self.bias) + + +class StackRNNLayer(nn.Module): def __init__(self, input_size, hidden_size, cell_class=LSTMCell): super().__init__() self.cell = cell_class(input_size=input_size, hidden_size=hidden_size) @@ -146,7 +190,7 @@ class StackLSTMLayer(nn.Module): return torch.stack(outputs), state -class ChainLSTMLayer(nn.Module): +class ChainRNNLayer(nn.Module): def __init__(self, input_size, hidden_size, cell_class=LSTMCell): super().__init__() self.cell = cell_class(input_size=input_size, hidden_size=hidden_size) @@ -161,8 +205,8 @@ class ChainLSTMLayer(nn.Module): return torch.stack(outputs), state -class _CustomLSTMLayer(nn.Module): - def __init__(self, input_size, hidden_size, num_layers, layer_class=StackLSTMLayer, cell_class=LSTMCell): +class _CustomRNNLayer(nn.Module): + def __init__(self, input_size, hidden_size, num_layers, layer_class=StackRNNLayer, cell_class=LSTMCell): super().__init__() self.num_layer = num_layers self.layers = nn.ModuleList( @@ -184,15 +228,15 @@ class _CustomLSTMLayer(nn.Module): return output, (torch.stack(output_states), torch.stack(output_cell_states)) -class CustomLSTMModel(nn.Module): +class CustomRNNModel(nn.Module): def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, - num_layer: int = 3, layer_class=StackLSTMLayer, cell_class=LSTMCell): + num_layer: int = 3, layer_class=StackRNNLayer, cell_class=nn.LSTMCell): super().__init__() self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2 self.output_size = output_size if output_size > 0 else input_size - self.lstm = _CustomLSTMLayer(input_size=input_size, hidden_size=self.hidden_size, - num_layers=num_layer, layer_class=layer_class, cell_class=cell_class) + self.lstm = _CustomRNNLayer(input_size=input_size, hidden_size=self.hidden_size, + num_layers=num_layer, layer_class=layer_class, cell_class=cell_class) self.dense = nn.Linear(hidden_size, self.output_size) self.zero_state = torch.nn.Parameter( @@ -213,3 +257,18 @@ class CustomLSTMModel(nn.Module): input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0)) output, state = self.lstm(input_data, init_state) return self.dense(output), state + + +class TorchGRUModel(nn.Module): + def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3): + super().__init__() + hidden_size = hidden_size if hidden_size > 0 else input_size * 2 + output_size = output_size if output_size > 0 else input_size + + self.lstm = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layer) + self.dense = nn.Linear(hidden_size, output_size) + + def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[ + torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + output, state = self.lstm(input_data, init_state) + return self.dense(output), state