Fix networks

This commit is contained in:
Corentin 2021-08-30 23:21:58 +09:00
commit 1704b7aad1
4 changed files with 343 additions and 145 deletions

3
.gitignore vendored
View file

@ -1,3 +1,4 @@
*.pyc *.pyc
output output
save

View file

@ -2,6 +2,7 @@ from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
import math import math
import shutil import shutil
import sys
import time import time
import numpy as np import numpy as np
@ -9,7 +10,7 @@ import torch
from torch import nn from torch import nn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import LSTMModel, LSTMCellModel, StackedLSTMModel from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer
from src.torch_utils.utils.batch_generator import BatchGenerator from src.torch_utils.utils.batch_generator import BatchGenerator
from src.torch_utils.train import parameter_summary from src.torch_utils.train import parameter_summary
@ -50,18 +51,21 @@ class DataGenerator:
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir') parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir')
parser.add_argument('--model', default='torch-lstm', help='Model to train')
parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=12, help='Max sequence length') parser.add_argument('--sequence', type=int, default=12, help='Max sequence length')
parser.add_argument('--hidden', type=int, default=16, help='LSTM cells hidden size')
parser.add_argument('--step', type=int, default=2000, help='Number of steps to train') parser.add_argument('--step', type=int, default=2000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args() arguments = parser.parse_args()
output_dir: Path = arguments.output output_dir: Path = arguments.output
model: str = arguments.model
batch_size: int = arguments.batch batch_size: int = arguments.batch
sequence_size: int = arguments.sequence sequence_size: int = arguments.sequence
hidden_size: int = arguments.hidden
max_step: int = arguments.step max_step: int = arguments.step
model: str = arguments.model
output_dir = output_dir.parent / f'modulo_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}'
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
if (output_dir / 'train').exists(): if (output_dir / 'train').exists():
@ -71,15 +75,24 @@ def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
if model == 'stack': if model == 'stack':
network = StackedLSTMModel(1, 16, 2).to(device) network = CustomLSTMModel(1, hidden_size, 2).to(device)
elif model == 'cell': elif model == 'stack-torchcell':
network = LSTMCellModel(1, 16, 2).to(device) network = CustomLSTMModel(1, hidden_size, 2, cell_class=nn.LSTMCell).to(device)
elif model == 'chain':
network = CustomLSTMModel(1, hidden_size, 2, layer_class=ChainLSTMLayer).to(device)
elif model == 'torch-cell':
network = TorchLSTMCellModel(1, hidden_size, 2).to(device)
elif model == 'torch-lstm':
network = TorchLSTMModel(1, hidden_size, 2).to(device)
else: else:
network = LSTMModel(1, 16, 2).to(device) print('Error : Unkown model')
sys.exit(1)
torch.save(network.state_dict(), output_dir / 'model_ini.pt') torch.save(network.state_dict(), output_dir / 'model_ini.pt')
input_sample = torch.from_numpy(generate_data(2, 4)[0]).to(device)
writer_train.add_graph(network, (input_sample,))
# Save parameters info # Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file: with open(output_dir / 'parameters.csv', 'w', encoding='utf-8') as param_file:
param_summary = parameter_summary(network) param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary] names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary] shapes = [len(str(shape)) for _, shape, _ in param_summary]
@ -88,7 +101,7 @@ def main():
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary])) for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3) optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
@ -99,17 +112,10 @@ def main():
(batch_size, max_step)).transpose((1, 0)), (batch_size * max_step)) (batch_size, max_step)).transpose((1, 0)), (batch_size * max_step))
dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8) dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8)
DataGenerator.MAX_LENGTH = sequence_size DataGenerator.MAX_LENGTH = sequence_size
if model in ['cell', 'stack']:
state = [(torch.zeros((batch_size, 16)).to(device),
torch.zeros((batch_size, 16)).to(device))] * network.NUM_LAYERS
else:
state = None
with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size, with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size,
pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator: pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator:
# data_np, _ = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data_np = batch_generator.batch_data data_np = batch_generator.batch_data
label_np = batch_generator.batch_label label_np = batch_generator.batch_label
# writer_train.add_graph(network, (torch.from_numpy(data_np).to(device),))
running_loss = 0.0 running_loss = 0.0
running_accuracy = 0.0 running_accuracy = 0.0
@ -126,7 +132,7 @@ def main():
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
outputs, _states = network(data, state) outputs, _states = network(data)
loss = criterion(outputs[-1], label) loss = criterion(outputs[-1], label)
running_loss += loss.item() running_loss += loss.item()
outputs_np = outputs[-1].detach().cpu().numpy() outputs_np = outputs[-1].detach().cpu().numpy()
@ -164,7 +170,7 @@ def main():
data = torch.from_numpy(data_np).to(device) data = torch.from_numpy(data_np).to(device)
label = torch.from_numpy(label_np).to(device) label = torch.from_numpy(label_np).to(device)
outputs, _states = network(data, state) outputs, _states = network(data)
outputs_np = outputs[-1].detach().cpu().numpy() outputs_np = outputs[-1].detach().cpu().numpy()
running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype( running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype(
np.float32).mean() np.float32).mean()
@ -189,13 +195,9 @@ def main():
test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32) test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32)
running_accuracy = 0.0 running_accuracy = 0.0
running_count = 0 running_count = 0
if model in ['cell', 'stack']:
state = [(torch.zeros((1, 16)).to(device),
torch.zeros((1, 16)).to(device))] * network.NUM_LAYERS
for data, label in zip(test_data, test_label): for data, label in zip(test_data, test_label):
outputs, _states = network( outputs, _states = network(
torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device), torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device))
state)
outputs_np = outputs[-1].detach().cpu().numpy() outputs_np = outputs[-1].detach().cpu().numpy()
output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label
running_accuracy += 1.0 if output_correct else 0.0 running_accuracy += 1.0 if output_correct else 0.0

View file

@ -1,6 +1,7 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
import shutil import shutil
import sys
import time import time
import numpy as np import numpy as np
@ -8,7 +9,7 @@ import torch
from torch import nn from torch import nn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import LSTMModel, StackedLSTMModel from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell
from src.torch_utils.train import parameter_summary from src.torch_utils.train import parameter_summary
@ -16,34 +17,74 @@ def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64) return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
def score_sequences(sequences: np.ndarray) -> float:
score = 0.0
max_score = 0.0
for sequence_data in sequences:
sequence_score = 0.0
for result in sequence_data:
if not result:
break
sequence_score += 1.0
if sequence_score > max_score:
max_score = sequence_score
score += sequence_score
return score, max_score
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
parser.add_argument('--model', default='torch-lstm', help='Model to train')
parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length') parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
parser.add_argument('--dimension', type=int, default=15, help='Input dimension') parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension')
parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length')
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train') parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args() arguments = parser.parse_args()
output_dir: Path = arguments.output output_dir: Path = arguments.output
model: str = arguments.model
batch_size: int = arguments.batch batch_size: int = arguments.batch
sequence_size: int = arguments.sequence sequence_size: int = arguments.sequence
input_dim: int = arguments.dimension input_dim: int = arguments.dimension
hidden_size: int = arguments.hidden
num_layer: int = arguments.lstm
max_step: int = arguments.step max_step: int = arguments.step
model: str = arguments.model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
if model == 'stack':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'stack-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
elif model == 'stack-torchcell':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
elif model == 'chain':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainLSTMLayer).to(device)
elif model == 'chain-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device)
elif model == 'torch-cell':
network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'torch-lstm':
network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
else:
print('Error : Unkown model')
sys.exit(1)
output_dir = output_dir.parent / f'recorder_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}_l{num_layer}'
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
if (output_dir / 'train').exists(): if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train') shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
data_sample = torch.zeros((2, 4, input_dim + 1)).to(device)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.save(network.state_dict(), output_dir / 'model_ini.pt')
if model == 'cell': writer_train.add_graph(network, (data_sample,))
network = StackedLSTMModel(input_dim + 1).to(device)
else:
network = LSTMModel(input_dim + 1).to(device)
# Save parameters info # Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file: with open(output_dir / 'parameters.csv', 'w') as param_file:
@ -55,64 +96,129 @@ def main():
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary])) for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3) optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device) class Metrics:
# writer_train.add_graph(network, (zero_data[:, :5],)) class Bench:
def __init__(self):
self.data_gen = 0.0
self.data_process = 0.0
self.predict = 0.0
self.loss = 0.0
self.backprop = 0.0
self.optimizer = 0.0
self.metrics = 0.0
if model == 'cell': def reset(self):
state = [(torch.zeros((batch_size, (input_dim + 1) * 2)).to(device), self.data_gen = 0.0
torch.zeros((batch_size, (input_dim + 1) * 2)).to(device))] * network.NUM_LAYERS self.data_process = 0.0
else: self.predict = 0.0
state = None self.loss = 0.0
running_loss = 0.0 self.backprop = 0.0
running_accuracy = 0.0 self.optimizer = 0.0
running_count = 0 self.metrics = 0.0
def get_proportions(self, train_time: float) -> str:
return (
f'data_gen: {self.data_gen / train_time:.02f}'
f', data_process: {self.data_process / train_time:.02f}'
f', predict: {self.predict / train_time:.02f}'
f', loss: {self.loss / train_time:.02f}'
f', backprop: {self.backprop / train_time:.02f}'
f', optimizer: {self.optimizer / train_time:.02f}'
f', metrics: {self.metrics / train_time:.02f}')
def __init__(self):
self.loss = 0.0
self.accuracy = 0.0
self.score = 0.0
self.max_score = 0.0
self.count = 0
self.bench = self.Bench()
def reset(self):
self.loss = 0.0
self.accuracy = 0.0
self.score = 0.0
self.max_score = 0.0
self.count = 0
self.bench.reset()
metrics = Metrics()
summary_period = max_step // 100 summary_period = max_step // 100
min_sequence_size = 4
zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device)
np.set_printoptions(precision=2) np.set_printoptions(precision=2)
try: try:
start_time = time.time() start_time = time.time()
for step in range(1, max_step + 1): for step in range(1, max_step + 1):
label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)), input_dim) step_start_time = time.time()
label_np = generate_data(
batch_size, int(np.random.uniform(min_sequence_size, sequence_size + 1)), input_dim)
data_gen_time = time.time()
label = torch.from_numpy(label_np).to(device) label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float() data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0 data[:, :, input_dim] = 1.0
data_process_time = time.time()
optimizer.zero_grad() optimizer.zero_grad()
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state) outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0))
outputs = outputs[label_np.shape[1]:, :, :-1] outputs = outputs[label_np.shape[1]:, :, :-1]
# state = (state[0].detach(), state[1].detach()) # state = (state[0].detach(), state[1].detach())
# state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]), # state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]),
# state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:])) # state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:]))
predict_time = time.time()
data[:, :, input_dim] = 0.0
loss = criterion(outputs[:, 0], label[0]) loss = criterion(outputs[:, 0], label[0])
for i in range(1, batch_size): for i in range(1, batch_size):
loss += criterion(outputs[:, i], label[i]) loss += criterion(outputs[:, i], label[i])
loss /= batch_size loss /= batch_size
running_loss += loss.item() loss_time = time.time()
running_accuracy += (
torch.sum((torch.argmax(outputs, 2).transpose(1, 0) == label).long(), loss.backward()
1) == label.size(1)).float().mean().item() backprop_time = time.time()
running_count += 1 optimizer.step()
optim_time = time.time()
sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
average_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy())
# if current_score > min_sequence_size * 2 and min_sequence_size < sequence_size - 1:
# min_sequence_size += 1
metrics.loss += loss.item()
metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().sum().item()
metrics.score += average_score
if max_score > metrics.max_score:
metrics.max_score = max_score
metrics.count += batch_size
metrics.bench.data_gen += data_gen_time - step_start_time
metrics.bench.data_process += data_process_time - data_gen_time
metrics.bench.predict += predict_time - data_process_time
metrics.bench.loss += loss_time - predict_time
metrics.bench.backprop += backprop_time - loss_time
metrics.bench.optimizer += optim_time - backprop_time
metrics.bench.metrics += time.time() - optim_time
if step % summary_period == 0: if step % summary_period == 0:
writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=step) writer_train.add_scalar('metric/loss', metrics.loss / metrics.count, global_step=step)
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=step) writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step)
writer_train.add_scalar('metric/score', metrics.score / metrics.count, global_step=step)
writer_train.add_scalar('metric/max_score', metrics.max_score, global_step=step)
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step) writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
scheduler.step() scheduler.step()
speed = summary_period / (time.time() - start_time) train_time = time.time() - start_time
print(f'Step {step}, loss: {running_loss / running_count:.03e}' print(f'Step {step}, loss: {metrics.loss / metrics.count:.03e}'
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') f', acc: {metrics.accuracy / metrics.count:.03e}'
f', score: {metrics.score / metrics.count:.03f}'
f', speed: {summary_period / train_time:0.3f}step/s'
f' => {metrics.count / train_time:.02f}input/s'
f'\n ({metrics.bench.get_proportions(train_time)})')
start_time = time.time() start_time = time.time()
running_loss = 0.0 metrics.reset()
running_accuracy = 0.0
running_count = 0
loss.backward()
optimizer.step()
except KeyboardInterrupt: except KeyboardInterrupt:
print('\r ', end='\r') print('\r ', end='\r')
writer_train.close() writer_train.close()
@ -121,34 +227,44 @@ def main():
test_label = [ test_label = [
np.asarray([[0, 0, 0, 0]], dtype=np.int64), np.asarray([[0, 0, 0, 0]], dtype=np.int64),
np.asarray([[2, 2, 2, 2]], dtype=np.int64), np.asarray([[2, 2, 2, 2]], dtype=np.int64),
np.asarray([[1, 2, 3, 4]], dtype=np.int64),
np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64), np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64),
generate_data(1, 4, input_dim), generate_data(1, 4, input_dim),
np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64), np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64),
np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64), np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64),
np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64), np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64),
generate_data(1, max(4, sequence_size // 4), input_dim),
generate_data(1, max(4, sequence_size // 2), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim) generate_data(1, sequence_size, input_dim)
] ]
zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device) zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device)
if model == 'cell': metrics.reset()
state = [(torch.zeros((1, (input_dim + 1) * 2)).to(device),
torch.zeros((1, (input_dim + 1) * 2)).to(device))] * 3
running_accuracy = 0.0
running_count = 0
for label_np in test_label: for label_np in test_label:
label = torch.from_numpy(label_np).to(device) label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float() data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0 data[:, :, input_dim] = 1.0
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state) outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0))
outputs = outputs[label_np.shape[1]:, :, :-1]
# state = (state[0].detach(), state[1].detach()) # state = (state[0].detach(), state[1].detach())
outputs = outputs[label_np.shape[1]:, :, :-1]
sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
current_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy())
running_accuracy += ( metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item()
torch.sum( metrics.score += average_score
(torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item() if max_score > metrics.max_score:
running_count += 1 metrics.max_score = max_score
print(f'{len(label_np)} label: {label_np}, output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}') metrics.count += 1
print(f'Test accuracy: {running_accuracy / running_count:.03f}') print(f'score: {current_score}/{label_np.shape[1]}, label: {label_np}'
f', output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}')
print(f'Test accuracy: {metrics.accuracy / metrics.count:.03f}')
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -1,16 +1,17 @@
import math
from typing import Optional
import torch import torch
from torch import nn from torch import nn
class LSTMModel(nn.Module): class TorchLSTMModel(nn.Module):
NUM_LAYERS = 3 def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3):
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
super().__init__() super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2 hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size output_size = output_size if output_size > 0 else input_size
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS) self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layer)
self.dense = nn.Linear(hidden_size, output_size) self.dense = nn.Linear(hidden_size, output_size)
def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[ def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[
@ -19,16 +20,57 @@ class LSTMModel(nn.Module):
return self.dense(output), state return self.dense(output), state
class LSTMCell(torch.jit.ScriptModule): class TorchLSTMCellModel(nn.Module):
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3):
super().__init__()
self.num_layer = num_layer
self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2
self.output_size = output_size if output_size > 0 else input_size
self.hidden_size = hidden_size
self.layers = nn.ModuleList([
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)] + [
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layer - 1)]
)
self.dense = nn.Linear(hidden_size, output_size)
def forward(self, input_data: torch.Tensor,
init_states: tuple[torch.Tensor, torch.Tensor] = (torch.zeros(1), torch.zeros(1))) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if len(init_states[0].shape) == 1:
zeros = torch.zeros(self.num_layer, input_data.size(1), self.hidden_size,
dtype=input_data.dtype, device=input_data.device)
init_states = (zeros, zeros)
output_h_states = torch.jit.annotate(list[torch.Tensor], [])
output_c_states = torch.jit.annotate(list[torch.Tensor], [])
cell_inputs = input_data.unbind(0)
cell_id = 0
for cell in self.layers:
cell_h_state = torch.jit.annotate(torch.Tensor, init_states[0][cell_id])
cell_c_state = torch.jit.annotate(torch.Tensor, init_states[1][cell_id])
cell_outputs = torch.jit.annotate(list[torch.Tensor], [])
for i in range(len(cell_inputs)):
cell_h_state, cell_c_state = cell(cell_inputs[i], (cell_h_state, cell_c_state))
cell_outputs += [cell_h_state]
cell_inputs = cell_outputs
output_h_states += [cell_h_state]
output_c_states += [cell_c_state]
cell_id += 1
return self.dense(torch.stack(cell_outputs)), (torch.stack(output_h_states), torch.stack(output_c_states))
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size): def __init__(self, input_size, hidden_size):
super().__init__() super().__init__()
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.randn(input_size, 4 * hidden_size)) self.weight_ih = nn.Parameter(torch.empty(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(torch.randn(hidden_size, 4 * hidden_size)) self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size))
self.bias = nn.Parameter(torch.randn(4 * hidden_size)) self.bias = nn.Parameter(torch.empty(4 * hidden_size))
self.reset_parameters()
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, torch.Tensor]: torch.Tensor, torch.Tensor]:
hx, cx = state hx, cx = state
@ -46,13 +88,54 @@ class LSTMCell(torch.jit.ScriptModule):
return (hy, cy) return (hy, cy)
def reset_parameters(self) -> None:
for weight in [self.weight_hh, self.weight_ih]:
nn.init.xavier_normal_(weight)
nn.init.zeros_(self.bias)
class LSTMLayer(torch.jit.ScriptModule):
class BNLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size): def __init__(self, input_size, hidden_size):
super().__init__() super().__init__()
self.cell = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.empty(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size))
self.bias = nn.Parameter(torch.empty(4 * hidden_size))
self.bn_1 = nn.BatchNorm1d(4 * hidden_size)
self.bn_2 = nn.BatchNorm1d(4 * hidden_size)
self.bn_out = nn.BatchNorm1d(hidden_size)
self.reset_parameters()
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, torch.Tensor]:
hx, cx = state
gates = (
(self.bn_1(torch.mm(input_data, self.weight_ih)) + self.bn_2(torch.mm(hx, self.weight_hh)) + self.bias))
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(self.bn_out(cy))
return (hy, cy)
def reset_parameters(self) -> None:
for weight in [self.weight_hh, self.weight_ih]:
nn.init.xavier_normal_(weight)
nn.init.zeros_(self.bias)
class StackLSTMLayer(nn.Module):
def __init__(self, input_size, hidden_size, cell_class=LSTMCell):
super().__init__()
self.cell = cell_class(input_size=input_size, hidden_size=hidden_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
inputs = input_data.unbind(0) inputs = input_data.unbind(0)
@ -63,74 +146,70 @@ class LSTMLayer(torch.jit.ScriptModule):
return torch.stack(outputs), state return torch.stack(outputs), state
class StackedLSTM(torch.jit.ScriptModule): class ChainLSTMLayer(nn.Module):
def __init__(self, input_size, hidden_size, num_layers): def __init__(self, input_size, hidden_size, cell_class=LSTMCell):
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.cell = cell_class(input_size=input_size, hidden_size=hidden_size)
[LSTMLayer(input_size=input_size, hidden_size=hidden_size)] + [
LSTMLayer(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layers - 1)])
@torch.jit.script_method def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
def forward(self, input_data: torch.Tensor, states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[ torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: inputs = input_data.unbind(0)
output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], []) outputs = torch.jit.annotate(list[torch.Tensor], [])
for i in range(len(inputs)):
state = self.cell(inputs[i], state)
outputs += [state[1]]
return torch.stack(outputs), state
class _CustomLSTMLayer(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, layer_class=StackLSTMLayer, cell_class=LSTMCell):
super().__init__()
self.num_layer = num_layers
self.layers = nn.ModuleList(
[layer_class(input_size=input_size, hidden_size=hidden_size, cell_class=cell_class)] + [
layer_class(input_size=hidden_size, hidden_size=hidden_size, cell_class=cell_class)
for _ in range(num_layers - 1)])
def forward(self, input_data: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
output_states = torch.jit.annotate(list[torch.Tensor], [])
output_cell_states = torch.jit.annotate(list[torch.Tensor], [])
output = input_data output = input_data
i = 0 i = 0
for rnn_layer in self.layers: for rnn_layer in self.layers:
state = states[i] output, out_state = rnn_layer(output, (states[0][i], states[1][i]))
output, out_state = rnn_layer(output, state) output_states += [out_state[0]]
output_states += [out_state] output_cell_states += [out_state[1]]
i += 1 i += 1
return output, output_states return output, (torch.stack(output_states), torch.stack(output_cell_states))
class StackedLSTMModel(torch.jit.ScriptModule): class CustomLSTMModel(nn.Module):
NUM_LAYERS = 3 def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1,
num_layer: int = 3, layer_class=StackLSTMLayer, cell_class=LSTMCell):
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
super().__init__() super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2 self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size self.output_size = output_size if output_size > 0 else input_size
self.lstm = StackedLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS) self.lstm = _CustomLSTMLayer(input_size=input_size, hidden_size=self.hidden_size,
self.dense = nn.Linear(hidden_size, output_size) num_layers=num_layer, layer_class=layer_class, cell_class=cell_class)
self.dense = nn.Linear(hidden_size, self.output_size)
@torch.jit.script_method self.zero_state = torch.nn.Parameter(
def forward(self, input_data: torch.Tensor, init_state: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[ torch.zeros(self.lstm.num_layer, self.hidden_size), requires_grad=False)
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: self.zero_cell_state = torch.nn.Parameter(
torch.zeros(self.lstm.num_layer, self.hidden_size), requires_grad=False)
def forward(self,
input_data: torch.Tensor,
init_state: tuple[torch.Tensor, torch.Tensor] = (torch.zeros((1)), torch.zeros(1))
) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if len(init_state[0].shape) == 1:
init_state = (
self.zero_state.expand(
input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0),
self.zero_cell_state.expand(
input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0))
output, state = self.lstm(input_data, init_state) output, state = self.lstm(input_data, init_state)
return self.dense(output), state return self.dense(output), state
class LSTMCellModel(torch.jit.ScriptModule):
NUM_LAYERS = 3
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
self.hidden_size = hidden_size
self.layers = nn.ModuleList([
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size),
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size),
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
])
self.dense = nn.Linear(hidden_size, output_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, init_states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], [])
cell_inputs = input_data.unbind(0)
cell_id = 0
for cell in self.layers:
cell_state = init_states[cell_id]
cell_outputs = torch.jit.annotate(list[torch.Tensor], [])
for i in range(len(cell_inputs)):
cell_state = cell(cell_inputs[i], cell_state)
cell_outputs += [cell_state[0]]
cell_inputs = cell_outputs
output_states += [cell_state]
cell_id += 1
return self.dense(torch.stack(cell_outputs)), output_states