diff --git a/.gitignore b/.gitignore index 74ebfc8..ec92bdc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pyc +*.temp output save \ No newline at end of file diff --git a/modulo_tf.py b/modulo_tf.py new file mode 100644 index 0000000..4dc3e2f --- /dev/null +++ b/modulo_tf.py @@ -0,0 +1,194 @@ +from argparse import ArgumentParser +from pathlib import Path +import math +import os +import shutil +import sys +import time + +import numpy as np +import tensorflow as tf + +from src.tf_network import TFLSTMModel +from src.torch_utils.utils.batch_generator import BatchGenerator + + +def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]: + modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32) + data = np.zeros((data_length, batch_size, 1), dtype=np.float32) + starts = [] + for mod in modulos: + starts.append(int(np.random.uniform(0, mod))) + for i in range(batch_size): + # np.where(data[i] % modulos[i] == starts[i], [1.0], data[i]) + for j in range(starts[i], data_length, modulos[i]): + data[j, i, 0] = 1.0 + label = [] + for i in range(batch_size): + label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0) + return data, np.asarray(label, dtype=np.int64) + + +class DataGenerator: + MAX_LENGTH = 1 + INITIALIZED = False + + @staticmethod + def pipeline(sequence_length, _dummy_label): + if not DataGenerator.INITIALIZED: + np.random.seed(time.time_ns() % (2**32)) + DataGenerator.INITIALIZED = True + data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32) + modulo = int(np.random.uniform(3, sequence_length // 2 + 1)) + start = int(np.random.uniform(0, modulo)) + for i in range(start, sequence_length, modulo): + data[i, 0] = 1.0 + return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64) + + +def main(): + parser = ArgumentParser() + parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir') + parser.add_argument('--model', default='torch-lstm', help='Model to train') + parser.add_argument('--batch', type=int, default=32, help='Batch size') + parser.add_argument('--sequence', type=int, default=12, help='Max sequence length') + parser.add_argument('--hidden', type=int, default=16, help='LSTM cells hidden size') + parser.add_argument('--step', type=int, default=2000, help='Number of steps to train') + arguments = parser.parse_args() + + output_dir: Path = arguments.output + model: str = arguments.model + batch_size: int = arguments.batch + sequence_size: int = arguments.sequence + hidden_size: int = arguments.hidden + max_step: int = arguments.step + + output_dir = output_dir.parent / f'modulo_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}' + if not output_dir.exists(): + output_dir.mkdir(parents=True) + if (output_dir / 'train').exists(): + shutil.rmtree(output_dir / 'train') + writer_train = tf.summary.create_file_writer(str(output_dir / 'train'), flush_millis=20000) + + network: tf.keras.Model = None + if model == 'tf-lstm': + network = TFLSTMModel(1, hidden_size, 2) + else: + print('Error : Unkown model') + sys.exit(1) + network = network.compile( + optimizer='adam', + loss=tf.losses) + torch.save(network.state_dict(), output_dir / 'model_ini.pt') + input_sample = torch.from_numpy(generate_data(2, 4)[0]).to(device) + writer_train.add_graph(network, (input_sample,)) + + + # optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4) + # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995) + # criterion = nn.CrossEntropyLoss() + + sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32) + sequence_data[0] = sequence_size + sequence_data_reshaped = np.reshape(np.broadcast_to( + sequence_data, + (batch_size, max_step)).transpose((1, 0)), (batch_size * max_step)) + dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8) + DataGenerator.MAX_LENGTH = sequence_size + with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size, + pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator: + data_np = batch_generator.batch_data + label_np = batch_generator.batch_label + + running_loss = 0.0 + running_accuracy = 0.0 + running_count = 0 + summary_period = max(max_step // 100, 1) + np.set_printoptions(precision=2) + try: + start_time = time.time() + while batch_generator.epoch == 0: + # data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) + data = torch.from_numpy( + data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]]).to(device) + label = torch.from_numpy(label_np).to(device) + + optimizer.zero_grad(set_to_none=True) + + outputs, _states = network(data) + loss = criterion(outputs[-1], label) + running_loss += loss.item() + outputs_np = outputs[-1].detach().cpu().numpy() + running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype( + np.float32).mean() + running_count += 1 + if (batch_generator.step + 1) % summary_period == 0: + writer_train.add_scalar('metric/loss', running_loss / running_count, + global_step=batch_generator.step) + writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), + global_step=batch_generator.step) + writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], + global_step=batch_generator.step) + scheduler.step() + + speed = summary_period / (time.time() - start_time) + print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}' + f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') + start_time = time.time() + running_loss = 0.0 + running_accuracy = 0.0 + running_count = 0 + loss.backward() + optimizer.step() + data_np, label_np = batch_generator.next_batch() + except KeyboardInterrupt: + print('\r ', end='\r') + writer_train.close() + + network.eval() + running_accuracy = 0.0 + running_count = 0 + for _ in range(math.ceil(1000 / batch_size)): + data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) + data = torch.from_numpy(data_np).to(device) + label = torch.from_numpy(label_np).to(device) + + outputs, _states = network(data) + outputs_np = outputs[-1].detach().cpu().numpy() + running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype( + np.float32).mean() + running_count += 1 + print(f'Validation accuracy: {running_accuracy / running_count:.03f}') + + test_data = [ + [[1.], [0.], [0.], [1.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], + [[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]], + [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]], + [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], + [0.], [1.], [0.], [0.], [1.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], + [0.], [1.], [0.], [0.], [1.], [0.]], + ] + test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32) + running_accuracy = 0.0 + running_count = 0 + for data, label in zip(test_data, test_label): + outputs, _states = network( + torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device)) + outputs_np = outputs[-1].detach().cpu().numpy() + output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label + running_accuracy += 1.0 if output_correct else 0.0 + running_count += 1 + print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}' + f', output: {int(outputs_np[0, 1] > outputs_np[0, 0])}') + print(f'Test accuracy: {running_accuracy / running_count:.03f}') + + +if __name__ == '__main__': + main() diff --git a/recorder.py b/recorder.py index 07fef5e..545842e 100644 --- a/recorder.py +++ b/recorder.py @@ -9,6 +9,7 @@ import torch from torch import nn from torch.utils.tensorboard import SummaryWriter +from src.metrics import Metrics from src.torch_networks import ( TorchLSTMModel, TorchLSTMCellModel, TorchGRUModel, CustomRNNModel, ChainRNNLayer, BNLSTMCell, CustomLSTMCell) @@ -34,6 +35,35 @@ def score_sequences(sequences: np.ndarray) -> float: return score, max_score +def get_network(model: str, input_dim: int, hidden_size: int, num_layer: int, device: str) -> nn.Module: + if model == 'stack': + return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) + if model == 'stack-torchcell': + return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device) + if model == 'stack-bn': + return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device) + if model == 'stack-custom': + return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device) + if model == 'chain-lstm': + return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, + layer_class=ChainRNNLayer).to(device) + if model == 'chain-bn': + return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, + layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device) + if model == 'chain-custom': + return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, + layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device) + if model == 'torch-cell': + return TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) + if model == 'torch-lstm': + return TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) + if model == 'torch-gru': + return TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) + + print('Error : Unkown model') + sys.exit(1) + + def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') @@ -57,32 +87,8 @@ def main(): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True - if model == 'stack': - network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) - elif model == 'stack-torchcell': - network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device) - elif model == 'stack-bn': - network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device) - elif model == 'stack-custom': - network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device) - elif model == 'chain-lstm': - network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, - layer_class=ChainRNNLayer).to(device) - elif model == 'chain-bn': - network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, - layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device) - elif model == 'chain-custom': - network = CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, - layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device) - elif model == 'torch-cell': - network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) - elif model == 'torch-lstm': - network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) - elif model == 'torch-gru': - network = TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) - else: - print('Error : Unkown model') - sys.exit(1) + + network = get_network(model, input_dim, hidden_size, num_layer, device) output_dir = output_dir.parent / f'recorder_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}_l{num_layer}' if not output_dir.exists(): @@ -109,52 +115,6 @@ def main(): scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) criterion = nn.CrossEntropyLoss() - class Metrics: - class Bench: - def __init__(self): - self.data_gen = 0.0 - self.data_process = 0.0 - self.predict = 0.0 - self.loss = 0.0 - self.backprop = 0.0 - self.optimizer = 0.0 - self.metrics = 0.0 - - def reset(self): - self.data_gen = 0.0 - self.data_process = 0.0 - self.predict = 0.0 - self.loss = 0.0 - self.backprop = 0.0 - self.optimizer = 0.0 - self.metrics = 0.0 - - def get_proportions(self, train_time: float) -> str: - return ( - f'data_gen: {self.data_gen / train_time:.02f}' - f', data_process: {self.data_process / train_time:.02f}' - f', predict: {self.predict / train_time:.02f}' - f', loss: {self.loss / train_time:.02f}' - f', backprop: {self.backprop / train_time:.02f}' - f', optimizer: {self.optimizer / train_time:.02f}' - f', metrics: {self.metrics / train_time:.02f}') - - def __init__(self): - self.loss = 0.0 - self.accuracy = 0.0 - self.score = 0.0 - self.max_score = 0.0 - self.count = 0 - self.bench = self.Bench() - - def reset(self): - self.loss = 0.0 - self.accuracy = 0.0 - self.score = 0.0 - self.max_score = 0.0 - self.count = 0 - self.bench.reset() - metrics = Metrics() summary_period = max_step // 100 min_sequence_size = 4 diff --git a/sequence_encoder.py b/sequence_encoder.py new file mode 100644 index 0000000..dcda3a1 --- /dev/null +++ b/sequence_encoder.py @@ -0,0 +1,171 @@ +from argparse import ArgumentParser +from pathlib import Path +import shutil +import sys +import time + +import numpy as np +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from src.metrics import Metrics +from src.torch_utils.layers import Conv1d, Deconv1d +from src.torch_utils.train import parameter_summary + + +def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray: + return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64) + + +class SequenceAutoEncoder(nn.Module): + def __init__(self, input_dim: int): + super().__init__() + self.auto_encoder = nn.Sequential( + Conv1d(input_dim, 16, kernel_size=2, stride=2), + Conv1d(16, 16, kernel_size=2, stride=2), + Conv1d(16, 16, kernel_size=2, stride=2), + Conv1d(16, 16, kernel_size=2, stride=2), + Deconv1d(16, 16, kernel_size=2, stride=2), + Deconv1d(16, 16, kernel_size=2, stride=2), + Deconv1d(16, 16, kernel_size=2, stride=2), + Deconv1d(16, input_dim, kernel_size=2, stride=2, use_batch_norm=False, activation=None)) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return self.auto_encoder(input_data) + + +def main(): + parser = ArgumentParser() + parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') + parser.add_argument('--batch', type=int, default=128, help='Batch size') + parser.add_argument('--sequence', type=int, default=128, help='Max sequence length') + parser.add_argument('--dimension', type=int, default=16, help='Input dimension') + # parser.add_argument('--latent', type=int, default=16, help='Latent space size') + parser.add_argument('--step', type=int, default=80_000, help='Number of steps to train') + arguments = parser.parse_args() + + output_dir: Path = arguments.output + batch_size: int = arguments.batch + sequence_size: int = arguments.sequence + input_dim: int = arguments.dimension + # latent_size: int = arguments.latent + max_step: int = arguments.step + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + torch.backends.cudnn.benchmark = True + + network = SequenceAutoEncoder(input_dim).to(device) + + output_dir = output_dir / f'sequence_encoder_b{batch_size}_s{sequence_size}' + if not output_dir.exists(): + output_dir.mkdir(parents=True) + if (output_dir / 'train').exists(): + shutil.rmtree(output_dir / 'train') + writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) + writer_val = SummaryWriter(log_dir=str(output_dir / 'val'), flush_secs=20) + data_sample = torch.zeros((2, input_dim, sequence_size)).to(device) + + torch.save(network.state_dict(), output_dir / 'model_ini.pt') + writer_train.add_graph(network, (data_sample,)) + + # Save parameters info + with open(output_dir / 'parameters.csv', 'w') as param_file: + param_summary = parameter_summary(network) + names = [len(name) for name, _, _ in param_summary] + shapes = [len(str(shape)) for _, shape, _ in param_summary] + param_file.write( + '\n'.join( + [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' + for name, shape, size in param_summary])) + + optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6) + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.994) + criterion = nn.CrossEntropyLoss() + + def val_summary(step): + nonlocal network + + network.eval() + val_metrics = Metrics() + for _ in range(10): + label_np = generate_data(batch_size, sequence_size, input_dim) + label = torch.from_numpy(label_np).to(device) + data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1) + + outputs = network(data) + + val_metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item() + val_metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item() + val_metrics.count += batch_size + writer_val.add_scalar('metric/error', 1 - (val_metrics.accuracy / val_metrics.count), global_step=step) + writer_val.add_scalar('metric/score_inv', 1 - (val_metrics.score / val_metrics.count), global_step=step) + network.train() + + metrics = Metrics() + summary_period = max_step // 100 + np.set_printoptions(precision=2) + try: + start_time = time.time() + for step in range(1, max_step + 1): + label_np = generate_data(batch_size, sequence_size, input_dim) + + label = torch.from_numpy(label_np).to(device) + data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1) + + optimizer.zero_grad() + + outputs = network(data) + + loss = criterion(outputs, label) + loss.backward() + + metrics.loss += loss.item() + metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item() + metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item() + metrics.count += batch_size + + optimizer.step() + + if step % summary_period == 0: + writer_train.add_scalar('metric/loss', metrics.loss / summary_period, global_step=step) + writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step) + writer_train.add_scalar('metric/score_inv', 1 - (metrics.score / metrics.count), global_step=step) + writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step) + scheduler.step() + + train_time = time.time() - start_time + print(f'Step {step}, loss: {metrics.loss / summary_period:.03e}' + f', acc: {metrics.accuracy / metrics.count:.03e}' + f', score_inv: {1 - (metrics.score / metrics.count):.03e}' + f', speed: {summary_period / train_time:0.3f}step/s' + f' => {metrics.count / train_time:.02f}input/s') + val_summary(step) + start_time = time.time() + metrics.reset() + except KeyboardInterrupt: + print('\r ', end='\r') + writer_train.close() + writer_val.close() + + network.eval() + test_label = [generate_data(1, sequence_size, input_dim) for _ in range(10)] + metrics.reset() + for label_np in test_label: + label = torch.from_numpy(label_np).to(device) + data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1) + + outputs = network(data) + # sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label + + # metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item() + metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).float().sum().item() + metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item() + metrics.count += 1 + print(f'label: {label_np}\noutput: {torch.argmax(outputs, 1).detach().cpu().numpy()}') + print(f'\nTest accuracy: {metrics.accuracy / metrics.count:.05f}') + print(f'Test score: {metrics.score / metrics.count:.05f}') + + +if __name__ == '__main__': + main() diff --git a/src/metrics.py b/src/metrics.py new file mode 100644 index 0000000..0c1fe39 --- /dev/null +++ b/src/metrics.py @@ -0,0 +1,45 @@ +class Metrics: + class Bench: + def __init__(self): + self.data_gen = 0.0 + self.data_process = 0.0 + self.predict = 0.0 + self.loss = 0.0 + self.backprop = 0.0 + self.optimizer = 0.0 + self.metrics = 0.0 + + def reset(self): + self.data_gen = 0.0 + self.data_process = 0.0 + self.predict = 0.0 + self.loss = 0.0 + self.backprop = 0.0 + self.optimizer = 0.0 + self.metrics = 0.0 + + def get_proportions(self, train_time: float) -> str: + return ( + f'data_gen: {self.data_gen / train_time:.02f}' + f', data_process: {self.data_process / train_time:.02f}' + f', predict: {self.predict / train_time:.02f}' + f', loss: {self.loss / train_time:.02f}' + f', backprop: {self.backprop / train_time:.02f}' + f', optimizer: {self.optimizer / train_time:.02f}' + f', metrics: {self.metrics / train_time:.02f}') + + def __init__(self): + self.loss = 0.0 + self.accuracy = 0.0 + self.score = 0.0 + self.max_score = 0.0 + self.count = 0 + self.bench = self.Bench() + + def reset(self): + self.loss = 0.0 + self.accuracy = 0.0 + self.score = 0.0 + self.max_score = 0.0 + self.count = 0 + self.bench.reset() diff --git a/src/tf_network.py b/src/tf_network.py new file mode 100644 index 0000000..db049d4 --- /dev/null +++ b/src/tf_network.py @@ -0,0 +1,28 @@ +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +from tensorflow.keras import Model + + +# class TorchLSTMModel(nn.Module): +# def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3): +# super().__init__() +# hidden_size = hidden_size if hidden_size > 0 else input_size * 2 +# output_size = output_size if output_size > 0 else input_size + +# self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layer) +# self.dense = nn.Linear(hidden_size, output_size) + +# def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[ +# torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: +# output, state = self.lstm(input_data, init_state) +# return self.dense(output), state + + +def TFLSTMModel(input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3): + hidden_size = hidden_size if hidden_size > 0 else input_size * 2 + output_size = output_size if output_size > 0 else input_size + + return tf.keras.models.Sequential( + [layers.LSTM(hidden_size) for _ in range(num_layer)] + + [layers.Dense(output_size)])