171 lines
7.3 KiB
Python
171 lines
7.3 KiB
Python
from argparse import ArgumentParser
|
|
from pathlib import Path
|
|
import shutil
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from src.metrics import Metrics
|
|
from src.torch_utils.layers import Conv1d, Deconv1d
|
|
from src.torch_utils.train import parameter_summary
|
|
|
|
|
|
def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray:
|
|
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
|
|
|
|
|
|
class SequenceAutoEncoder(nn.Module):
|
|
def __init__(self, input_dim: int):
|
|
super().__init__()
|
|
self.auto_encoder = nn.Sequential(
|
|
Conv1d(input_dim, 16, kernel_size=2, stride=2),
|
|
Conv1d(16, 16, kernel_size=2, stride=2),
|
|
Conv1d(16, 16, kernel_size=2, stride=2),
|
|
Conv1d(16, 16, kernel_size=2, stride=2),
|
|
Deconv1d(16, 16, kernel_size=2, stride=2),
|
|
Deconv1d(16, 16, kernel_size=2, stride=2),
|
|
Deconv1d(16, 16, kernel_size=2, stride=2),
|
|
Deconv1d(16, input_dim, kernel_size=2, stride=2, use_batch_norm=False, activation=None))
|
|
|
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
|
return self.auto_encoder(input_data)
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
|
|
parser.add_argument('--batch', type=int, default=128, help='Batch size')
|
|
parser.add_argument('--sequence', type=int, default=128, help='Max sequence length')
|
|
parser.add_argument('--dimension', type=int, default=16, help='Input dimension')
|
|
# parser.add_argument('--latent', type=int, default=16, help='Latent space size')
|
|
parser.add_argument('--step', type=int, default=80_000, help='Number of steps to train')
|
|
arguments = parser.parse_args()
|
|
|
|
output_dir: Path = arguments.output
|
|
batch_size: int = arguments.batch
|
|
sequence_size: int = arguments.sequence
|
|
input_dim: int = arguments.dimension
|
|
# latent_size: int = arguments.latent
|
|
max_step: int = arguments.step
|
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
network = SequenceAutoEncoder(input_dim).to(device)
|
|
|
|
output_dir = output_dir / f'sequence_encoder_b{batch_size}_s{sequence_size}'
|
|
if not output_dir.exists():
|
|
output_dir.mkdir(parents=True)
|
|
if (output_dir / 'train').exists():
|
|
shutil.rmtree(output_dir / 'train')
|
|
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
|
|
writer_val = SummaryWriter(log_dir=str(output_dir / 'val'), flush_secs=20)
|
|
data_sample = torch.zeros((2, input_dim, sequence_size)).to(device)
|
|
|
|
torch.save(network.state_dict(), output_dir / 'model_ini.pt')
|
|
writer_train.add_graph(network, (data_sample,))
|
|
|
|
# Save parameters info
|
|
with open(output_dir / 'parameters.csv', 'w') as param_file:
|
|
param_summary = parameter_summary(network)
|
|
names = [len(name) for name, _, _ in param_summary]
|
|
shapes = [len(str(shape)) for _, shape, _ in param_summary]
|
|
param_file.write(
|
|
'\n'.join(
|
|
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
|
|
for name, shape, size in param_summary]))
|
|
|
|
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6)
|
|
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.994)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
def val_summary(step):
|
|
nonlocal network
|
|
|
|
network.eval()
|
|
val_metrics = Metrics()
|
|
for _ in range(10):
|
|
label_np = generate_data(batch_size, sequence_size, input_dim)
|
|
label = torch.from_numpy(label_np).to(device)
|
|
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
|
|
|
|
outputs = network(data)
|
|
|
|
val_metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item()
|
|
val_metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
|
|
val_metrics.count += batch_size
|
|
writer_val.add_scalar('metric/error', 1 - (val_metrics.accuracy / val_metrics.count), global_step=step)
|
|
writer_val.add_scalar('metric/score_inv', 1 - (val_metrics.score / val_metrics.count), global_step=step)
|
|
network.train()
|
|
|
|
metrics = Metrics()
|
|
summary_period = max_step // 100
|
|
np.set_printoptions(precision=2)
|
|
try:
|
|
start_time = time.time()
|
|
for step in range(1, max_step + 1):
|
|
label_np = generate_data(batch_size, sequence_size, input_dim)
|
|
|
|
label = torch.from_numpy(label_np).to(device)
|
|
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
outputs = network(data)
|
|
|
|
loss = criterion(outputs, label)
|
|
loss.backward()
|
|
|
|
metrics.loss += loss.item()
|
|
metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item()
|
|
metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
|
|
metrics.count += batch_size
|
|
|
|
optimizer.step()
|
|
|
|
if step % summary_period == 0:
|
|
writer_train.add_scalar('metric/loss', metrics.loss / summary_period, global_step=step)
|
|
writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step)
|
|
writer_train.add_scalar('metric/score_inv', 1 - (metrics.score / metrics.count), global_step=step)
|
|
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
|
|
scheduler.step()
|
|
|
|
train_time = time.time() - start_time
|
|
print(f'Step {step}, loss: {metrics.loss / summary_period:.03e}'
|
|
f', acc: {metrics.accuracy / metrics.count:.03e}'
|
|
f', score_inv: {1 - (metrics.score / metrics.count):.03e}'
|
|
f', speed: {summary_period / train_time:0.3f}step/s'
|
|
f' => {metrics.count / train_time:.02f}input/s')
|
|
val_summary(step)
|
|
start_time = time.time()
|
|
metrics.reset()
|
|
except KeyboardInterrupt:
|
|
print('\r ', end='\r')
|
|
writer_train.close()
|
|
writer_val.close()
|
|
|
|
network.eval()
|
|
test_label = [generate_data(1, sequence_size, input_dim) for _ in range(10)]
|
|
metrics.reset()
|
|
for label_np in test_label:
|
|
label = torch.from_numpy(label_np).to(device)
|
|
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
|
|
|
|
outputs = network(data)
|
|
# sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
|
|
|
|
# metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item()
|
|
metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).float().sum().item()
|
|
metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
|
|
metrics.count += 1
|
|
print(f'label: {label_np}\noutput: {torch.argmax(outputs, 1).detach().cpu().numpy()}')
|
|
print(f'\nTest accuracy: {metrics.accuracy / metrics.count:.05f}')
|
|
print(f'Test score: {metrics.score / metrics.count:.05f}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|