Fix networks

This commit is contained in:
Corentin 2021-08-30 23:21:58 +09:00
commit 1704b7aad1
4 changed files with 343 additions and 145 deletions

3
.gitignore vendored
View file

@ -1,3 +1,4 @@
*.pyc
output
output
save

View file

@ -2,6 +2,7 @@ from argparse import ArgumentParser
from pathlib import Path
import math
import shutil
import sys
import time
import numpy as np
@ -9,7 +10,7 @@ import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import LSTMModel, LSTMCellModel, StackedLSTMModel
from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer
from src.torch_utils.utils.batch_generator import BatchGenerator
from src.torch_utils.train import parameter_summary
@ -50,18 +51,21 @@ class DataGenerator:
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir')
parser.add_argument('--model', default='torch-lstm', help='Model to train')
parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=12, help='Max sequence length')
parser.add_argument('--hidden', type=int, default=16, help='LSTM cells hidden size')
parser.add_argument('--step', type=int, default=2000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
model: str = arguments.model
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
hidden_size: int = arguments.hidden
max_step: int = arguments.step
model: str = arguments.model
output_dir = output_dir.parent / f'modulo_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}'
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
@ -71,15 +75,24 @@ def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
if model == 'stack':
network = StackedLSTMModel(1, 16, 2).to(device)
elif model == 'cell':
network = LSTMCellModel(1, 16, 2).to(device)
network = CustomLSTMModel(1, hidden_size, 2).to(device)
elif model == 'stack-torchcell':
network = CustomLSTMModel(1, hidden_size, 2, cell_class=nn.LSTMCell).to(device)
elif model == 'chain':
network = CustomLSTMModel(1, hidden_size, 2, layer_class=ChainLSTMLayer).to(device)
elif model == 'torch-cell':
network = TorchLSTMCellModel(1, hidden_size, 2).to(device)
elif model == 'torch-lstm':
network = TorchLSTMModel(1, hidden_size, 2).to(device)
else:
network = LSTMModel(1, 16, 2).to(device)
print('Error : Unkown model')
sys.exit(1)
torch.save(network.state_dict(), output_dir / 'model_ini.pt')
input_sample = torch.from_numpy(generate_data(2, 4)[0]).to(device)
writer_train.add_graph(network, (input_sample,))
# Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file:
with open(output_dir / 'parameters.csv', 'w', encoding='utf-8') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
@ -88,7 +101,7 @@ def main():
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
criterion = nn.CrossEntropyLoss()
@ -99,17 +112,10 @@ def main():
(batch_size, max_step)).transpose((1, 0)), (batch_size * max_step))
dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8)
DataGenerator.MAX_LENGTH = sequence_size
if model in ['cell', 'stack']:
state = [(torch.zeros((batch_size, 16)).to(device),
torch.zeros((batch_size, 16)).to(device))] * network.NUM_LAYERS
else:
state = None
with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size,
pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator:
# data_np, _ = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data_np = batch_generator.batch_data
label_np = batch_generator.batch_label
# writer_train.add_graph(network, (torch.from_numpy(data_np).to(device),))
running_loss = 0.0
running_accuracy = 0.0
@ -126,7 +132,7 @@ def main():
optimizer.zero_grad(set_to_none=True)
outputs, _states = network(data, state)
outputs, _states = network(data)
loss = criterion(outputs[-1], label)
running_loss += loss.item()
outputs_np = outputs[-1].detach().cpu().numpy()
@ -164,7 +170,7 @@ def main():
data = torch.from_numpy(data_np).to(device)
label = torch.from_numpy(label_np).to(device)
outputs, _states = network(data, state)
outputs, _states = network(data)
outputs_np = outputs[-1].detach().cpu().numpy()
running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype(
np.float32).mean()
@ -189,13 +195,9 @@ def main():
test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32)
running_accuracy = 0.0
running_count = 0
if model in ['cell', 'stack']:
state = [(torch.zeros((1, 16)).to(device),
torch.zeros((1, 16)).to(device))] * network.NUM_LAYERS
for data, label in zip(test_data, test_label):
outputs, _states = network(
torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device),
state)
torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device))
outputs_np = outputs[-1].detach().cpu().numpy()
output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label
running_accuracy += 1.0 if output_correct else 0.0

View file

@ -1,6 +1,7 @@
from argparse import ArgumentParser
from pathlib import Path
import shutil
import sys
import time
import numpy as np
@ -8,7 +9,7 @@ import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import LSTMModel, StackedLSTMModel
from src.torch_networks import TorchLSTMModel, TorchLSTMCellModel, CustomLSTMModel, ChainLSTMLayer, BNLSTMCell
from src.torch_utils.train import parameter_summary
@ -16,34 +17,74 @@ def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
def score_sequences(sequences: np.ndarray) -> float:
score = 0.0
max_score = 0.0
for sequence_data in sequences:
sequence_score = 0.0
for result in sequence_data:
if not result:
break
sequence_score += 1.0
if sequence_score > max_score:
max_score = sequence_score
score += sequence_score
return score, max_score
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
parser.add_argument('--model', default='torch-lstm', help='Model to train')
parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
parser.add_argument('--hidden', type=int, default=32, help='Hidden dimension')
parser.add_argument('--lstm', type=int, default=3, help='LSTM layer stack length')
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
model: str = arguments.model
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
input_dim: int = arguments.dimension
hidden_size: int = arguments.hidden
num_layer: int = arguments.lstm
max_step: int = arguments.step
model: str = arguments.model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
if model == 'stack':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'stack-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=BNLSTMCell).to(device)
elif model == 'stack-torchcell':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer, cell_class=nn.LSTMCell).to(device)
elif model == 'chain':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainLSTMLayer).to(device)
elif model == 'chain-bn':
network = CustomLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer,
layer_class=ChainLSTMLayer, cell_class=BNLSTMCell).to(device)
elif model == 'torch-cell':
network = TorchLSTMCellModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
elif model == 'torch-lstm':
network = TorchLSTMModel(input_dim + 1, hidden_size, num_layer=num_layer).to(device)
else:
print('Error : Unkown model')
sys.exit(1)
output_dir = output_dir.parent / f'recorder_{model}_b{batch_size}_s{sequence_size}_h{hidden_size}_l{num_layer}'
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
data_sample = torch.zeros((2, 4, input_dim + 1)).to(device)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if model == 'cell':
network = StackedLSTMModel(input_dim + 1).to(device)
else:
network = LSTMModel(input_dim + 1).to(device)
torch.save(network.state_dict(), output_dir / 'model_ini.pt')
writer_train.add_graph(network, (data_sample,))
# Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file:
@ -55,64 +96,129 @@ def main():
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
criterion = nn.CrossEntropyLoss()
zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device)
# writer_train.add_graph(network, (zero_data[:, :5],))
class Metrics:
class Bench:
def __init__(self):
self.data_gen = 0.0
self.data_process = 0.0
self.predict = 0.0
self.loss = 0.0
self.backprop = 0.0
self.optimizer = 0.0
self.metrics = 0.0
if model == 'cell':
state = [(torch.zeros((batch_size, (input_dim + 1) * 2)).to(device),
torch.zeros((batch_size, (input_dim + 1) * 2)).to(device))] * network.NUM_LAYERS
else:
state = None
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
def reset(self):
self.data_gen = 0.0
self.data_process = 0.0
self.predict = 0.0
self.loss = 0.0
self.backprop = 0.0
self.optimizer = 0.0
self.metrics = 0.0
def get_proportions(self, train_time: float) -> str:
return (
f'data_gen: {self.data_gen / train_time:.02f}'
f', data_process: {self.data_process / train_time:.02f}'
f', predict: {self.predict / train_time:.02f}'
f', loss: {self.loss / train_time:.02f}'
f', backprop: {self.backprop / train_time:.02f}'
f', optimizer: {self.optimizer / train_time:.02f}'
f', metrics: {self.metrics / train_time:.02f}')
def __init__(self):
self.loss = 0.0
self.accuracy = 0.0
self.score = 0.0
self.max_score = 0.0
self.count = 0
self.bench = self.Bench()
def reset(self):
self.loss = 0.0
self.accuracy = 0.0
self.score = 0.0
self.max_score = 0.0
self.count = 0
self.bench.reset()
metrics = Metrics()
summary_period = max_step // 100
min_sequence_size = 4
zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device)
np.set_printoptions(precision=2)
try:
start_time = time.time()
for step in range(1, max_step + 1):
label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)), input_dim)
step_start_time = time.time()
label_np = generate_data(
batch_size, int(np.random.uniform(min_sequence_size, sequence_size + 1)), input_dim)
data_gen_time = time.time()
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0
data_process_time = time.time()
optimizer.zero_grad()
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state)
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0))
outputs = outputs[label_np.shape[1]:, :, :-1]
# state = (state[0].detach(), state[1].detach())
# state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]),
# state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:]))
predict_time = time.time()
data[:, :, input_dim] = 0.0
loss = criterion(outputs[:, 0], label[0])
for i in range(1, batch_size):
loss += criterion(outputs[:, i], label[i])
loss /= batch_size
running_loss += loss.item()
running_accuracy += (
torch.sum((torch.argmax(outputs, 2).transpose(1, 0) == label).long(),
1) == label.size(1)).float().mean().item()
running_count += 1
loss_time = time.time()
loss.backward()
backprop_time = time.time()
optimizer.step()
optim_time = time.time()
sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
average_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy())
# if current_score > min_sequence_size * 2 and min_sequence_size < sequence_size - 1:
# min_sequence_size += 1
metrics.loss += loss.item()
metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().sum().item()
metrics.score += average_score
if max_score > metrics.max_score:
metrics.max_score = max_score
metrics.count += batch_size
metrics.bench.data_gen += data_gen_time - step_start_time
metrics.bench.data_process += data_process_time - data_gen_time
metrics.bench.predict += predict_time - data_process_time
metrics.bench.loss += loss_time - predict_time
metrics.bench.backprop += backprop_time - loss_time
metrics.bench.optimizer += optim_time - backprop_time
metrics.bench.metrics += time.time() - optim_time
if step % summary_period == 0:
writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=step)
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=step)
writer_train.add_scalar('metric/loss', metrics.loss / metrics.count, global_step=step)
writer_train.add_scalar('metric/error', 1 - (metrics.accuracy / metrics.count), global_step=step)
writer_train.add_scalar('metric/score', metrics.score / metrics.count, global_step=step)
writer_train.add_scalar('metric/max_score', metrics.max_score, global_step=step)
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
scheduler.step()
speed = summary_period / (time.time() - start_time)
print(f'Step {step}, loss: {running_loss / running_count:.03e}'
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
train_time = time.time() - start_time
print(f'Step {step}, loss: {metrics.loss / metrics.count:.03e}'
f', acc: {metrics.accuracy / metrics.count:.03e}'
f', score: {metrics.score / metrics.count:.03f}'
f', speed: {summary_period / train_time:0.3f}step/s'
f' => {metrics.count / train_time:.02f}input/s'
f'\n ({metrics.bench.get_proportions(train_time)})')
start_time = time.time()
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
loss.backward()
optimizer.step()
metrics.reset()
except KeyboardInterrupt:
print('\r ', end='\r')
writer_train.close()
@ -121,34 +227,44 @@ def main():
test_label = [
np.asarray([[0, 0, 0, 0]], dtype=np.int64),
np.asarray([[2, 2, 2, 2]], dtype=np.int64),
np.asarray([[1, 2, 3, 4]], dtype=np.int64),
np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64),
generate_data(1, 4, input_dim),
np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64),
np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64),
np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64),
generate_data(1, max(4, sequence_size // 4), input_dim),
generate_data(1, max(4, sequence_size // 2), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, max(4, sequence_size * 3 // 4), input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim),
generate_data(1, sequence_size, input_dim)
]
zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device)
if model == 'cell':
state = [(torch.zeros((1, (input_dim + 1) * 2)).to(device),
torch.zeros((1, (input_dim + 1) * 2)).to(device))] * 3
running_accuracy = 0.0
running_count = 0
metrics.reset()
for label_np in test_label:
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state)
outputs = outputs[label_np.shape[1]:, :, :-1]
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0))
# state = (state[0].detach(), state[1].detach())
outputs = outputs[label_np.shape[1]:, :, :-1]
sequence_correct = torch.argmax(outputs, 2).transpose(1, 0) == label
current_score, max_score = score_sequences(sequence_correct.detach().cpu().numpy())
running_accuracy += (
torch.sum(
(torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item()
running_count += 1
print(f'{len(label_np)} label: {label_np}, output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}')
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
metrics.accuracy += (torch.sum(sequence_correct.long(), 1) == label.size(1)).float().mean().item()
metrics.score += average_score
if max_score > metrics.max_score:
metrics.max_score = max_score
metrics.count += 1
print(f'score: {current_score}/{label_np.shape[1]}, label: {label_np}'
f', output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}')
print(f'Test accuracy: {metrics.accuracy / metrics.count:.03f}')
if __name__ == '__main__':

View file

@ -1,16 +1,17 @@
import math
from typing import Optional
import torch
from torch import nn
class LSTMModel(nn.Module):
NUM_LAYERS = 3
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
class TorchLSTMModel(nn.Module):
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3):
super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS)
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layer)
self.dense = nn.Linear(hidden_size, output_size)
def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[
@ -19,16 +20,57 @@ class LSTMModel(nn.Module):
return self.dense(output), state
class LSTMCell(torch.jit.ScriptModule):
class TorchLSTMCellModel(nn.Module):
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1, num_layer: int = 3):
super().__init__()
self.num_layer = num_layer
self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2
self.output_size = output_size if output_size > 0 else input_size
self.hidden_size = hidden_size
self.layers = nn.ModuleList([
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)] + [
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layer - 1)]
)
self.dense = nn.Linear(hidden_size, output_size)
def forward(self, input_data: torch.Tensor,
init_states: tuple[torch.Tensor, torch.Tensor] = (torch.zeros(1), torch.zeros(1))) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if len(init_states[0].shape) == 1:
zeros = torch.zeros(self.num_layer, input_data.size(1), self.hidden_size,
dtype=input_data.dtype, device=input_data.device)
init_states = (zeros, zeros)
output_h_states = torch.jit.annotate(list[torch.Tensor], [])
output_c_states = torch.jit.annotate(list[torch.Tensor], [])
cell_inputs = input_data.unbind(0)
cell_id = 0
for cell in self.layers:
cell_h_state = torch.jit.annotate(torch.Tensor, init_states[0][cell_id])
cell_c_state = torch.jit.annotate(torch.Tensor, init_states[1][cell_id])
cell_outputs = torch.jit.annotate(list[torch.Tensor], [])
for i in range(len(cell_inputs)):
cell_h_state, cell_c_state = cell(cell_inputs[i], (cell_h_state, cell_c_state))
cell_outputs += [cell_h_state]
cell_inputs = cell_outputs
output_h_states += [cell_h_state]
output_c_states += [cell_c_state]
cell_id += 1
return self.dense(torch.stack(cell_outputs)), (torch.stack(output_h_states), torch.stack(output_c_states))
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.randn(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(torch.randn(hidden_size, 4 * hidden_size))
self.bias = nn.Parameter(torch.randn(4 * hidden_size))
self.weight_ih = nn.Parameter(torch.empty(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size))
self.bias = nn.Parameter(torch.empty(4 * hidden_size))
self.reset_parameters()
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, torch.Tensor]:
hx, cx = state
@ -46,13 +88,54 @@ class LSTMCell(torch.jit.ScriptModule):
return (hy, cy)
def reset_parameters(self) -> None:
for weight in [self.weight_hh, self.weight_ih]:
nn.init.xavier_normal_(weight)
nn.init.zeros_(self.bias)
class LSTMLayer(torch.jit.ScriptModule):
class BNLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.cell = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.empty(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(torch.empty(hidden_size, 4 * hidden_size))
self.bias = nn.Parameter(torch.empty(4 * hidden_size))
self.bn_1 = nn.BatchNorm1d(4 * hidden_size)
self.bn_2 = nn.BatchNorm1d(4 * hidden_size)
self.bn_out = nn.BatchNorm1d(hidden_size)
self.reset_parameters()
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, torch.Tensor]:
hx, cx = state
gates = (
(self.bn_1(torch.mm(input_data, self.weight_ih)) + self.bn_2(torch.mm(hx, self.weight_hh)) + self.bias))
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(self.bn_out(cy))
return (hy, cy)
def reset_parameters(self) -> None:
for weight in [self.weight_hh, self.weight_ih]:
nn.init.xavier_normal_(weight)
nn.init.zeros_(self.bias)
class StackLSTMLayer(nn.Module):
def __init__(self, input_size, hidden_size, cell_class=LSTMCell):
super().__init__()
self.cell = cell_class(input_size=input_size, hidden_size=hidden_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
inputs = input_data.unbind(0)
@ -63,74 +146,70 @@ class LSTMLayer(torch.jit.ScriptModule):
return torch.stack(outputs), state
class StackedLSTM(torch.jit.ScriptModule):
def __init__(self, input_size, hidden_size, num_layers):
class ChainLSTMLayer(nn.Module):
def __init__(self, input_size, hidden_size, cell_class=LSTMCell):
super().__init__()
self.layers = nn.ModuleList(
[LSTMLayer(input_size=input_size, hidden_size=hidden_size)] + [
LSTMLayer(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layers - 1)])
self.cell = cell_class(input_size=input_size, hidden_size=hidden_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], [])
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
inputs = input_data.unbind(0)
outputs = torch.jit.annotate(list[torch.Tensor], [])
for i in range(len(inputs)):
state = self.cell(inputs[i], state)
outputs += [state[1]]
return torch.stack(outputs), state
class _CustomLSTMLayer(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, layer_class=StackLSTMLayer, cell_class=LSTMCell):
super().__init__()
self.num_layer = num_layers
self.layers = nn.ModuleList(
[layer_class(input_size=input_size, hidden_size=hidden_size, cell_class=cell_class)] + [
layer_class(input_size=hidden_size, hidden_size=hidden_size, cell_class=cell_class)
for _ in range(num_layers - 1)])
def forward(self, input_data: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
output_states = torch.jit.annotate(list[torch.Tensor], [])
output_cell_states = torch.jit.annotate(list[torch.Tensor], [])
output = input_data
i = 0
for rnn_layer in self.layers:
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
output, out_state = rnn_layer(output, (states[0][i], states[1][i]))
output_states += [out_state[0]]
output_cell_states += [out_state[1]]
i += 1
return output, output_states
return output, (torch.stack(output_states), torch.stack(output_cell_states))
class StackedLSTMModel(torch.jit.ScriptModule):
NUM_LAYERS = 3
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
class CustomLSTMModel(nn.Module):
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1,
num_layer: int = 3, layer_class=StackLSTMLayer, cell_class=LSTMCell):
super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
self.hidden_size = hidden_size if hidden_size > 0 else input_size * 2
self.output_size = output_size if output_size > 0 else input_size
self.lstm = StackedLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS)
self.dense = nn.Linear(hidden_size, output_size)
self.lstm = _CustomLSTMLayer(input_size=input_size, hidden_size=self.hidden_size,
num_layers=num_layer, layer_class=layer_class, cell_class=cell_class)
self.dense = nn.Linear(hidden_size, self.output_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, init_state: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
self.zero_state = torch.nn.Parameter(
torch.zeros(self.lstm.num_layer, self.hidden_size), requires_grad=False)
self.zero_cell_state = torch.nn.Parameter(
torch.zeros(self.lstm.num_layer, self.hidden_size), requires_grad=False)
def forward(self,
input_data: torch.Tensor,
init_state: tuple[torch.Tensor, torch.Tensor] = (torch.zeros((1)), torch.zeros(1))
) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if len(init_state[0].shape) == 1:
init_state = (
self.zero_state.expand(
input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0),
self.zero_cell_state.expand(
input_data.size(1), self.lstm.num_layer, self.hidden_size).transpose(1, 0))
output, state = self.lstm(input_data, init_state)
return self.dense(output), state
class LSTMCellModel(torch.jit.ScriptModule):
NUM_LAYERS = 3
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
self.hidden_size = hidden_size
self.layers = nn.ModuleList([
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size),
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size),
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
])
self.dense = nn.Linear(hidden_size, output_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, init_states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], [])
cell_inputs = input_data.unbind(0)
cell_id = 0
for cell in self.layers:
cell_state = init_states[cell_id]
cell_outputs = torch.jit.annotate(list[torch.Tensor], [])
for i in range(len(cell_inputs)):
cell_state = cell(cell_inputs[i], cell_state)
cell_outputs += [cell_state[0]]
cell_inputs = cell_outputs
output_states += [cell_state]
cell_id += 1
return self.dense(torch.stack(cell_outputs)), output_states