from argparse import ArgumentParser from pathlib import Path import shutil import sys import time import numpy as np import torch from torch import nn from torch.utils.tensorboard import SummaryWriter from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell from src.torch_utils.train import parameter_summary def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray: return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64) def score_sequences(sequences: np.ndarray) -> float: score = 0.0 max_score = 0.0 for sequence_data in sequences: sequence_score = 0.0 for result in sequence_data: if not result: break sequence_score += 1.0 if sequence_score > max_score: max_score = sequence_score score += sequence_score return score, max_score def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') parser.add_argument('--model', default='torch-lstm', help='Model to train') parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--sequence', type=int, default=8, help='Max sequence length') parser.add_argument('--dimension', type=int, default=15, help='Input dimension') parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension') parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length') parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train') arguments = parser.parse_args() output_dir: Path = arguments.output model: str = arguments.model batch_size: int = arguments.batch sequence_size: int = arguments.sequence input_dim: int = arguments.dimension hidden_size: int = arguments.hidden num_layer: int = arguments.lstm max_step: int = arguments.step device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True if model == 'stack': network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) elif model == 'stack-bn': network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device) elif model == 'stack-torchcell': network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device) elif model == 'chain': network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, layer_class=ChainLSTMLayer).to(device) elif model == 'chain-bn': network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device) elif model == 'torch-cell': network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) elif model == 'torch-lstm': network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device) else: print('Error : Unkown model') sys.exit(1) output_dir = output_dir.parent / f'recorder_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}_l{num_layer}' if not output_dir.exists(): output_dir.mkdir(parents=True) if (output_dir / 'train').exists(): shutil.rmtree(output_dir / 'train') writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) data_sample = torch.zeros((2, 4, input_dim + 1)).to(device) torch.save(network.state_dict(), output_dir / 'model_ini.pt') writer_train.add_graph(network, (data_sample,)) # Save parameters info with open(output_dir / 'parameters.csv', 'w') as param_file: param_summary = parameter_summary(network) names = [len(name) for name, _, _ in param_summary] shapes = [len(str(shape)) for _, shape, _ in param_summary] param_file.write( '\n'.join( [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' for name, shape, size in param_summary])) optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) criterion = nn.CrossEntropyLoss() class Metrics: class Bench: def __init__(self): self.data_gen = 0.0 self.data_process = 0.0 self.predict = 0.0 self.loss = 0.0 self.backprop = 0.0 self.optimizer = 0.0 self.metrics = 0.0 def reset(self): self.data_gen = 0.0 self.data_process = 0.0 self.predict = 0.0 self.loss = 0.0 self.backprop = 0.0 self.optimizer = 0.0 self.metrics = 0.0 def get_proportions(self, train_time: float) -> str: return ( f'data_gen: {self.data_gen / train_time:.02f}' f', data_process: {self.data_process / train_time:.02f}' f', predict: {self.predict / train_time:.02f}' f', loss: {self.loss / train_time:.02f}' f', backprop: {self.backprop / train_time:.02f}' f', optimizer: {self.optimizer / train_time:.02f}' f', metrics: {self.metrics / train_time:.02f}') def __init__(self): self.loss = 0.0 self.accuracy = 0.0 self.score = 0.0 self.max_score = 0.0 self.count = 0 self.bench = self.Bench() def reset(self): self.loss = 0.0 self.accuracy = 0.0 self.score = 0.0 self.max_score = 0.0 self.count = 0 self.bench.reset() metrics = Metrics() summary_period = max_step // 100 min_sequence_size = 4 zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device) np.set_printoptions(precision=2) try: start_time = time.time() for step in range(1, max_step + 1): step_start_time = time.time() label_np = generate_data( batch_size, int(np.random.uniform(min_sequence_size, sequence_size + 1)), input_dim) data_gen_time = time.time() label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim + 1).float() data[:, :, input_dim] = 1.0 data_process_time = time.time() optimizer.zero_grad() outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0)) outputs = outputs[label_np.shape[1]:, :, :-1] # state = (state[0].detach(), state[1].detach()) # state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]), # state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:])) predict_time = time.time() loss = criterion(outputs[:, 0], label[0]) for i in range(1, batch_size): loss += criterion(outputs[:, i], label[i]) loss /= batch_size loss_time = time.time() loss.backward() backprop_time = time.time() optimizer.step() optim_time = time.time() sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label average_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy()) # if current_score > min_sequence_size * 2 and min_sequence_size < sequence_size - 1: # min_sequence_size += 1 metrics.loss += loss.item() metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().sum().item() metrics.score += average_score if max_score > metrics.max_score: metrics.max_score = max_score metrics.count += batch_size metrics.bench.data_gen += data_gen_time - step_start_time metrics.bench.data_process += data_process_time - data_gen_time metrics.bench.predict += predict_time - data_process_time metrics.bench.loss += loss_time - predict_time metrics.bench.backprop += backprop_time - loss_time metrics.bench.optimizer += optim_time - backprop_time metrics.bench.metrics += time.time() - optim_time if step % summary_period == 0: writer_train.add_scalar('metric/loss', metrics.loss / metrics.count, global_step=step) writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step) writer_train.add_scalar('metric/score', metrics.score / metrics.count, global_step=step) writer_train.add_scalar('metric/max_score', metrics.max_score, global_step=step) writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step) scheduler.step() train_time = time.time() - start_time print(f'Step {step}, loss: {metrics.loss / metrics.count:.03e}' f', acc: {metrics.accuracy / metrics.count:.03e}' f', score: {metrics.score / metrics.count:.03f}' f', speed: {summary_period / train_time:0.3f}step/s' f' => {metrics.count / train_time:.02f}input/s' f'\n ({metrics.bench.get_proportions(train_time)})') start_time = time.time() metrics.reset() except KeyboardInterrupt: print('\r ', end='\r') writer_train.close() network.eval() test_label = [ np.asarray([[0, 0, 0, 0]], dtype=np.int64), np.asarray([[2, 2, 2, 2]], dtype=np.int64), np.asarray([[1, 2, 3, 4]], dtype=np.int64), np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64), generate_data(1, 4, input_dim), np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64), np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64), np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64), generate_data(1, max(4, sequence_size // 4), input_dim), generate_data(1, max(4, sequence_size // 2), input_dim), generate_data(1, max(4, sequence_size * 3 // 4), input_dim), generate_data(1, max(4, sequence_size * 3 // 4), input_dim), generate_data(1, max(4, sequence_size * 3 // 4), input_dim), generate_data(1, max(4, sequence_size * 3 // 4), input_dim), generate_data(1, sequence_size, input_dim), generate_data(1, sequence_size, input_dim), generate_data(1, sequence_size, input_dim), generate_data(1, sequence_size, input_dim) ] zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device) metrics.reset() for label_np in test_label: label = torch.from_numpy(label_np).to(device) data = nn.functional.one_hot(label, input_dim + 1).float() data[:, :, input_dim] = 1.0 outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0)) # state = (state[0].detach(), state[1].detach()) outputs = outputs[label_np.shape[1]:, :, :-1] sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label current_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy()) metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item() metrics.score += average_score if max_score > metrics.max_score: metrics.max_score = max_score metrics.count += 1 print(f'score: {current_score}/{label_np.shape[1]}, label: {label_np}' f', output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}') print(f'Test accuracy: {metrics.accuracy / metrics.count:.03f}') if __name__ == '__main__': main()