from argparse import ArgumentParser from pathlib import Path import shutil import sys import time import numpy as np import torch from torch import nn from torch.utils.tensorboard import SummaryWriter from src.metrics import Metrics from src.torch_utils.layers import Conv1d, Deconv1d from src.torch_utils.train import parameter_summary def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray: return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64) class SequenceAutoEncoder(nn.Module): def __init__(self, input_dim: int): super().__init__() self.auto_encoder = nn.Sequential( Conv1d(input_dim, 16, kernel_size=2, stride=2), Conv1d(16, 16, kernel_size=2, stride=2), Conv1d(16, 16, kernel_size=2, stride=2), Conv1d(16, 16, kernel_size=2, stride=2), Deconv1d(16, 16, kernel_size=2, stride=2), Deconv1d(16, 16, kernel_size=2, stride=2), Deconv1d(16, 16, kernel_size=2, stride=2), Deconv1d(16, input_dim, kernel_size=2, stride=2, use_batch_norm=False, activation=None)) def forward(self, input_data: torch.Tensor) -> torch.Tensor: return self.auto_encoder(input_data) def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') parser.add_argument('--batch', type=int, default=128, help='Batch size') parser.add_argument('--sequence', type=int, default=128, help='Max sequence length') parser.add_argument('--dimension', type=int, default=16, help='Input dimension') # parser.add_argument('--latent', type=int, default=16, help='Latent space size') parser.add_argument('--step', type=int, default=80_000, help='Number of steps to train') arguments = parser.parse_args() output_dir: Path = arguments.output batch_size: int = arguments.batch sequence_size: int = arguments.sequence input_dim: int = arguments.dimension # latent_size: int = arguments.latent max_step: int = arguments.step device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True network = SequenceAutoEncoder(input_dim).to(device) output_dir = output_dir / f'sequence_encoder_b{batch_size}_s{sequence_size}' if not output_dir.exists(): output_dir.mkdir(parents=True) if (output_dir / 'train').exists(): shutil.rmtree(output_dir / 'train') writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) writer_val = SummaryWriter(log_dir=str(output_dir / 'val'), flush_secs=20) data_sample = torch.zeros((2, input_dim, sequence_size)).to(device) torch.save(network.state_dict(), output_dir / 'model_ini.pt') writer_train.add_graph(network, (data_sample,)) # Save parameters info with open(output_dir / 'parameters.csv', 'w') as param_file: param_summary = parameter_summary(network) names = [len(name) for name, _, _ in param_summary] shapes = [len(str(shape)) for _, shape, _ in param_summary] param_file.write( '\n'.join( [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.994) criterion = nn.CrossEntropyLoss() def val_summary(step): nonlocal network network.eval() val_metrics = Metrics() for _ in range(10): label_np = generate_data(batch_size, sequence_size, input_dim) label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1) outputs = network(data) val_metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item() val_metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item() val_metrics.count += batch_size writer_val.add_scalar('metric/error', 1 - (val_metrics.accuracy / val_metrics.count), global_step=step) writer_val.add_scalar('metric/score_inv', 1 - (val_metrics.score / val_metrics.count), global_step=step) network.train() metrics = Metrics() summary_period = max_step // 100 np.set_printoptions(precision=2) try: start_time = time.time() for step in range(1, max_step + 1): label_np = generate_data(batch_size, sequence_size, input_dim) label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1) optimizer.zero_grad() outputs = network(data) loss = criterion(outputs, label) loss.backward() metrics.loss += loss.item() metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item() metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item() metrics.count += batch_size optimizer.step() if step % summary_period == 0: writer_train.add_scalar('metric/loss', metrics.loss / summary_period, global_step=step) writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step) writer_train.add_scalar('metric/score_inv', 1 - (metrics.score / metrics.count), global_step=step) writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step) scheduler.step() train_time = time.time() - start_time print(f'Step {step}, loss: {metrics.loss / summary_period:.03e}' f', acc: {metrics.accuracy / metrics.count:.03e}' f', score_inv: {1 - (metrics.score / metrics.count):.03e}' f', speed: {summary_period / train_time:0.3f}step/s' f' => {metrics.count / train_time:.02f}input/s') val_summary(step) start_time = time.time() metrics.reset() except KeyboardInterrupt: print('\r ', end='\r') writer_train.close() writer_val.close() network.eval() test_label = [generate_data(1, sequence_size, input_dim) for _ in range(10)] metrics.reset() for label_np in test_label: label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1) outputs = network(data) # sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label # metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item() metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).float().sum().item() metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item() metrics.count += 1 print(f'label: {label_np}\noutput: {torch.argmax(outputs, 1).detach().cpu().numpy()}') print(f'\nTest accuracy: {metrics.accuracy / metrics.count:.05f}') print(f'Test score: {metrics.score / metrics.count:.05f}') if __name__ == '__main__': main()