from argparse import ArgumentParser from pathlib import Path import shutil import sys import time import numpy as np import torch from torch import nn from torch.utils.tensorboard import SummaryWriter from src.metrics import Metrics from src.torch_networks import ( TorchLSTMModel, TorchLSTMCellModel, TorchGRUModel, CustomRNNModel, ChainRNNLayer, BNLSTMCell, CustomLSTMCell) from src.torch_utils.train import parameter_summary def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray: return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64) def score_sequences(sequences: np.ndarray) -> float: score = 0.0 max_score = 0.0 for sequence_data in sequences: sequence_score = 0.0 for result in sequence_data: if not result: break sequence_score += 1.0 if sequence_score > max_score: max_score = sequence_score score += sequence_score return score, max_score def get_network(model: str, input_dim: int, hidden_size: int, num_layer: int, device: str) -> nn.Module: if model == 'stack': return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) if model == 'stack-torchcell': return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device) if model == 'stack-bn': return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device) if model == 'stack-custom': return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device) if model == 'chain-lstm': return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, layer_class=ChainRNNLayer).to(device) if model == 'chain-bn': return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device) if model == 'chain-custom': return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device) if model == 'torch-cell': return TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) if model == 'torch-lstm': return TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) if model == 'torch-gru': return TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) print('Error : Unkown model') sys.exit(1) def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') parser.add_argument('--model', default='torch-lstm', help='Model to train') parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--sequence', type=int, default=8, help='Max sequence length') parser.add_argument('--dimension', type=int, default=15, help='Input dimension') parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension') parser.add_argument('--layer', type=int, default=3, help='LSTM layer stack length') parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train') arguments = parser.parse_args() output_dir: Path = arguments.output model: str = arguments.model batch_size: int = arguments.batch sequence_size: int = arguments.sequence input_dim: int = arguments.dimension hidden_size: int = arguments.hidden num_layer: int = arguments.layer max_step: int = arguments.step device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True network = get_network(model, input_dim, hidden_size, num_layer, device) output_dir = output_dir.parent / f'recorder_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}_l{num_layer}' if not output_dir.exists(): output_dir.mkdir(parents=True) if (output_dir / 'train').exists(): shutil.rmtree(output_dir / 'train') writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) data_sample = torch.zeros((2, 4, input_dim + 1)).to(device) torch.save(network.state_dict(), output_dir / 'model_ini.pt') writer_train.add_graph(network, (data_sample,)) # Save parameters info with open(output_dir / 'parameters.csv', 'w') as param_file: param_summary = parameter_summary(network) names = [len(name) for name, _, _ in param_summary] shapes = [len(str(shape)) for _, shape, _ in param_summary] param_file.write( '\n'.join( [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) criterion = nn.CrossEntropyLoss() metrics = Metrics() summary_period = max_step // 100 min_sequence_size = 4 zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device) np.set_printoptions(precision=2) try: start_time = time.time() for step in range(1, max_step + 1): step_start_time = time.time() label_np = generate_data( batch_size, int(np.random.uniform(min_sequence_size, sequence_size + 1)), input_dim) data_gen_time = time.time() label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim + 1).float() data[:, :, input_dim] = 1.0 data_process_time = time.time() optimizer.zero_grad() outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0)) outputs = outputs[label_np.shape[1]:, :, :-1] # state = (state[0].detach(), state[1].detach()) # state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]), # state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:])) predict_time = time.time() loss = criterion(outputs[:, 0], label[0]) for i in range(1, batch_size): loss += criterion(outputs[:, i], label[i]) loss /= batch_size loss_time = time.time() loss.backward() backprop_time = time.time() optimizer.step() optim_time = time.time() sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label average_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy()) # if current_score > min_sequence_size * 2 and min_sequence_size < sequence_size - 1: # min_sequence_size += 1 metrics.loss += loss.item() metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().sum().item() metrics.score += average_score if max_score > metrics.max_score: metrics.max_score = max_score metrics.count += batch_size metrics.bench.data_gen += data_gen_time - step_start_time metrics.bench.data_process += data_process_time - data_gen_time metrics.bench.predict += predict_time - data_process_time metrics.bench.loss += loss_time - predict_time metrics.bench.backprop += backprop_time - loss_time metrics.bench.optimizer += optim_time - backprop_time metrics.bench.metrics += time.time() - optim_time if step % summary_period == 0: writer_train.add_scalar('metric/loss', metrics.loss / metrics.count, global_step=step) writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step) writer_train.add_scalar('metric/score', metrics.score / metrics.count, global_step=step) writer_train.add_scalar('metric/max_score', metrics.max_score, global_step=step) writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step) scheduler.step() train_time = time.time() - start_time print(f'Step {step}, loss: {metrics.loss / metrics.count:.03e}' f', acc: {metrics.accuracy / metrics.count:.03e}' f', score: {metrics.score / metrics.count:.03f}' f', speed: {summary_period / train_time:0.3f}step/s' f' => {metrics.count / train_time:.02f}input/s' f'\n ({metrics.bench.get_proportions(train_time)})') start_time = time.time() metrics.reset() except KeyboardInterrupt: print('\r ', end='\r') writer_train.close() network.eval() test_label = [ np.asarray([[0, 0, 0, 0]], dtype=np.int64), np.asarray([[2, 2, 2, 2]], dtype=np.int64), np.asarray([[1, 2, 3, 4]], dtype=np.int64), np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64), generate_data(1, 4, input_dim), np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64), np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64), np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64), generate_data(1, max(4, sequence_size // 4), input_dim), generate_data(1, max(4, sequence_size // 2), input_dim), generate_data(1, max(4, sequence_size * 3 // 4), input_dim), generate_data(1, max(4, sequence_size * 3 // 4), input_dim), generate_data(1, max(4, sequence_size * 3 // 4), input_dim), generate_data(1, max(4, sequence_size * 3 // 4), input_dim), generate_data(1, sequence_size, input_dim), generate_data(1, sequence_size, input_dim), generate_data(1, sequence_size, input_dim), generate_data(1, sequence_size, input_dim) ] zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device) metrics.reset() for label_np in test_label: label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim + 1).float() data[:, :, input_dim] = 1.0 outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0)) # state = (state[0].detach(), state[1].detach()) outputs = outputs[label_np.shape[1]:, :, :-1] sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label current_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy()) metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item() metrics.score += average_score if max_score > metrics.max_score: metrics.max_score = max_score metrics.count += 1 print(f'score: {current_score}/{label_np.shape[1]}, label: {label_np}' f', output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}') print(f'Test accuracy: {metrics.accuracy / metrics.count:.03f}') if __name__ == '__main__': main()