from argparse import ArgumentParser from pathlib import Path import shutil import time import numpy as np import torch from torch import nn from torch.utils.tensorboard import SummaryWriter from src.torch_networks import LSTMModel, StackedLSTMModel from src.torch_utils.train import parameter_summary def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray: return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64) def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--sequence', type=int, default=8, help='Max sequence length') parser.add_argument('--dimension', type=int, default=15, help='Input dimension') parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train') parser.add_argument('--model', help='Model to train') arguments = parser.parse_args() output_dir: Path = arguments.output batch_size: int = arguments.batch sequence_size: int = arguments.sequence input_dim: int = arguments.dimension max_step: int = arguments.step model: str = arguments.model if not output_dir.exists(): output_dir.mkdir(parents=True) if (output_dir / 'train').exists(): shutil.rmtree(output_dir / 'train') writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') if model == 'cell': network = StackedLSTMModel(input_dim + 1).to(device) else: network = LSTMModel(input_dim + 1).to(device) # Save parameters info with open(output_dir / 'parameters.csv', 'w') as param_file: param_summary = parameter_summary(network) names = [len(name) for name, _, _ in param_summary] shapes = [len(str(shape)) for _, shape, _ in param_summary] param_file.write( '\n'.join( [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) optimizer = torch.optim.Adam(network.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) criterion = nn.CrossEntropyLoss() zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device) # writer_train.add_graph(network, (zero_data[:, :5],)) if model == 'cell': state = [(torch.zeros((batch_size, (input_dim + 1) * 2)).to(device), torch.zeros((batch_size, (input_dim + 1) * 2)).to(device))] * network.NUM_LAYERS else: state = None running_loss = 0.0 running_accuracy = 0.0 running_count = 0 summary_period = max_step // 100 np.set_printoptions(precision=2) try: start_time = time.time() for step in range(1, max_step + 1): label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)), input_dim) label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim + 1).float() data[:, :, input_dim] = 1.0 optimizer.zero_grad() outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state) outputs = outputs[label_np.shape[1]:, :, :-1] # state = (state[0].detach(), state[1].detach()) # state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]), # state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:])) data[:, :, input_dim] = 0.0 loss = criterion(outputs[:, 0], label[0]) for i in range(1, batch_size): loss += criterion(outputs[:, i], label[i]) loss /= batch_size running_loss += loss.item() running_accuracy += ( torch.sum((torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item() running_count += 1 if step % summary_period == 0: writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=step) writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=step) writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step) scheduler.step() speed = summary_period / (time.time() - start_time) print(f'Step {step}, loss: {running_loss / running_count:.03e}' f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') start_time = time.time() running_loss = 0.0 running_accuracy = 0.0 running_count = 0 loss.backward() optimizer.step() except KeyboardInterrupt: print('\r ', end='\r') writer_train.close() network.eval() test_label = [ np.asarray([[0, 0, 0, 0]], dtype=np.int64), np.asarray([[2, 2, 2, 2]], dtype=np.int64), np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64), generate_data(1, 4, input_dim), np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64), np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64), np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64), generate_data(1, sequence_size, input_dim) ] zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device) if model == 'cell': state = [(torch.zeros((1, (input_dim + 1) * 2)).to(device), torch.zeros((1, (input_dim + 1) * 2)).to(device))] * 3 running_accuracy = 0.0 running_count = 0 for label_np in test_label: label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim + 1).float() data[:, :, input_dim] = 1.0 outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state) outputs = outputs[label_np.shape[1]:, :, :-1] # state = (state[0].detach(), state[1].detach()) running_accuracy += ( torch.sum( (torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item() running_count += 1 print(f'{len(label_np)} label: {label_np}, output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}') print(f'Test accuracy: {running_accuracy / running_count:.03f}') if __name__ == '__main__': main()