From 1704b7aad1235835cf800979149e3612c4c54ba1 Mon Sep 17 00:00:00 2001 From: Corentin Date: Mon, 30 Aug 2021 23:21:58 +0900 Subject: [PATCH] Fix networks --- .gitignore | 3 +- modulo.py | 48 ++++----- recorder.py | 216 +++++++++++++++++++++++++++++++---------- src/torch_networks.py | 221 ++++++++++++++++++++++++++++-------------- 4 files changed, 343 insertions(+), 145 deletions(-) diff --git a/.gitignore b/.gitignore index b1d5e39..74ebfc8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.pyc -output \ No newline at end of file +output +save \ No newline at end of file diff --git a/modulo.py b/modulo.py index 51ce082..0a2ee9c 100644 --- a/modulo.py +++ b/modulo.py @@ -2,6 +2,7 @@ from argparse import ArgumentParser from pathlib import Path import math import shutil +import sys import time import numpy as np @@ -9,7 +10,7 @@ import torch from torch import nn from torch.utils.tensorboard import SummaryWriter -from src.torch_networks import LSTMModel, LSTMCellModel, StackedLSTMModel +from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer from src.torch_utils.utils.batch_generator import BatchGenerator from src.torch_utils.train import parameter_summary @@ -50,18 +51,21 @@ class DataGenerator: def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir') + parser.add_argument('--model', default='torch-lstm', help='Model to train') parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--sequence', type=int, default=12, help='Max sequence length') + parser.add_argument('--hidden', type=int, default=16, help='LSTM cells hidden size') parser.add_argument('--step', type=int, default=2000, help='Number of steps to train') - parser.add_argument('--model', help='Model to train') arguments = parser.parse_args() output_dir: Path = arguments.output + model: str = arguments.model batch_size: int = arguments.batch sequence_size: int = arguments.sequence + hidden_size: int = arguments.hidden max_step: int = arguments.step - model: str = arguments.model + output_dir = output_dir.parent / f'modulo_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}' if not output_dir.exists(): output_dir.mkdir(parents=True) if (output_dir / 'train').exists(): @@ -71,15 +75,24 @@ def main(): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True if model == 'stack': - network = StackedLSTMModel(1, 16, 2).to(device) - elif model == 'cell': - network = LSTMCellModel(1, 16, 2).to(device) + network = CustomLSTMModel(1, hidden_size, 2).to(device) + elif model == 'stack-torchcell': + network = CustomLSTMModel(1, hidden_size, 2, cell_class=nn.LSTMCell).to(device) + elif model == 'chain': + network = CustomLSTMModel(1, hidden_size, 2, layer_class=ChainLSTMLayer).to(device) + elif model == 'torch-cell': + network = TorchLSTMCellModel(1, hidden_size, 2).to(device) + elif model == 'torch-lstm': + network = TorchLSTMModel(1, hidden_size, 2).to(device) else: - network = LSTMModel(1, 16, 2).to(device) + print('Error : Unkown model') + sys.exit(1) torch.save(network.state_dict(), output_dir / 'model_ini.pt') + input_sample = torch.from_numpy(generate_data(2, 4)[0]).to(device) + writer_train.add_graph(network, (input_sample,)) # Save parameters info - with open(output_dir / 'parameters.csv', 'w') as param_file: + with open(output_dir / 'parameters.csv', 'w', encoding='utf-8') as param_file: param_summary = parameter_summary(network) names = [len(name) for name, _, _ in param_summary] shapes = [len(str(shape)) for _, shape, _ in param_summary] @@ -88,7 +101,7 @@ def main(): [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) - optimizer = torch.optim.Adam(network.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995) criterion = nn.CrossEntropyLoss() @@ -99,17 +112,10 @@ def main(): (batch_size, max_step)).transpose((1, 0)), (batch_size * max_step)) dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8) DataGenerator.MAX_LENGTH = sequence_size - if model in ['cell', 'stack']: - state = [(torch.zeros((batch_size, 16)).to(device), - torch.zeros((batch_size, 16)).to(device))] * network.NUM_LAYERS - else: - state = None with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size, pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator: - # data_np, _ = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) data_np = batch_generator.batch_data label_np = batch_generator.batch_label - # writer_train.add_graph(network, (torch.from_numpy(data_np).to(device),)) running_loss = 0.0 running_accuracy = 0.0 @@ -126,7 +132,7 @@ def main(): optimizer.zero_grad(set_to_none=True) - outputs, _states = network(data, state) + outputs, _states = network(data) loss = criterion(outputs[-1], label) running_loss += loss.item() outputs_np = outputs[-1].detach().cpu().numpy() @@ -164,7 +170,7 @@ def main(): data = torch.from_numpy(data_np).to(device) label = torch.from_numpy(label_np).to(device) - outputs, _states = network(data, state) + outputs, _states = network(data) outputs_np = outputs[-1].detach().cpu().numpy() running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype( np.float32).mean() @@ -189,13 +195,9 @@ def main(): test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32) running_accuracy = 0.0 running_count = 0 - if model in ['cell', 'stack']: - state = [(torch.zeros((1, 16)).to(device), - torch.zeros((1, 16)).to(device))] * network.NUM_LAYERS for data, label in zip(test_data, test_label): outputs, _states = network( - torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device), - state) + torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device)) outputs_np = outputs[-1].detach().cpu().numpy() output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label running_accuracy += 1.0 if output_correct else 0.0 diff --git a/recorder.py b/recorder.py index 2d9f573..a081ffe 100644 --- a/recorder.py +++ b/recorder.py @@ -1,6 +1,7 @@ from argparse import ArgumentParser from pathlib import Path import shutil +import sys import time import numpy as np @@ -8,7 +9,7 @@ import torch from torch import nn from torch.utils.tensorboard import SummaryWriter -from src.torch_networks import LSTMModel, StackedLSTMModel +from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell from src.torch_utils.train import parameter_summary @@ -16,34 +17,74 @@ def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64) +def score_sequences(sequences: np.ndarray) -> float: + score = 0.0 + max_score = 0.0 + for sequence_data in sequences: + sequence_score = 0.0 + for result in sequence_data: + if not result: + break + sequence_score += 1.0 + if sequence_score > max_score: + max_score = sequence_score + score += sequence_score + return score, max_score + + def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') + parser.add_argument('--model', default='torch-lstm', help='Model to train') parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--sequence', type=int, default=8, help='Max sequence length') parser.add_argument('--dimension', type=int, default=15, help='Input dimension') + parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension') + parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length') parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train') - parser.add_argument('--model', help='Model to train') arguments = parser.parse_args() output_dir: Path = arguments.output + model: str = arguments.model batch_size: int = arguments.batch sequence_size: int = arguments.sequence input_dim: int = arguments.dimension + hidden_size: int = arguments.hidden + num_layer: int = arguments.lstm max_step: int = arguments.step - model: str = arguments.model + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + torch.backends.cudnn.benchmark = True + if model == 'stack': + network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) + elif model == 'stack-bn': + network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device) + elif model == 'stack-torchcell': + network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device) + elif model == 'chain': + network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, + layer_class=ChainLSTMLayer).to(device) + elif model == 'chain-bn': + network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, + layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device) + elif model == 'torch-cell': + network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) + elif model == 'torch-lstm': + network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) + else: + print('Error : Unkown model') + sys.exit(1) + + output_dir = output_dir.parent / f'recorder_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}_l{num_layer}' if not output_dir.exists(): output_dir.mkdir(parents=True) if (output_dir / 'train').exists(): shutil.rmtree(output_dir / 'train') writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) + data_sample = torch.zeros((2, 4, input_dim + 1)).to(device) - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - if model == 'cell': - network = StackedLSTMModel(input_dim + 1).to(device) - else: - network = LSTMModel(input_dim + 1).to(device) + torch.save(network.state_dict(), output_dir / 'model_ini.pt') + writer_train.add_graph(network, (data_sample,)) # Save parameters info with open(output_dir / 'parameters.csv', 'w') as param_file: @@ -55,64 +96,129 @@ def main(): [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) - optimizer = torch.optim.Adam(network.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) criterion = nn.CrossEntropyLoss() - zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device) - # writer_train.add_graph(network, (zero_data[:, :5],)) + class Metrics: + class Bench: + def __init__(self): + self.data_gen = 0.0 + self.data_process = 0.0 + self.predict = 0.0 + self.loss = 0.0 + self.backprop = 0.0 + self.optimizer = 0.0 + self.metrics = 0.0 - if model == 'cell': - state = [(torch.zeros((batch_size, (input_dim + 1) * 2)).to(device), - torch.zeros((batch_size, (input_dim + 1) * 2)).to(device))] * network.NUM_LAYERS - else: - state = None - running_loss = 0.0 - running_accuracy = 0.0 - running_count = 0 + def reset(self): + self.data_gen = 0.0 + self.data_process = 0.0 + self.predict = 0.0 + self.loss = 0.0 + self.backprop = 0.0 + self.optimizer = 0.0 + self.metrics = 0.0 + + def get_proportions(self, train_time: float) -> str: + return ( + f'data_gen: {self.data_gen / train_time:.02f}' + f', data_process: {self.data_process / train_time:.02f}' + f', predict: {self.predict / train_time:.02f}' + f', loss: {self.loss / train_time:.02f}' + f', backprop: {self.backprop / train_time:.02f}' + f', optimizer: {self.optimizer / train_time:.02f}' + f', metrics: {self.metrics / train_time:.02f}') + + def __init__(self): + self.loss = 0.0 + self.accuracy = 0.0 + self.score = 0.0 + self.max_score = 0.0 + self.count = 0 + self.bench = self.Bench() + + def reset(self): + self.loss = 0.0 + self.accuracy = 0.0 + self.score = 0.0 + self.max_score = 0.0 + self.count = 0 + self.bench.reset() + + metrics = Metrics() summary_period = max_step // 100 + min_sequence_size = 4 + zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device) np.set_printoptions(precision=2) try: start_time = time.time() for step in range(1, max_step + 1): - label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)), input_dim) + step_start_time = time.time() + label_np = generate_data( + batch_size, int(np.random.uniform(min_sequence_size, sequence_size + 1)), input_dim) + data_gen_time = time.time() + label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim + 1).float() data[:, :, input_dim] = 1.0 + data_process_time = time.time() optimizer.zero_grad() - outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state) + outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0)) outputs = outputs[label_np.shape[1]:, :, :-1] # state = (state[0].detach(), state[1].detach()) # state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]), # state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:])) + predict_time = time.time() - data[:, :, input_dim] = 0.0 loss = criterion(outputs[:, 0], label[0]) for i in range(1, batch_size): loss += criterion(outputs[:, i], label[i]) loss /= batch_size - running_loss += loss.item() - running_accuracy += ( - torch.sum((torch.argmax(outputs, 2).transpose(1, 0) == label).long(), - 1) == label.size(1)).float().mean().item() - running_count += 1 + loss_time = time.time() + + loss.backward() + backprop_time = time.time() + optimizer.step() + optim_time = time.time() + + sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label + average_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy()) + # if current_score > min_sequence_size * 2 and min_sequence_size < sequence_size - 1: + # min_sequence_size += 1 + metrics.loss += loss.item() + metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().sum().item() + metrics.score += average_score + if max_score > metrics.max_score: + metrics.max_score = max_score + metrics.count += batch_size + + metrics.bench.data_gen += data_gen_time - step_start_time + metrics.bench.data_process += data_process_time - data_gen_time + metrics.bench.predict += predict_time - data_process_time + metrics.bench.loss += loss_time - predict_time + metrics.bench.backprop += backprop_time - loss_time + metrics.bench.optimizer += optim_time - backprop_time + metrics.bench.metrics += time.time() - optim_time if step % summary_period == 0: - writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=step) - writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=step) + writer_train.add_scalar('metric/loss', metrics.loss / metrics.count, global_step=step) + writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step) + writer_train.add_scalar('metric/score', metrics.score / metrics.count, global_step=step) + writer_train.add_scalar('metric/max_score', metrics.max_score, global_step=step) writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step) scheduler.step() - speed = summary_period / (time.time() - start_time) - print(f'Step {step}, loss: {running_loss / running_count:.03e}' - f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') + train_time = time.time() - start_time + print(f'Step {step}, loss: {metrics.loss / metrics.count:.03e}' + f', acc: {metrics.accuracy / metrics.count:.03e}' + f', score: {metrics.score / metrics.count:.03f}' + f', speed: {summary_period / train_time:0.3f}step/s' + f' => {metrics.count / train_time:.02f}input/s' + f'\n ({metrics.bench.get_proportions(train_time)})') start_time = time.time() - running_loss = 0.0 - running_accuracy = 0.0 - running_count = 0 - loss.backward() - optimizer.step() + metrics.reset() except KeyboardInterrupt: print('\r ', end='\r') writer_train.close() @@ -121,34 +227,44 @@ def main(): test_label = [ np.asarray([[0, 0, 0, 0]], dtype=np.int64), np.asarray([[2, 2, 2, 2]], dtype=np.int64), + np.asarray([[1, 2, 3, 4]], dtype=np.int64), np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64), generate_data(1, 4, input_dim), np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64), np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64), np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64), + generate_data(1, max(4, sequence_size // 4), input_dim), + generate_data(1, max(4, sequence_size // 2), input_dim), + generate_data(1, max(4, sequence_size * 3 // 4), input_dim), + generate_data(1, max(4, sequence_size * 3 // 4), input_dim), + generate_data(1, max(4, sequence_size * 3 // 4), input_dim), + generate_data(1, max(4, sequence_size * 3 // 4), input_dim), + generate_data(1, sequence_size, input_dim), + generate_data(1, sequence_size, input_dim), + generate_data(1, sequence_size, input_dim), generate_data(1, sequence_size, input_dim) ] zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device) - if model == 'cell': - state = [(torch.zeros((1, (input_dim + 1) * 2)).to(device), - torch.zeros((1, (input_dim + 1) * 2)).to(device))] * 3 - running_accuracy = 0.0 - running_count = 0 + metrics.reset() for label_np in test_label: label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim + 1).float() data[:, :, input_dim] = 1.0 - outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state) - outputs = outputs[label_np.shape[1]:, :, :-1] + outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0)) # state = (state[0].detach(), state[1].detach()) + outputs = outputs[label_np.shape[1]:, :, :-1] + sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label + current_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy()) - running_accuracy += ( - torch.sum( - (torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item() - running_count += 1 - print(f'{len(label_np)} label: {label_np}, output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}') - print(f'Test accuracy: {running_accuracy / running_count:.03f}') + metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item() + metrics.score += average_score + if max_score > metrics.max_score: + metrics.max_score = max_score + metrics.count += 1 + print(f'score: {current_score}/{label_np.shape[1]}, label: {label_np}' + f', output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}') + print(f'Test accuracy: {metrics.accuracy / metrics.count:.03f}') if __name__ == '__main__': diff --git a/src/torch_networks.py b/src/torch_networks.py index 4bf3c6d..db7ae5f 100644 --- a/src/torch_networks.py +++ b/src/torch_networks.py @@ -1,16 +1,17 @@ +import math +from typing import Optional + import torch from torch import nn -class LSTMModel(nn.Module): - NUM_LAYERS = 3 - - def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1): +class TorchLSTMModel(nn.Module): + def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3): super().__init__() hidden_size = hidden_size if hidden_size > 0 else input_size * 2 output_size = output_size if output_size > 0 else input_size - self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS) + self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layer) self.dense = nn.Linear(hidden_size, output_size) def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[ @@ -19,16 +20,57 @@ class LSTMModel(nn.Module): return self.dense(output), state -class LSTMCell(torch.jit.ScriptModule): +class TorchLSTMCellModel(nn.Module): + def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3): + super().__init__() + self.num_layer = num_layer + self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2 + self.output_size = output_size if output_size > 0 else input_size + + self.hidden_size = hidden_size + self.layers = nn.ModuleList([ + nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)] + [ + nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layer - 1)] + ) + self.dense = nn.Linear(hidden_size, output_size) + + def forward(self, input_data: torch.Tensor, + init_states: tuple[torch.Tensor, torch.Tensor] = (torch.zeros(1), torch.zeros(1))) -> tuple[ + torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if len(init_states[0].shape) == 1: + zeros = torch.zeros(self.num_layer, input_data.size(1), self.hidden_size, + dtype=input_data.dtype, device=input_data.device) + init_states = (zeros, zeros) + + output_h_states = torch.jit.annotate(list[torch.Tensor], []) + output_c_states = torch.jit.annotate(list[torch.Tensor], []) + cell_inputs = input_data.unbind(0) + cell_id = 0 + for cell in self.layers: + cell_h_state = torch.jit.annotate(torch.Tensor, init_states[0][cell_id]) + cell_c_state = torch.jit.annotate(torch.Tensor, init_states[1][cell_id]) + cell_outputs = torch.jit.annotate(list[torch.Tensor], []) + for i in range(len(cell_inputs)): + cell_h_state, cell_c_state = cell(cell_inputs[i], (cell_h_state, cell_c_state)) + cell_outputs += [cell_h_state] + cell_inputs = cell_outputs + output_h_states += [cell_h_state] + output_c_states += [cell_c_state] + cell_id += 1 + return self.dense(torch.stack(cell_outputs)), (torch.stack(output_h_states), torch.stack(output_c_states)) + + +class LSTMCell(nn.Module): def __init__(self, input_size, hidden_size): super().__init__() self.input_size = input_size self.hidden_size = hidden_size - self.weight_ih = nn.Parameter(torch.randn(input_size, 4 * hidden_size)) - self.weight_hh = nn.Parameter(torch.randn(hidden_size, 4 * hidden_size)) - self.bias = nn.Parameter(torch.randn(4 * hidden_size)) + self.weight_ih = nn.Parameter(torch.empty(input_size, 4 * hidden_size)) + self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size)) + self.bias = nn.Parameter(torch.empty(4 * hidden_size)) + + self.reset_parameters() - @torch.jit.script_method def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ torch.Tensor, torch.Tensor]: hx, cx = state @@ -46,13 +88,54 @@ class LSTMCell(torch.jit.ScriptModule): return (hy, cy) + def reset_parameters(self) -> None: + for weight in [self.weight_hh, self.weight_ih]: + nn.init.xavier_normal_(weight) + nn.init.zeros_(self.bias) -class LSTMLayer(torch.jit.ScriptModule): + +class BNLSTMCell(nn.Module): def __init__(self, input_size, hidden_size): super().__init__() - self.cell = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = nn.Parameter(torch.empty(input_size, 4 * hidden_size)) + self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size)) + self.bias = nn.Parameter(torch.empty(4 * hidden_size)) + self.bn_1 = nn.BatchNorm1d(4 * hidden_size) + self.bn_2 = nn.BatchNorm1d(4 * hidden_size) + self.bn_out = nn.BatchNorm1d(hidden_size) + + self.reset_parameters() + + def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ + torch.Tensor, torch.Tensor]: + hx, cx = state + gates = ( + (self.bn_1(torch.mm(input_data, self.weight_ih)) + self.bn_2(torch.mm(hx, self.weight_hh)) + self.bias)) + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(self.bn_out(cy)) + + return (hy, cy) + + def reset_parameters(self) -> None: + for weight in [self.weight_hh, self.weight_ih]: + nn.init.xavier_normal_(weight) + nn.init.zeros_(self.bias) + + +class StackLSTMLayer(nn.Module): + def __init__(self, input_size, hidden_size, cell_class=LSTMCell): + super().__init__() + self.cell = cell_class(input_size=input_size, hidden_size=hidden_size) - @torch.jit.script_method def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: inputs = input_data.unbind(0) @@ -63,74 +146,70 @@ class LSTMLayer(torch.jit.ScriptModule): return torch.stack(outputs), state -class StackedLSTM(torch.jit.ScriptModule): - def __init__(self, input_size, hidden_size, num_layers): +class ChainLSTMLayer(nn.Module): + def __init__(self, input_size, hidden_size, cell_class=LSTMCell): super().__init__() - self.layers = nn.ModuleList( - [LSTMLayer(input_size=input_size, hidden_size=hidden_size)] + [ - LSTMLayer(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layers - 1)]) + self.cell = cell_class(input_size=input_size, hidden_size=hidden_size) - @torch.jit.script_method - def forward(self, input_data: torch.Tensor, states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[ - torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: - output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], []) + def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ + torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + inputs = input_data.unbind(0) + outputs = torch.jit.annotate(list[torch.Tensor], []) + for i in range(len(inputs)): + state = self.cell(inputs[i], state) + outputs += [state[1]] + return torch.stack(outputs), state + + +class _CustomLSTMLayer(nn.Module): + def __init__(self, input_size, hidden_size, num_layers, layer_class=StackLSTMLayer, cell_class=LSTMCell): + super().__init__() + self.num_layer = num_layers + self.layers = nn.ModuleList( + [layer_class(input_size=input_size, hidden_size=hidden_size, cell_class=cell_class)] + [ + layer_class(input_size=hidden_size, hidden_size=hidden_size, cell_class=cell_class) + for _ in range(num_layers - 1)]) + + def forward(self, input_data: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[ + torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + output_states = torch.jit.annotate(list[torch.Tensor], []) + output_cell_states = torch.jit.annotate(list[torch.Tensor], []) output = input_data i = 0 for rnn_layer in self.layers: - state = states[i] - output, out_state = rnn_layer(output, state) - output_states += [out_state] + output, out_state = rnn_layer(output, (states[0][i], states[1][i])) + output_states += [out_state[0]] + output_cell_states += [out_state[1]] i += 1 - return output, output_states + return output, (torch.stack(output_states), torch.stack(output_cell_states)) -class StackedLSTMModel(torch.jit.ScriptModule): - NUM_LAYERS = 3 - - def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1): +class CustomLSTMModel(nn.Module): + def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, + num_layer: int = 3, layer_class=StackLSTMLayer, cell_class=LSTMCell): super().__init__() - hidden_size = hidden_size if hidden_size > 0 else input_size * 2 - output_size = output_size if output_size > 0 else input_size + self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2 + self.output_size = output_size if output_size > 0 else input_size - self.lstm = StackedLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS) - self.dense = nn.Linear(hidden_size, output_size) + self.lstm = _CustomLSTMLayer(input_size=input_size, hidden_size=self.hidden_size, + num_layers=num_layer, layer_class=layer_class, cell_class=cell_class) + self.dense = nn.Linear(hidden_size, self.output_size) - @torch.jit.script_method - def forward(self, input_data: torch.Tensor, init_state: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[ - torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: + self.zero_state = torch.nn.Parameter( + torch.zeros(self.lstm.num_layer, self.hidden_size), requires_grad=False) + self.zero_cell_state = torch.nn.Parameter( + torch.zeros(self.lstm.num_layer, self.hidden_size), requires_grad=False) + + def forward(self, + input_data: torch.Tensor, + init_state: tuple[torch.Tensor, torch.Tensor] = (torch.zeros((1)), torch.zeros(1)) + ) -> tuple[ + torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if len(init_state[0].shape) == 1: + init_state = ( + self.zero_state.expand( + input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0), + self.zero_cell_state.expand( + input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0)) output, state = self.lstm(input_data, init_state) return self.dense(output), state - - -class LSTMCellModel(torch.jit.ScriptModule): - NUM_LAYERS = 3 - - def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1): - super().__init__() - hidden_size = hidden_size if hidden_size > 0 else input_size * 2 - output_size = output_size if output_size > 0 else input_size - - self.hidden_size = hidden_size - self.layers = nn.ModuleList([ - nn.LSTMCell(input_size=input_size, hidden_size=hidden_size), - nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size), - nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size) - ]) - self.dense = nn.Linear(hidden_size, output_size) - - @torch.jit.script_method - def forward(self, input_data: torch.Tensor, init_states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[ - torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: - output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], []) - cell_inputs = input_data.unbind(0) - cell_id = 0 - for cell in self.layers: - cell_state = init_states[cell_id] - cell_outputs = torch.jit.annotate(list[torch.Tensor], []) - for i in range(len(cell_inputs)): - cell_state = cell(cell_inputs[i], cell_state) - cell_outputs += [cell_state[0]] - cell_inputs = cell_outputs - output_states += [cell_state] - cell_id += 1 - return self.dense(torch.stack(cell_outputs)), output_states