from argparse import ArgumentParser from pathlib import Path import math import shutil import sys import time import numpy as np import torch from torch import nn from torch.utils.tensorboard import SummaryWriter from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer from src.torch_utils.utils.batch_generator import BatchGenerator from src.torch_utils.train import parameter_summary def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]: modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32) data = np.zeros((data_length, batch_size, 1), dtype=np.float32) starts = [] for mod in modulos: starts.append(int(np.random.uniform(0, mod))) for i in range(batch_size): # np.where(data[i] % modulos[i] == starts[i], [1.0], data[i]) for j in range(starts[i], data_length, modulos[i]): data[j, i, 0] = 1.0 label = [] for i in range(batch_size): label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0) return data, np.asarray(label, dtype=np.int64) class DataGenerator: MAX_LENGTH = 1 INITIALIZED = False @staticmethod def pipeline(sequence_length, _dummy_label): if not DataGenerator.INITIALIZED: np.random.seed(time.time_ns() % (2**32)) DataGenerator.INITIALIZED = True data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32) modulo = int(np.random.uniform(3, sequence_length // 2 + 1)) start = int(np.random.uniform(0, modulo)) for i in range(start, sequence_length, modulo): data[i, 0] = 1.0 return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64) def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir') parser.add_argument('--model', default='torch-lstm', help='Model to train') parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--sequence', type=int, default=12, help='Max sequence length') parser.add_argument('--hidden', type=int, default=16, help='LSTM cells hidden size') parser.add_argument('--step', type=int, default=2000, help='Number of steps to train') arguments = parser.parse_args() output_dir: Path = arguments.output model: str = arguments.model batch_size: int = arguments.batch sequence_size: int = arguments.sequence hidden_size: int = arguments.hidden max_step: int = arguments.step output_dir = output_dir.parent / f'modulo_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}' if not output_dir.exists(): output_dir.mkdir(parents=True) if (output_dir / 'train').exists(): shutil.rmtree(output_dir / 'train') writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True if model == 'stack': network = CustomLSTMModel(1, hidden_size, 2).to(device) elif model == 'stack-torchcell': network = CustomLSTMModel(1, hidden_size, 2, cell_class=nn.LSTMCell).to(device) elif model == 'chain': network = CustomLSTMModel(1, hidden_size, 2, layer_class=ChainLSTMLayer).to(device) elif model == 'torch-cell': network = TorchLSTMCellModel(1, hidden_size, 2).to(device) elif model == 'torch-lstm': network = TorchLSTMModel(1, hidden_size, 2).to(device) else: print('Error : Unkown model') sys.exit(1) torch.save(network.state_dict(), output_dir / 'model_ini.pt') input_sample = torch.from_numpy(generate_data(2, 4)[0]).to(device) writer_train.add_graph(network, (input_sample,)) # Save parameters info with open(output_dir / 'parameters.csv', 'w', encoding='utf-8') as param_file: param_summary = parameter_summary(network) names = [len(name) for name, _, _ in param_summary] shapes = [len(str(shape)) for _, shape, _ in param_summary] param_file.write( '\n'.join( [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995) criterion = nn.CrossEntropyLoss() sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32) sequence_data[0] = sequence_size sequence_data_reshaped = np.reshape(np.broadcast_to( sequence_data, (batch_size, max_step)).transpose((1, 0)), (batch_size * max_step)) dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8) DataGenerator.MAX_LENGTH = sequence_size with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size, pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator: data_np = batch_generator.batch_data label_np = batch_generator.batch_label running_loss = 0.0 running_accuracy = 0.0 running_count = 0 summary_period = max(max_step // 100, 1) np.set_printoptions(precision=2) try: start_time = time.time() while batch_generator.epoch == 0: # data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) data = torch.from_numpy( data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]]).to(device) label = torch.from_numpy(label_np).to(device) optimizer.zero_grad(set_to_none=True) outputs, _states = network(data) loss = criterion(outputs[-1], label) running_loss += loss.item() outputs_np = outputs[-1].detach().cpu().numpy() running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype( np.float32).mean() running_count += 1 if (batch_generator.step + 1) % summary_period == 0: writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=batch_generator.step) writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=batch_generator.step) writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=batch_generator.step) scheduler.step() speed = summary_period / (time.time() - start_time) print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}' f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') start_time = time.time() running_loss = 0.0 running_accuracy = 0.0 running_count = 0 loss.backward() optimizer.step() data_np, label_np = batch_generator.next_batch() except KeyboardInterrupt: print('\r ', end='\r') writer_train.close() network.eval() running_accuracy = 0.0 running_count = 0 for _ in range(math.ceil(1000 / batch_size)): data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) data = torch.from_numpy(data_np).to(device) label = torch.from_numpy(label_np).to(device) outputs, _states = network(data) outputs_np = outputs[-1].detach().cpu().numpy() running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype( np.float32).mean() running_count += 1 print(f'Validation accuracy: {running_accuracy / running_count:.03f}') test_data = [ [[1.], [0.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]], [[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]], ] test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32) running_accuracy = 0.0 running_count = 0 for data, label in zip(test_data, test_label): outputs, _states = network( torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device)) outputs_np = outputs[-1].detach().cpu().numpy() output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label running_accuracy += 1.0 if output_correct else 0.0 running_count += 1 print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}' f', output: {int(outputs_np[0, 1] > outputs_np[0, 0])}') print(f'Test accuracy: {running_accuracy / running_count:.03f}') if __name__ == '__main__': main()