Add tensorflow implementation and Encoder experiment
This commit is contained in:
parent
30f55fa967
commit
db52231fa0
6 changed files with 471 additions and 72 deletions
171
sequence_encoder.py
Normal file
171
sequence_encoder.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.metrics import Metrics
|
||||
from src.torch_utils.layers import Conv1d, Deconv1d
|
||||
from src.torch_utils.train import parameter_summary
|
||||
|
||||
|
||||
def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray:
|
||||
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
|
||||
|
||||
|
||||
class SequenceAutoEncoder(nn.Module):
|
||||
def __init__(self, input_dim: int):
|
||||
super().__init__()
|
||||
self.auto_encoder = nn.Sequential(
|
||||
Conv1d(input_dim, 16, kernel_size=2, stride=2),
|
||||
Conv1d(16, 16, kernel_size=2, stride=2),
|
||||
Conv1d(16, 16, kernel_size=2, stride=2),
|
||||
Conv1d(16, 16, kernel_size=2, stride=2),
|
||||
Deconv1d(16, 16, kernel_size=2, stride=2),
|
||||
Deconv1d(16, 16, kernel_size=2, stride=2),
|
||||
Deconv1d(16, 16, kernel_size=2, stride=2),
|
||||
Deconv1d(16, input_dim, kernel_size=2, stride=2, use_batch_norm=False, activation=None))
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return self.auto_encoder(input_data)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
|
||||
parser.add_argument('--batch', type=int, default=128, help='Batch size')
|
||||
parser.add_argument('--sequence', type=int, default=128, help='Max sequence length')
|
||||
parser.add_argument('--dimension', type=int, default=16, help='Input dimension')
|
||||
# parser.add_argument('--latent', type=int, default=16, help='Latent space size')
|
||||
parser.add_argument('--step', type=int, default=80_000, help='Number of steps to train')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
output_dir: Path = arguments.output
|
||||
batch_size: int = arguments.batch
|
||||
sequence_size: int = arguments.sequence
|
||||
input_dim: int = arguments.dimension
|
||||
# latent_size: int = arguments.latent
|
||||
max_step: int = arguments.step
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
network = SequenceAutoEncoder(input_dim).to(device)
|
||||
|
||||
output_dir = output_dir / f'sequence_encoder_b{batch_size}_s{sequence_size}'
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
if (output_dir / 'train').exists():
|
||||
shutil.rmtree(output_dir / 'train')
|
||||
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
|
||||
writer_val = SummaryWriter(log_dir=str(output_dir / 'val'), flush_secs=20)
|
||||
data_sample = torch.zeros((2, input_dim, sequence_size)).to(device)
|
||||
|
||||
torch.save(network.state_dict(), output_dir / 'model_ini.pt')
|
||||
writer_train.add_graph(network, (data_sample,))
|
||||
|
||||
# Save parameters info
|
||||
with open(output_dir / 'parameters.csv', 'w') as param_file:
|
||||
param_summary = parameter_summary(network)
|
||||
names = [len(name) for name, _, _ in param_summary]
|
||||
shapes = [len(str(shape)) for _, shape, _ in param_summary]
|
||||
param_file.write(
|
||||
'\n'.join(
|
||||
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
|
||||
for name, shape, size in param_summary]))
|
||||
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-6)
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.994)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
def val_summary(step):
|
||||
nonlocal network
|
||||
|
||||
network.eval()
|
||||
val_metrics = Metrics()
|
||||
for _ in range(10):
|
||||
label_np = generate_data(batch_size, sequence_size, input_dim)
|
||||
label = torch.from_numpy(label_np).to(device)
|
||||
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
|
||||
|
||||
outputs = network(data)
|
||||
|
||||
val_metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item()
|
||||
val_metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
|
||||
val_metrics.count += batch_size
|
||||
writer_val.add_scalar('metric/error', 1 - (val_metrics.accuracy / val_metrics.count), global_step=step)
|
||||
writer_val.add_scalar('metric/score_inv', 1 - (val_metrics.score / val_metrics.count), global_step=step)
|
||||
network.train()
|
||||
|
||||
metrics = Metrics()
|
||||
summary_period = max_step // 100
|
||||
np.set_printoptions(precision=2)
|
||||
try:
|
||||
start_time = time.time()
|
||||
for step in range(1, max_step + 1):
|
||||
label_np = generate_data(batch_size, sequence_size, input_dim)
|
||||
|
||||
label = torch.from_numpy(label_np).to(device)
|
||||
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = network(data)
|
||||
|
||||
loss = criterion(outputs, label)
|
||||
loss.backward()
|
||||
|
||||
metrics.loss += loss.item()
|
||||
metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).sum().item()
|
||||
metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
|
||||
metrics.count += batch_size
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if step % summary_period == 0:
|
||||
writer_train.add_scalar('metric/loss', metrics.loss / summary_period, global_step=step)
|
||||
writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step)
|
||||
writer_train.add_scalar('metric/score_inv', 1 - (metrics.score / metrics.count), global_step=step)
|
||||
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
|
||||
scheduler.step()
|
||||
|
||||
train_time = time.time() - start_time
|
||||
print(f'Step {step}, loss: {metrics.loss / summary_period:.03e}'
|
||||
f', acc: {metrics.accuracy / metrics.count:.03e}'
|
||||
f', score_inv: {1 - (metrics.score / metrics.count):.03e}'
|
||||
f', speed: {summary_period / train_time:0.3f}step/s'
|
||||
f' => {metrics.count / train_time:.02f}input/s')
|
||||
val_summary(step)
|
||||
start_time = time.time()
|
||||
metrics.reset()
|
||||
except KeyboardInterrupt:
|
||||
print('\r ', end='\r')
|
||||
writer_train.close()
|
||||
writer_val.close()
|
||||
|
||||
network.eval()
|
||||
test_label = [generate_data(1, sequence_size, input_dim) for _ in range(10)]
|
||||
metrics.reset()
|
||||
for label_np in test_label:
|
||||
label = torch.from_numpy(label_np).to(device)
|
||||
data = nn.functional.one_hot(label, input_dim).float().transpose(2, 1)
|
||||
|
||||
outputs = network(data)
|
||||
# sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
|
||||
|
||||
# metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item()
|
||||
metrics.accuracy += ((torch.argmax(outputs, 1) == label).sum(1) == sequence_size).float().sum().item()
|
||||
metrics.score += (torch.argmax(outputs, 1) == label).float().mean(1).sum().item()
|
||||
metrics.count += 1
|
||||
print(f'label: {label_np}\noutput: {torch.argmax(outputs, 1).detach().cpu().numpy()}')
|
||||
print(f'\nTest accuracy: {metrics.accuracy / metrics.count:.05f}')
|
||||
print(f'Test score: {metrics.score / metrics.count:.05f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue