time-series/sequence_encoder.py

171 lines
7.3 KiB
Python

from argparse import ArgumentParser
from pathlib import Path
import shutil
import sys
import time
import numpy as np
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src.metrics import Metrics
from src.torch_utils.layers import Conv1d, Deconv1d
from src.torch_utils.train import parameter_summary
def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray:
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
class SequenceAutoEncoder(nn.Module):
def __init__(self, input_dim: int):
super().__init__()
self.auto_encoder = nn.Sequential(
Conv1d(input_dim, 16, kernel_size=2, stride=2),
Conv1d(16, 16, kernel_size=2, stride=2),
Conv1d(16, 16, kernel_size=2, stride=2),
Conv1d(16, 16, kernel_size=2, stride=2),
Deconv1d(16, 16, kernel_size=2, stride=2),
Deconv1d(16, 16, kernel_size=2, stride=2),
Deconv1d(16, 16, kernel_size=2, stride=2),
Deconv1d(16, input_dim, kernel_size=2, stride=2, use_batch_norm=False, activation=None))
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return self.auto_encoder(input_data)
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
parser.add_argument('--batch', type=int, default=128, help='Batch size')
parser.add_argument('--sequence', type=int, default=128, help='Max sequence length')
parser.add_argument('--dimension', type=int, default=16, help='Input dimension')
# parser.add_argument('--latent', type=int, default=16, help='Latent space size')
parser.add_argument('--step', type=int, default=80_000, help='Number of steps to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
input_dim: int = arguments.dimension
# latent_size: int = arguments.latent
max_step: int = arguments.step
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
network = SequenceAutoEncoder(input_dim).to(device)
output_dir = output_dir / f'sequence_encoder_b{batch_size}_s{sequence_size}'
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
writer_val = SummaryWriter(log_dir=str(output_dir / 'val'), flush_secs=20)
data_sample = torch.zeros((2, input_dim, sequence_size)).to(device)
torch.save(network.state_dict(), output_dir / 'model_ini.pt')
writer_train.add_graph(network, (data_sample,))
# Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
param_file.write(
'\n'.join(
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.994)
criterion = nn.CrossEntropyLoss()
def val_summary(step):
nonlocal network
network.eval()
val_metrics = Metrics()
for _ in range(10):
label_np = generate_data(batch_size, sequence_size, input_dim)
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
outputs = network(data)
val_metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item()
val_metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
val_metrics.count += batch_size
writer_val.add_scalar('metric/error', 1 - (val_metrics.accuracy / val_metrics.count), global_step=step)
writer_val.add_scalar('metric/score_inv', 1 - (val_metrics.score / val_metrics.count), global_step=step)
network.train()
metrics = Metrics()
summary_period = max_step // 100
np.set_printoptions(precision=2)
try:
start_time = time.time()
for step in range(1, max_step + 1):
label_np = generate_data(batch_size, sequence_size, input_dim)
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
optimizer.zero_grad()
outputs = network(data)
loss = criterion(outputs, label)
loss.backward()
metrics.loss += loss.item()
metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item()
metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
metrics.count += batch_size
optimizer.step()
if step % summary_period == 0:
writer_train.add_scalar('metric/loss', metrics.loss / summary_period, global_step=step)
writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step)
writer_train.add_scalar('metric/score_inv', 1 - (metrics.score / metrics.count), global_step=step)
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
scheduler.step()
train_time = time.time() - start_time
print(f'Step {step}, loss: {metrics.loss / summary_period:.03e}'
f', acc: {metrics.accuracy / metrics.count:.03e}'
f', score_inv: {1 - (metrics.score / metrics.count):.03e}'
f', speed: {summary_period / train_time:0.3f}step/s'
f' => {metrics.count / train_time:.02f}input/s')
val_summary(step)
start_time = time.time()
metrics.reset()
except KeyboardInterrupt:
print('\r ', end='\r')
writer_train.close()
writer_val.close()
network.eval()
test_label = [generate_data(1, sequence_size, input_dim) for _ in range(10)]
metrics.reset()
for label_np in test_label:
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
outputs = network(data)
# sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
# metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item()
metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).float().sum().item()
metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
metrics.count += 1
print(f'label: {label_np}\noutput: {torch.argmax(outputs, 1).detach().cpu().numpy()}')
print(f'\nTest accuracy: {metrics.accuracy / metrics.count:.05f}')
print(f'Test score: {metrics.score / metrics.count:.05f}')
if __name__ == '__main__':
main()