time-series/recorder.py

240 lines
11 KiB
Python

from argparse import ArgumentParser
from pathlib import Path
import shutil
import sys
import time
import numpy as np
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src.metrics import Metrics
from src.torch_networks import (
TorchLSTMModel, TorchLSTMCellModel, TorchGRUModel,
CustomRNNModel, ChainRNNLayer, BNLSTMCell, CustomLSTMCell)
from src.torch_utils.train import parameter_summary
def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray:
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
def score_sequences(sequences: np.ndarray) -> float:
score = 0.0
max_score = 0.0
for sequence_data in sequences:
sequence_score = 0.0
for result in sequence_data:
if not result:
break
sequence_score += 1.0
if sequence_score > max_score:
max_score = sequence_score
score += sequence_score
return score, max_score
def get_network(model: str, input_dim: int, hidden_size: int, num_layer: int, device: str) -> nn.Module:
if model == 'stack':
return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
if model == 'stack-torchcell':
return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
if model == 'stack-bn':
return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
if model == 'stack-custom':
return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=CustomLSTMCell).to(device)
if model == 'chain-lstm':
return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainRNNLayer).to(device)
if model == 'chain-bn':
return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainRNNLayer, cell_class=BNLSTMCell).to(device)
if model == 'chain-custom':
return CustomRNNModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainRNNLayer, cell_class=CustomLSTMCell).to(device)
if model == 'torch-cell':
return TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
if model == 'torch-lstm':
return TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
if model == 'torch-gru':
return TorchGRUModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
print('Error : Unkown model')
sys.exit(1)
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
parser.add_argument('--model', default='torch-lstm', help='Model to train')
parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension')
parser.add_argument('--layer', type=int, default=3, help='LSTM layer stack length')
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
model: str = arguments.model
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
input_dim: int = arguments.dimension
hidden_size: int = arguments.hidden
num_layer: int = arguments.layer
max_step: int = arguments.step
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
network = get_network(model, input_dim, hidden_size, num_layer, device)
output_dir = output_dir.parent / f'recorder_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}_l{num_layer}'
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
data_sample = torch.zeros((2, 4, input_dim + 1)).to(device)
torch.save(network.state_dict(), output_dir / 'model_ini.pt')
writer_train.add_graph(network, (data_sample,))
# Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
param_file.write(
'\n'.join(
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
criterion = nn.CrossEntropyLoss()
metrics = Metrics()
summary_period = max_step // 100
min_sequence_size = 4
zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device)
np.set_printoptions(precision=2)
try:
start_time = time.time()
for step in range(1, max_step + 1):
step_start_time = time.time()
label_np = generate_data(
batch_size, int(np.random.uniform(min_sequence_size, sequence_size + 1)), input_dim)
data_gen_time = time.time()
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0
data_process_time = time.time()
optimizer.zero_grad()
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0))
outputs = outputs[label_np.shape[1]:, :, :-1]
# state = (state[0].detach(), state[1].detach())
# state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]),
# state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:]))
predict_time = time.time()
loss = criterion(outputs[:, 0], label[0])
for i in range(1, batch_size):
loss += criterion(outputs[:, i], label[i])
loss /= batch_size
loss_time = time.time()
loss.backward()
backprop_time = time.time()
optimizer.step()
optim_time = time.time()
sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
average_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy())
# if current_score > min_sequence_size * 2 and min_sequence_size < sequence_size - 1:
# min_sequence_size += 1
metrics.loss += loss.item()
metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().sum().item()
metrics.score += average_score
if max_score > metrics.max_score:
metrics.max_score = max_score
metrics.count += batch_size
metrics.bench.data_gen += data_gen_time - step_start_time
metrics.bench.data_process += data_process_time - data_gen_time
metrics.bench.predict += predict_time - data_process_time
metrics.bench.loss += loss_time - predict_time
metrics.bench.backprop += backprop_time - loss_time
metrics.bench.optimizer += optim_time - backprop_time
metrics.bench.metrics += time.time() - optim_time
if step % summary_period == 0:
writer_train.add_scalar('metric/loss', metrics.loss / metrics.count, global_step=step)
writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step)
writer_train.add_scalar('metric/score', metrics.score / metrics.count, global_step=step)
writer_train.add_scalar('metric/max_score', metrics.max_score, global_step=step)
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
scheduler.step()
train_time = time.time() - start_time
print(f'Step {step}, loss: {metrics.loss / metrics.count:.03e}'
f', acc: {metrics.accuracy / metrics.count:.03e}'
f', score: {metrics.score / metrics.count:.03f}'
f', speed: {summary_period / train_time:0.3f}step/s'
f' => {metrics.count / train_time:.02f}input/s'
f'\n ({metrics.bench.get_proportions(train_time)})')
start_time = time.time()
metrics.reset()
except KeyboardInterrupt:
print('\r ', end='\r')
writer_train.close()
network.eval()
test_label = [
np.asarray([[0, 0, 0, 0]], dtype=np.int64),
np.asarray([[2, 2, 2, 2]], dtype=np.int64),
np.asarray([[1, 2, 3, 4]], dtype=np.int64),
np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64),
generate_data(1, 4, input_dim),
np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64),
np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64),
np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64),
generate_data(1, max(4, sequence_size // 4), input_dim),
generate_data(1, max(4, sequence_size // 2), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim)
]
zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device)
metrics.reset()
for label_np in test_label:
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0))
# state = (state[0].detach(), state[1].detach())
outputs = outputs[label_np.shape[1]:, :, :-1]
sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
current_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy())
metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item()
metrics.score += average_score
if max_score > metrics.max_score:
metrics.max_score = max_score
metrics.count += 1
print(f'score: {current_score}/{label_np.shape[1]}, label: {label_np}'
f', output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}')
print(f'Test accuracy: {metrics.accuracy / metrics.count:.03f}')
if __name__ == '__main__':
main()