Initial commit

This commit is contained in:
Corentin 2021-08-26 16:48:09 +09:00
commit e41417c029
8 changed files with 802 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
*.pyc
output

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "src/torch_utils"]
path = src/torch_utils
url = git@gitlab.com:corentin-pro/torch_utils.git

209
modulo.py Normal file
View file

@ -0,0 +1,209 @@
from argparse import ArgumentParser
from pathlib import Path
import math
import shutil
import time
import numpy as np
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import LSTMModel, LSTMCellModel, StackedLSTMModel
from src.torch_utils.utils.batch_generator import BatchGenerator
from src.torch_utils.train import parameter_summary
def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]:
modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32)
data = np.zeros((data_length, batch_size, 1), dtype=np.float32)
starts = []
for mod in modulos:
starts.append(int(np.random.uniform(0, mod)))
for i in range(batch_size):
# np.where(data[i] % modulos[i] == starts[i], [1.0], data[i])
for j in range(starts[i], data_length, modulos[i]):
data[j, i, 0] = 1.0
label = []
for i in range(batch_size):
label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0)
return data, np.asarray(label, dtype=np.int64)
class DataGenerator:
MAX_LENGTH = 1
INITIALIZED = False
@staticmethod
def pipeline(sequence_length, _dummy_label):
if not DataGenerator.INITIALIZED:
np.random.seed(time.time_ns() % (2**32))
DataGenerator.INITIALIZED = True
data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32)
modulo = int(np.random.uniform(3, sequence_length // 2 + 1))
start = int(np.random.uniform(0, modulo))
for i in range(start, sequence_length, modulo):
data[i, 0] = 1.0
return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64)
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir')
parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=12, help='Max sequence length')
parser.add_argument('--step', type=int, default=2000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
max_step: int = arguments.step
model: str = arguments.model
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
if model == 'stack':
network = StackedLSTMModel(1, 16, 2).to(device)
elif model == 'cell':
network = LSTMCellModel(1, 16, 2).to(device)
else:
network = LSTMModel(1, 16, 2).to(device)
torch.save(network.state_dict(), output_dir / 'model_ini.pt')
# Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
param_file.write(
'\n'.join(
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
criterion = nn.CrossEntropyLoss()
sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32)
sequence_data[0] = sequence_size
sequence_data_reshaped = np.reshape(np.broadcast_to(
sequence_data,
(batch_size, max_step)).transpose((1, 0)), (batch_size * max_step))
dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8)
DataGenerator.MAX_LENGTH = sequence_size
if model in ['cell', 'stack']:
state = [(torch.zeros((batch_size, 16)).to(device),
torch.zeros((batch_size, 16)).to(device))] * network.NUM_LAYERS
else:
state = None
with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size,
pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator:
# data_np, _ = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data_np = batch_generator.batch_data
label_np = batch_generator.batch_label
# writer_train.add_graph(network, (torch.from_numpy(data_np).to(device),))
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
summary_period = max(max_step // 100, 1)
np.set_printoptions(precision=2)
try:
start_time = time.time()
while batch_generator.epoch == 0:
# data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data = torch.from_numpy(
data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]]).to(device)
label = torch.from_numpy(label_np).to(device)
optimizer.zero_grad(set_to_none=True)
outputs, _states = network(data, state)
loss = criterion(outputs[-1], label)
running_loss += loss.item()
outputs_np = outputs[-1].detach().cpu().numpy()
running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype(
np.float32).mean()
running_count += 1
if (batch_generator.step + 1) % summary_period == 0:
writer_train.add_scalar('metric/loss', running_loss / running_count,
global_step=batch_generator.step)
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count),
global_step=batch_generator.step)
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0],
global_step=batch_generator.step)
scheduler.step()
speed = summary_period / (time.time() - start_time)
print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}'
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
start_time = time.time()
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
loss.backward()
optimizer.step()
data_np, label_np = batch_generator.next_batch()
except KeyboardInterrupt:
print('\r ', end='\r')
writer_train.close()
network.eval()
running_accuracy = 0.0
running_count = 0
for _ in range(math.ceil(1000 / batch_size)):
data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data = torch.from_numpy(data_np).to(device)
label = torch.from_numpy(label_np).to(device)
outputs, _states = network(data, state)
outputs_np = outputs[-1].detach().cpu().numpy()
running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype(
np.float32).mean()
running_count += 1
print(f'Validation accuracy: {running_accuracy / running_count:.03f}')
test_data = [
[[1.], [0.], [0.], [1.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
[[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]],
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]],
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
[0.], [1.], [0.], [0.], [1.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
[0.], [1.], [0.], [0.], [1.], [0.]],
]
test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32)
running_accuracy = 0.0
running_count = 0
if model in ['cell', 'stack']:
state = [(torch.zeros((1, 16)).to(device),
torch.zeros((1, 16)).to(device))] * network.NUM_LAYERS
for data, label in zip(test_data, test_label):
outputs, _states = network(
torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device),
state)
outputs_np = outputs[-1].detach().cpu().numpy()
output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label
running_accuracy += 1.0 if output_correct else 0.0
running_count += 1
print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}'
f', output: {int(outputs_np[0, 1] > outputs_np[0, 0])}')
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
if __name__ == '__main__':
main()

191
modulo_jax.py Normal file
View file

@ -0,0 +1,191 @@
from argparse import ArgumentParser
from pathlib import Path
import math
import shutil
import time
import sys
import jax
from jax import lax
import jax.numpy as jnp
from jax.experimental import optimizers
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from src.jax_networks import StackedLSTMModel
from src.torch_utils.utils.batch_generator import BatchGenerator
def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]:
modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32)
data = np.zeros((data_length, batch_size, 1), dtype=np.float32)
starts = []
for mod in modulos:
starts.append(int(np.random.uniform(0, mod)))
for i in range(batch_size):
# np.where(data[i] % modulos[i] == starts[i], [1.0], data[i])
for j in range(starts[i], data_length, modulos[i]):
data[j, i, 0] = 1.0
label = []
for i in range(batch_size):
label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0)
return data, np.asarray(label, dtype=np.int64)
class DataGenerator:
MAX_LENGTH = 1
INITIALIZED = False
@staticmethod
def pipeline(sequence_length, _dummy_label):
if not DataGenerator.INITIALIZED:
np.random.seed(time.time_ns() % (2**32))
DataGenerator.INITIALIZED = True
data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32)
modulo = int(np.random.uniform(3, sequence_length // 2 + 1))
start = int(np.random.uniform(0, modulo))
for i in range(start, sequence_length, modulo):
data[i, 0] = 1.0
return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64)
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir')
parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=12, help='Max sequence length')
parser.add_argument('--step', type=int, default=2000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
max_step: int = arguments.step
model: str = arguments.model
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
rng = jax.random.PRNGKey(0)
if model == 'stack':
network_init, network_fun = StackedLSTMModel(1, 16, 2)
_, network_params = network_init(rng, 1)
else:
print('Model not implemented')
sys.exit(1)
@jax.jit
def loss_fun(preds, targets):
return lax.mul(-targets, lax.log(preds)) - lax.mul(1 - targets, lax.log(1 - preds))
def accuracy_fun(preds, targets):
return jnp.mean((preds[:, 1] > preds[:, 0]).int64() == targets)
opt_init, opt_update, opt_params = optimizers.adam(1e-3)
@jax.jit
def update_fun(step, opt_state, preds, targets):
return opt_update(step, jax.grad(loss_fun)(preds, targets), opt_params(opt_state))
opt_state = opt_init(network_params)
sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32)
sequence_data[0] = sequence_size
sequence_data_reshaped = np.reshape(np.broadcast_to(
sequence_data,
(batch_size, max_step)).transpose((1, 0)), (batch_size * max_step))
dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8)
DataGenerator.MAX_LENGTH = sequence_size
state = [(jnp.zeros((batch_size, 16)),
jnp.zeros((batch_size, 16)))] * 3
with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size,
pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator:
data_np = batch_generator.batch_data
label_np = batch_generator.batch_label
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
summary_period = max(max_step // 100, 1)
np.set_printoptions(precision=2)
try:
start_time = time.time()
while batch_generator.epoch == 0:
# data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data = jnp.asarray(data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]])
label = jnp.asarray(label_np)
preds, _state = network_fun(network_params, data, state)
loss = loss_fun(preds, label)
update_fun(batch_generator.global_step, opt_state, preds, label)
running_loss += loss
running_accuracy += accuracy_fun(preds, label)
running_count += 1
if (batch_generator.step + 1) % summary_period == 0:
writer_train.add_scalar('metric/loss', running_loss / running_count,
global_step=batch_generator.step)
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count),
global_step=batch_generator.step)
speed = summary_period / (time.time() - start_time)
print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}'
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
start_time = time.time()
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
data_np, label_np = batch_generator.next_batch()
except KeyboardInterrupt:
print('\r ', end='\r')
writer_train.close()
running_accuracy = 0.0
running_count = 0
for _ in range(math.ceil(1000 / batch_size)):
data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data = jnp.asarray(data_np)
label = jnp.asarray(label_np)
preds, _state = network_fun(network_params, data, state)
running_accuracy += accuracy_fun(preds, label)
running_count += 1
print(f'Validation accuracy: {running_accuracy / running_count:.03f}')
test_data = [
[[1.], [0.], [0.], [1.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
[[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]],
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]],
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
[0.], [1.], [0.], [0.], [1.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
[0.], [1.], [0.], [0.], [1.], [0.]],
]
test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32)
running_accuracy = 0.0
running_count = 0
state = [(jnp.zeros((1, 16)), jnp.zeros((1, 16)))] * 3
for data, label in zip(test_data, test_label):
outputs, _states = network_fun(
network_params,
jnp.asarray(np.expand_dims(np.asarray(data, dtype=np.float32), 1)),
state)
running_accuracy += accuracy_fun(preds, label)
running_count += 1
print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}'
f', output: {outputs.detach().cpu().numpy()[0, -1, 1]:.02f}')
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
if __name__ == '__main__':
main()

155
recorder.py Normal file
View file

@ -0,0 +1,155 @@
from argparse import ArgumentParser
from pathlib import Path
import shutil
import time
import numpy as np
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from src.torch_networks import LSTMModel, StackedLSTMModel
from src.torch_utils.train import parameter_summary
def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray:
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
input_dim: int = arguments.dimension
max_step: int = arguments.step
model: str = arguments.model
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if model == 'cell':
network = StackedLSTMModel(input_dim + 1).to(device)
else:
network = LSTMModel(input_dim + 1).to(device)
# Save parameters info
with open(output_dir / 'parameters.csv', 'w') as param_file:
param_summary = parameter_summary(network)
names = [len(name) for name, _, _ in param_summary]
shapes = [len(str(shape)) for _, shape, _ in param_summary]
param_file.write(
'\n'.join(
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
for name, shape, size in param_summary]))
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
criterion = nn.CrossEntropyLoss()
zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device)
# writer_train.add_graph(network, (zero_data[:, :5],))
if model == 'cell':
state = [(torch.zeros((batch_size, (input_dim + 1) * 2)).to(device),
torch.zeros((batch_size, (input_dim + 1) * 2)).to(device))] * network.NUM_LAYERS
else:
state = None
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
summary_period = max_step // 100
np.set_printoptions(precision=2)
try:
start_time = time.time()
for step in range(1, max_step + 1):
label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)), input_dim)
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0
optimizer.zero_grad()
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state)
outputs = outputs[label_np.shape[1]:, :, :-1]
# state = (state[0].detach(), state[1].detach())
# state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]),
# state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:]))
data[:, :, input_dim] = 0.0
loss = criterion(outputs[:, 0], label[0])
for i in range(1, batch_size):
loss += criterion(outputs[:, i], label[i])
loss /= batch_size
running_loss += loss.item()
running_accuracy += (
torch.sum((torch.argmax(outputs, 2).transpose(1, 0) == label).long(),
1) == label.size(1)).float().mean().item()
running_count += 1
if step % summary_period == 0:
writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=step)
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=step)
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
scheduler.step()
speed = summary_period / (time.time() - start_time)
print(f'Step {step}, loss: {running_loss / running_count:.03e}'
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
start_time = time.time()
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
loss.backward()
optimizer.step()
except KeyboardInterrupt:
print('\r ', end='\r')
writer_train.close()
network.eval()
test_label = [
np.asarray([[0, 0, 0, 0]], dtype=np.int64),
np.asarray([[2, 2, 2, 2]], dtype=np.int64),
np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64),
generate_data(1, 4, input_dim),
np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64),
np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64),
np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64),
generate_data(1, sequence_size, input_dim)
]
zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device)
if model == 'cell':
state = [(torch.zeros((1, (input_dim + 1) * 2)).to(device),
torch.zeros((1, (input_dim + 1) * 2)).to(device))] * 3
running_accuracy = 0.0
running_count = 0
for label_np in test_label:
label = torch.from_numpy(label_np).to(device)
data = nn.functional.one_hot(label, input_dim + 1).float()
data[:, :, input_dim] = 1.0
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state)
outputs = outputs[label_np.shape[1]:, :, :-1]
# state = (state[0].detach(), state[1].detach())
running_accuracy += (
torch.sum(
(torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item()
running_count += 1
print(f'{len(label_np)} label: {label_np}, output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}')
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
if __name__ == '__main__':
main()

104
src/jax_networks.py Normal file
View file

@ -0,0 +1,104 @@
import jax
import jax.numpy as jnp
from jax import lax
from jax import nn
@jax.jit
def sigmoid(x):
return 1 / (1 + lax.exp(-x))
def LSTMCell(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
def init_fun(rng, input_size):
k1, k2, k3 = jax.random.split(rng, 3)
weight_ih = w_init(k1, (input_size, 4 * hidden_size))
weight_hh = w_init(k2, (hidden_size, 4 * hidden_size))
bias = b_init(k3, (4 * hidden_size,))
return hidden_size, (weight_ih, weight_hh, bias)
@jax.jit
def apply_fun(params, inputs, states):
(hx, cx) = states
weight_ih, weight_hh, bias = params
gates = lax.dot(inputs, weight_ih) + lax.dot(hx, weight_hh) + bias
in_gate = sigmoid(gates[:, :16])
forget_gate = sigmoid(gates[:, 16:32])
cell_gate = lax.tanh(gates[:, 32:48])
out_gate = sigmoid(gates[:, 48:])
cy = lax.mul(forget_gate, cx) + lax.mul(in_gate, cell_gate)
hy = lax.mul(out_gate, lax.tanh(cy))
return (hy, cy)
return init_fun, apply_fun
def LSTMLayer(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
cell_init, cell_fun = LSTMCell(hidden_size, w_init=w_init, b_init=b_init)
@jax.jit
def apply_fun(params, inputs, states):
output = []
for input_data in inputs:
states = cell_fun(params, input_data, states)
output.append(states[0])
return jnp.stack(output), states
return cell_init, apply_fun
def StackedLSTM(hidden_size, num_layer, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
layer_inits = []
layer_funs = []
for _ in range(num_layer):
layer_init, layer_fun = LSTMLayer(hidden_size, w_init=w_init, b_init=b_init)
layer_inits.append(layer_init)
layer_funs.append(layer_fun)
del layer_init
del layer_fun
def init_fun(rng, input_size):
params = []
output_size = input_size
for init in layer_inits:
output_size, layer_params = init(rng, output_size)
params.append(layer_params)
return output_size, tuple(params)
@jax.jit
def apply_fun(params, inputs, states):
output = inputs
output_states = []
for layer_id in range(num_layer):
output, out_state = layer_funs[layer_id](params[layer_id], output, states[layer_id])
output_states.append(out_state)
return jnp.stack(output), output_states
return init_fun, apply_fun
def StackedLSTMModel(input_size, hidden_size=-1, output_size=-1, num_layer=3,
w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
stacked_init, stacked_fun = StackedLSTM(hidden_size, num_layer, w_init=w_init, b_init=b_init)
def init_fun(rng, input_size):
stacked_output_size, stacked_params = stacked_init(rng, input_size)
k1, k2 = jax.random.split(rng)
weight = w_init(k1, (stacked_output_size, output_size))
bias = b_init(k2, (output_size,))
return output_size, ((weight, bias), stacked_params)
@jax.jit
def apply_fun(params, inputs, states):
(weight, bias), stacked_params = params
output, new_states = stacked_fun(stacked_params, inputs, states)
return lax.dot(output[-1], weight) + bias, new_states
return init_fun, apply_fun

136
src/torch_networks.py Normal file
View file

@ -0,0 +1,136 @@
import torch
from torch import nn
class LSTMModel(nn.Module):
NUM_LAYERS = 3
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS)
self.dense = nn.Linear(hidden_size, output_size)
def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
output, state = self.lstm(input_data, init_state)
return self.dense(output), state
class LSTMCell(torch.jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.randn(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(torch.randn(hidden_size, 4 * hidden_size))
self.bias = nn.Parameter(torch.randn(4 * hidden_size))
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, torch.Tensor]:
hx, cx = state
gates = (
(torch.mm(input_data, self.weight_ih) + torch.mm(hx, self.weight_hh) + self.bias))
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return (hy, cy)
class LSTMLayer(torch.jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super().__init__()
self.cell = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
inputs = input_data.unbind(0)
outputs = torch.jit.annotate(list[torch.Tensor], [])
for i in range(len(inputs)):
state = self.cell(inputs[i], state)
outputs += [state[0]]
return torch.stack(outputs), state
class StackedLSTM(torch.jit.ScriptModule):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.layers = nn.ModuleList(
[LSTMLayer(input_size=input_size, hidden_size=hidden_size)] + [
LSTMLayer(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layers - 1)])
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], [])
output = input_data
i = 0
for rnn_layer in self.layers:
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
i += 1
return output, output_states
class StackedLSTMModel(torch.jit.ScriptModule):
NUM_LAYERS = 3
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
self.lstm = StackedLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS)
self.dense = nn.Linear(hidden_size, output_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, init_state: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
output, state = self.lstm(input_data, init_state)
return self.dense(output), state
class LSTMCellModel(torch.jit.ScriptModule):
NUM_LAYERS = 3
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
super().__init__()
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
output_size = output_size if output_size > 0 else input_size
self.hidden_size = hidden_size
self.layers = nn.ModuleList([
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size),
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size),
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
])
self.dense = nn.Linear(hidden_size, output_size)
@torch.jit.script_method
def forward(self, input_data: torch.Tensor, init_states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], [])
cell_inputs = input_data.unbind(0)
cell_id = 0
for cell in self.layers:
cell_state = init_states[cell_id]
cell_outputs = torch.jit.annotate(list[torch.Tensor], [])
for i in range(len(cell_inputs)):
cell_state = cell(cell_inputs[i], cell_state)
cell_outputs += [cell_state[0]]
cell_inputs = cell_outputs
output_states += [cell_state]
cell_id += 1
return self.dense(torch.stack(cell_outputs)), output_states

1
src/torch_utils Submodule

@ -0,0 +1 @@
Subproject commit 1bac46219b42fe41ba3568fdde3ca364b02e46e9