time-series/recorder.py
2021-08-26 21:12:40 +09:00

155 lines
6.6 KiB
Python

from argparse import ArgumentParser
from pathlib import Path
import shutil
import time
import numpy as np
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import LSTMModel, StackedLSTMModel
from src.torch_utils.train import parameter_summary
def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray:
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
input_dim: int = arguments.dimension
max_step: int = arguments.step
model: str = arguments.model
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if model == 'cell':
network = StackedLSTMModel(input_dim + 1).to(device)
else:
network = LSTMModel(input_dim + 1).to(device)
# Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
param_file.write(
'\n'.join(
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
criterion = nn.CrossEntropyLoss()
zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device)
# writer_train.add_graph(network, (zero_data[:, :5],))
if model == 'cell':
state = [(torch.zeros((batch_size, (input_dim + 1) * 2)).to(device),
torch.zeros((batch_size, (input_dim + 1) * 2)).to(device))] * network.NUM_LAYERS
else:
state = None
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
summary_period = max_step // 100
np.set_printoptions(precision=2)
try:
start_time = time.time()
for step in range(1, max_step + 1):
label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)), input_dim)
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0
optimizer.zero_grad()
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state)
outputs = outputs[label_np.shape[1]:, :, :-1]
# state = (state[0].detach(), state[1].detach())
# state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]),
# state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:]))
data[:, :, input_dim] = 0.0
loss = criterion(outputs[:, 0], label[0])
for i in range(1, batch_size):
loss += criterion(outputs[:, i], label[i])
loss /= batch_size
running_loss += loss.item()
running_accuracy += (
torch.sum((torch.argmax(outputs, 2).transpose(1, 0) == label).long(),
1) == label.size(1)).float().mean().item()
running_count += 1
if step % summary_period == 0:
writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=step)
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=step)
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
scheduler.step()
speed = summary_period / (time.time() - start_time)
print(f'Step {step}, loss: {running_loss / running_count:.03e}'
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
start_time = time.time()
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
loss.backward()
optimizer.step()
except KeyboardInterrupt:
print('\r ', end='\r')
writer_train.close()
network.eval()
test_label = [
np.asarray([[0, 0, 0, 0]], dtype=np.int64),
np.asarray([[2, 2, 2, 2]], dtype=np.int64),
np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64),
generate_data(1, 4, input_dim),
np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64),
np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64),
np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64),
generate_data(1, sequence_size, input_dim)
]
zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device)
if model == 'cell':
state = [(torch.zeros((1, (input_dim + 1) * 2)).to(device),
torch.zeros((1, (input_dim + 1) * 2)).to(device))] * 3
running_accuracy = 0.0
running_count = 0
for label_np in test_label:
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state)
outputs = outputs[label_np.shape[1]:, :, :-1]
# state = (state[0].detach(), state[1].detach())
running_accuracy += (
torch.sum(
(torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item()
running_count += 1
print(f'{len(label_np)} label: {label_np}, output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}')
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
if __name__ == '__main__':
main()