Initial commit
This commit is contained in:
commit
e41417c029
8 changed files with 802 additions and 0 deletions
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
*.pyc
|
||||
|
||||
output
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
[submodule "src/torch_utils"]
|
||||
path = src/torch_utils
|
||||
url = git@gitlab.com:corentin-pro/torch_utils.git
|
||||
209
modulo.py
Normal file
209
modulo.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.torch_networks import LSTMModel, LSTMCellModel, StackedLSTMModel
|
||||
from src.torch_utils.utils.batch_generator import BatchGenerator
|
||||
from src.torch_utils.train import parameter_summary
|
||||
|
||||
|
||||
def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32)
|
||||
data = np.zeros((data_length, batch_size, 1), dtype=np.float32)
|
||||
starts = []
|
||||
for mod in modulos:
|
||||
starts.append(int(np.random.uniform(0, mod)))
|
||||
for i in range(batch_size):
|
||||
# np.where(data[i] % modulos[i] == starts[i], [1.0], data[i])
|
||||
for j in range(starts[i], data_length, modulos[i]):
|
||||
data[j, i, 0] = 1.0
|
||||
label = []
|
||||
for i in range(batch_size):
|
||||
label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0)
|
||||
return data, np.asarray(label, dtype=np.int64)
|
||||
|
||||
|
||||
class DataGenerator:
|
||||
MAX_LENGTH = 1
|
||||
INITIALIZED = False
|
||||
|
||||
@staticmethod
|
||||
def pipeline(sequence_length, _dummy_label):
|
||||
if not DataGenerator.INITIALIZED:
|
||||
np.random.seed(time.time_ns() % (2**32))
|
||||
DataGenerator.INITIALIZED = True
|
||||
data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32)
|
||||
modulo = int(np.random.uniform(3, sequence_length // 2 + 1))
|
||||
start = int(np.random.uniform(0, modulo))
|
||||
for i in range(start, sequence_length, modulo):
|
||||
data[i, 0] = 1.0
|
||||
return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir')
|
||||
parser.add_argument('--batch', type=int, default=32, help='Batch size')
|
||||
parser.add_argument('--sequence', type=int, default=12, help='Max sequence length')
|
||||
parser.add_argument('--step', type=int, default=2000, help='Number of steps to train')
|
||||
parser.add_argument('--model', help='Model to train')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
output_dir: Path = arguments.output
|
||||
batch_size: int = arguments.batch
|
||||
sequence_size: int = arguments.sequence
|
||||
max_step: int = arguments.step
|
||||
model: str = arguments.model
|
||||
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
if (output_dir / 'train').exists():
|
||||
shutil.rmtree(output_dir / 'train')
|
||||
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if model == 'stack':
|
||||
network = StackedLSTMModel(1, 16, 2).to(device)
|
||||
elif model == 'cell':
|
||||
network = LSTMCellModel(1, 16, 2).to(device)
|
||||
else:
|
||||
network = LSTMModel(1, 16, 2).to(device)
|
||||
torch.save(network.state_dict(), output_dir / 'model_ini.pt')
|
||||
|
||||
# Save parameters info
|
||||
with open(output_dir / 'parameters.csv', 'w') as param_file:
|
||||
param_summary = parameter_summary(network)
|
||||
names = [len(name) for name, _, _ in param_summary]
|
||||
shapes = [len(str(shape)) for _, shape, _ in param_summary]
|
||||
param_file.write(
|
||||
'\n'.join(
|
||||
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
|
||||
for name, shape, size in param_summary]))
|
||||
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32)
|
||||
sequence_data[0] = sequence_size
|
||||
sequence_data_reshaped = np.reshape(np.broadcast_to(
|
||||
sequence_data,
|
||||
(batch_size, max_step)).transpose((1, 0)), (batch_size * max_step))
|
||||
dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8)
|
||||
DataGenerator.MAX_LENGTH = sequence_size
|
||||
if model in ['cell', 'stack']:
|
||||
state = [(torch.zeros((batch_size, 16)).to(device),
|
||||
torch.zeros((batch_size, 16)).to(device))] * network.NUM_LAYERS
|
||||
else:
|
||||
state = None
|
||||
with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size,
|
||||
pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator:
|
||||
# data_np, _ = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
|
||||
data_np = batch_generator.batch_data
|
||||
label_np = batch_generator.batch_label
|
||||
# writer_train.add_graph(network, (torch.from_numpy(data_np).to(device),))
|
||||
|
||||
running_loss = 0.0
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
summary_period = max(max_step // 100, 1)
|
||||
np.set_printoptions(precision=2)
|
||||
try:
|
||||
start_time = time.time()
|
||||
while batch_generator.epoch == 0:
|
||||
# data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
|
||||
data = torch.from_numpy(
|
||||
data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]]).to(device)
|
||||
label = torch.from_numpy(label_np).to(device)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
outputs, _states = network(data, state)
|
||||
loss = criterion(outputs[-1], label)
|
||||
running_loss += loss.item()
|
||||
outputs_np = outputs[-1].detach().cpu().numpy()
|
||||
running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype(
|
||||
np.float32).mean()
|
||||
running_count += 1
|
||||
if (batch_generator.step + 1) % summary_period == 0:
|
||||
writer_train.add_scalar('metric/loss', running_loss / running_count,
|
||||
global_step=batch_generator.step)
|
||||
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count),
|
||||
global_step=batch_generator.step)
|
||||
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0],
|
||||
global_step=batch_generator.step)
|
||||
scheduler.step()
|
||||
|
||||
speed = summary_period / (time.time() - start_time)
|
||||
print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}'
|
||||
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
|
||||
start_time = time.time()
|
||||
running_loss = 0.0
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
data_np, label_np = batch_generator.next_batch()
|
||||
except KeyboardInterrupt:
|
||||
print('\r ', end='\r')
|
||||
writer_train.close()
|
||||
|
||||
network.eval()
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
for _ in range(math.ceil(1000 / batch_size)):
|
||||
data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
|
||||
data = torch.from_numpy(data_np).to(device)
|
||||
label = torch.from_numpy(label_np).to(device)
|
||||
|
||||
outputs, _states = network(data, state)
|
||||
outputs_np = outputs[-1].detach().cpu().numpy()
|
||||
running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype(
|
||||
np.float32).mean()
|
||||
running_count += 1
|
||||
print(f'Validation accuracy: {running_accuracy / running_count:.03f}')
|
||||
|
||||
test_data = [
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]],
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
|
||||
[0.], [1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
|
||||
[0.], [1.], [0.], [0.], [1.], [0.]],
|
||||
]
|
||||
test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32)
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
if model in ['cell', 'stack']:
|
||||
state = [(torch.zeros((1, 16)).to(device),
|
||||
torch.zeros((1, 16)).to(device))] * network.NUM_LAYERS
|
||||
for data, label in zip(test_data, test_label):
|
||||
outputs, _states = network(
|
||||
torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device),
|
||||
state)
|
||||
outputs_np = outputs[-1].detach().cpu().numpy()
|
||||
output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label
|
||||
running_accuracy += 1.0 if output_correct else 0.0
|
||||
running_count += 1
|
||||
print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}'
|
||||
f', output: {int(outputs_np[0, 1] > outputs_np[0, 0])}')
|
||||
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
191
modulo_jax.py
Normal file
191
modulo_jax.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
import sys
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import optimizers
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.jax_networks import StackedLSTMModel
|
||||
from src.torch_utils.utils.batch_generator import BatchGenerator
|
||||
|
||||
|
||||
def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32)
|
||||
data = np.zeros((data_length, batch_size, 1), dtype=np.float32)
|
||||
starts = []
|
||||
for mod in modulos:
|
||||
starts.append(int(np.random.uniform(0, mod)))
|
||||
for i in range(batch_size):
|
||||
# np.where(data[i] % modulos[i] == starts[i], [1.0], data[i])
|
||||
for j in range(starts[i], data_length, modulos[i]):
|
||||
data[j, i, 0] = 1.0
|
||||
label = []
|
||||
for i in range(batch_size):
|
||||
label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0)
|
||||
return data, np.asarray(label, dtype=np.int64)
|
||||
|
||||
|
||||
class DataGenerator:
|
||||
MAX_LENGTH = 1
|
||||
INITIALIZED = False
|
||||
|
||||
@staticmethod
|
||||
def pipeline(sequence_length, _dummy_label):
|
||||
if not DataGenerator.INITIALIZED:
|
||||
np.random.seed(time.time_ns() % (2**32))
|
||||
DataGenerator.INITIALIZED = True
|
||||
data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32)
|
||||
modulo = int(np.random.uniform(3, sequence_length // 2 + 1))
|
||||
start = int(np.random.uniform(0, modulo))
|
||||
for i in range(start, sequence_length, modulo):
|
||||
data[i, 0] = 1.0
|
||||
return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir')
|
||||
parser.add_argument('--batch', type=int, default=32, help='Batch size')
|
||||
parser.add_argument('--sequence', type=int, default=12, help='Max sequence length')
|
||||
parser.add_argument('--step', type=int, default=2000, help='Number of steps to train')
|
||||
parser.add_argument('--model', help='Model to train')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
output_dir: Path = arguments.output
|
||||
batch_size: int = arguments.batch
|
||||
sequence_size: int = arguments.sequence
|
||||
max_step: int = arguments.step
|
||||
model: str = arguments.model
|
||||
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
if (output_dir / 'train').exists():
|
||||
shutil.rmtree(output_dir / 'train')
|
||||
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
if model == 'stack':
|
||||
network_init, network_fun = StackedLSTMModel(1, 16, 2)
|
||||
_, network_params = network_init(rng, 1)
|
||||
else:
|
||||
print('Model not implemented')
|
||||
sys.exit(1)
|
||||
|
||||
@jax.jit
|
||||
def loss_fun(preds, targets):
|
||||
return lax.mul(-targets, lax.log(preds)) - lax.mul(1 - targets, lax.log(1 - preds))
|
||||
|
||||
def accuracy_fun(preds, targets):
|
||||
return jnp.mean((preds[:, 1] > preds[:, 0]).int64() == targets)
|
||||
|
||||
opt_init, opt_update, opt_params = optimizers.adam(1e-3)
|
||||
|
||||
@jax.jit
|
||||
def update_fun(step, opt_state, preds, targets):
|
||||
return opt_update(step, jax.grad(loss_fun)(preds, targets), opt_params(opt_state))
|
||||
|
||||
opt_state = opt_init(network_params)
|
||||
|
||||
sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32)
|
||||
sequence_data[0] = sequence_size
|
||||
sequence_data_reshaped = np.reshape(np.broadcast_to(
|
||||
sequence_data,
|
||||
(batch_size, max_step)).transpose((1, 0)), (batch_size * max_step))
|
||||
dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8)
|
||||
DataGenerator.MAX_LENGTH = sequence_size
|
||||
state = [(jnp.zeros((batch_size, 16)),
|
||||
jnp.zeros((batch_size, 16)))] * 3
|
||||
with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size,
|
||||
pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator:
|
||||
data_np = batch_generator.batch_data
|
||||
label_np = batch_generator.batch_label
|
||||
|
||||
running_loss = 0.0
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
summary_period = max(max_step // 100, 1)
|
||||
np.set_printoptions(precision=2)
|
||||
try:
|
||||
start_time = time.time()
|
||||
while batch_generator.epoch == 0:
|
||||
# data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
|
||||
data = jnp.asarray(data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]])
|
||||
label = jnp.asarray(label_np)
|
||||
|
||||
preds, _state = network_fun(network_params, data, state)
|
||||
loss = loss_fun(preds, label)
|
||||
update_fun(batch_generator.global_step, opt_state, preds, label)
|
||||
|
||||
running_loss += loss
|
||||
running_accuracy += accuracy_fun(preds, label)
|
||||
running_count += 1
|
||||
if (batch_generator.step + 1) % summary_period == 0:
|
||||
writer_train.add_scalar('metric/loss', running_loss / running_count,
|
||||
global_step=batch_generator.step)
|
||||
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count),
|
||||
global_step=batch_generator.step)
|
||||
|
||||
speed = summary_period / (time.time() - start_time)
|
||||
print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}'
|
||||
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
|
||||
start_time = time.time()
|
||||
running_loss = 0.0
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
data_np, label_np = batch_generator.next_batch()
|
||||
except KeyboardInterrupt:
|
||||
print('\r ', end='\r')
|
||||
writer_train.close()
|
||||
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
for _ in range(math.ceil(1000 / batch_size)):
|
||||
data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
|
||||
data = jnp.asarray(data_np)
|
||||
label = jnp.asarray(label_np)
|
||||
|
||||
preds, _state = network_fun(network_params, data, state)
|
||||
running_accuracy += accuracy_fun(preds, label)
|
||||
running_count += 1
|
||||
print(f'Validation accuracy: {running_accuracy / running_count:.03f}')
|
||||
|
||||
test_data = [
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]],
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
|
||||
[0.], [1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
|
||||
[0.], [1.], [0.], [0.], [1.], [0.]],
|
||||
]
|
||||
test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32)
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
state = [(jnp.zeros((1, 16)), jnp.zeros((1, 16)))] * 3
|
||||
for data, label in zip(test_data, test_label):
|
||||
outputs, _states = network_fun(
|
||||
network_params,
|
||||
jnp.asarray(np.expand_dims(np.asarray(data, dtype=np.float32), 1)),
|
||||
state)
|
||||
running_accuracy += accuracy_fun(preds, label)
|
||||
running_count += 1
|
||||
print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}'
|
||||
f', output: {outputs.detach().cpu().numpy()[0, -1, 1]:.02f}')
|
||||
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
155
recorder.py
Normal file
155
recorder.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.torch_networks import LSTMModel, StackedLSTMModel
|
||||
from src.torch_utils.train import parameter_summary
|
||||
|
||||
|
||||
def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray:
|
||||
return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir')
|
||||
parser.add_argument('--batch', type=int, default=32, help='Batch size')
|
||||
parser.add_argument('--sequence', type=int, default=8, help='Max sequence length')
|
||||
parser.add_argument('--dimension', type=int, default=15, help='Input dimension')
|
||||
parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train')
|
||||
parser.add_argument('--model', help='Model to train')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
output_dir: Path = arguments.output
|
||||
batch_size: int = arguments.batch
|
||||
sequence_size: int = arguments.sequence
|
||||
input_dim: int = arguments.dimension
|
||||
max_step: int = arguments.step
|
||||
model: str = arguments.model
|
||||
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
if (output_dir / 'train').exists():
|
||||
shutil.rmtree(output_dir / 'train')
|
||||
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
if model == 'cell':
|
||||
network = StackedLSTMModel(input_dim + 1).to(device)
|
||||
else:
|
||||
network = LSTMModel(input_dim + 1).to(device)
|
||||
|
||||
# Save parameters info
|
||||
with open(output_dir / 'parameters.csv', 'w') as param_file:
|
||||
param_summary = parameter_summary(network)
|
||||
names = [len(name) for name, _, _ in param_summary]
|
||||
shapes = [len(str(shape)) for _, shape, _ in param_summary]
|
||||
param_file.write(
|
||||
'\n'.join(
|
||||
[f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}'
|
||||
for name, shape, size in param_summary]))
|
||||
|
||||
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device)
|
||||
# writer_train.add_graph(network, (zero_data[:, :5],))
|
||||
|
||||
if model == 'cell':
|
||||
state = [(torch.zeros((batch_size, (input_dim + 1) * 2)).to(device),
|
||||
torch.zeros((batch_size, (input_dim + 1) * 2)).to(device))] * network.NUM_LAYERS
|
||||
else:
|
||||
state = None
|
||||
running_loss = 0.0
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
summary_period = max_step // 100
|
||||
np.set_printoptions(precision=2)
|
||||
try:
|
||||
start_time = time.time()
|
||||
for step in range(1, max_step + 1):
|
||||
label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)), input_dim)
|
||||
label = torch.from_numpy(label_np).to(device)
|
||||
data = nn.functional.one_hot(label, input_dim + 1).float()
|
||||
data[:, :, input_dim] = 1.0
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state)
|
||||
outputs = outputs[label_np.shape[1]:, :, :-1]
|
||||
# state = (state[0].detach(), state[1].detach())
|
||||
# state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]),
|
||||
# state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:]))
|
||||
|
||||
data[:, :, input_dim] = 0.0
|
||||
loss = criterion(outputs[:, 0], label[0])
|
||||
for i in range(1, batch_size):
|
||||
loss += criterion(outputs[:, i], label[i])
|
||||
loss /= batch_size
|
||||
running_loss += loss.item()
|
||||
running_accuracy += (
|
||||
torch.sum((torch.argmax(outputs, 2).transpose(1, 0) == label).long(),
|
||||
1) == label.size(1)).float().mean().item()
|
||||
running_count += 1
|
||||
if step % summary_period == 0:
|
||||
writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=step)
|
||||
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=step)
|
||||
writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step)
|
||||
scheduler.step()
|
||||
|
||||
speed = summary_period / (time.time() - start_time)
|
||||
print(f'Step {step}, loss: {running_loss / running_count:.03e}'
|
||||
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
|
||||
start_time = time.time()
|
||||
running_loss = 0.0
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
except KeyboardInterrupt:
|
||||
print('\r ', end='\r')
|
||||
writer_train.close()
|
||||
|
||||
network.eval()
|
||||
test_label = [
|
||||
np.asarray([[0, 0, 0, 0]], dtype=np.int64),
|
||||
np.asarray([[2, 2, 2, 2]], dtype=np.int64),
|
||||
np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64),
|
||||
generate_data(1, 4, input_dim),
|
||||
np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64),
|
||||
np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64),
|
||||
np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64),
|
||||
generate_data(1, sequence_size, input_dim)
|
||||
]
|
||||
zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device)
|
||||
if model == 'cell':
|
||||
state = [(torch.zeros((1, (input_dim + 1) * 2)).to(device),
|
||||
torch.zeros((1, (input_dim + 1) * 2)).to(device))] * 3
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
for label_np in test_label:
|
||||
label = torch.from_numpy(label_np).to(device)
|
||||
data = nn.functional.one_hot(label, input_dim + 1).float()
|
||||
data[:, :, input_dim] = 1.0
|
||||
|
||||
outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state)
|
||||
outputs = outputs[label_np.shape[1]:, :, :-1]
|
||||
# state = (state[0].detach(), state[1].detach())
|
||||
|
||||
running_accuracy += (
|
||||
torch.sum(
|
||||
(torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item()
|
||||
running_count += 1
|
||||
print(f'{len(label_np)} label: {label_np}, output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}')
|
||||
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
104
src/jax_networks.py
Normal file
104
src/jax_networks.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import nn
|
||||
|
||||
|
||||
@jax.jit
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + lax.exp(-x))
|
||||
|
||||
|
||||
def LSTMCell(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
|
||||
def init_fun(rng, input_size):
|
||||
k1, k2, k3 = jax.random.split(rng, 3)
|
||||
weight_ih = w_init(k1, (input_size, 4 * hidden_size))
|
||||
weight_hh = w_init(k2, (hidden_size, 4 * hidden_size))
|
||||
bias = b_init(k3, (4 * hidden_size,))
|
||||
return hidden_size, (weight_ih, weight_hh, bias)
|
||||
|
||||
@jax.jit
|
||||
def apply_fun(params, inputs, states):
|
||||
(hx, cx) = states
|
||||
weight_ih, weight_hh, bias = params
|
||||
gates = lax.dot(inputs, weight_ih) + lax.dot(hx, weight_hh) + bias
|
||||
|
||||
in_gate = sigmoid(gates[:, :16])
|
||||
forget_gate = sigmoid(gates[:, 16:32])
|
||||
cell_gate = lax.tanh(gates[:, 32:48])
|
||||
out_gate = sigmoid(gates[:, 48:])
|
||||
|
||||
cy = lax.mul(forget_gate, cx) + lax.mul(in_gate, cell_gate)
|
||||
hy = lax.mul(out_gate, lax.tanh(cy))
|
||||
|
||||
return (hy, cy)
|
||||
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def LSTMLayer(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
|
||||
cell_init, cell_fun = LSTMCell(hidden_size, w_init=w_init, b_init=b_init)
|
||||
|
||||
@jax.jit
|
||||
def apply_fun(params, inputs, states):
|
||||
output = []
|
||||
for input_data in inputs:
|
||||
states = cell_fun(params, input_data, states)
|
||||
output.append(states[0])
|
||||
return jnp.stack(output), states
|
||||
|
||||
return cell_init, apply_fun
|
||||
|
||||
|
||||
def StackedLSTM(hidden_size, num_layer, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
|
||||
layer_inits = []
|
||||
layer_funs = []
|
||||
for _ in range(num_layer):
|
||||
layer_init, layer_fun = LSTMLayer(hidden_size, w_init=w_init, b_init=b_init)
|
||||
layer_inits.append(layer_init)
|
||||
layer_funs.append(layer_fun)
|
||||
del layer_init
|
||||
del layer_fun
|
||||
|
||||
def init_fun(rng, input_size):
|
||||
params = []
|
||||
output_size = input_size
|
||||
for init in layer_inits:
|
||||
output_size, layer_params = init(rng, output_size)
|
||||
params.append(layer_params)
|
||||
return output_size, tuple(params)
|
||||
|
||||
@jax.jit
|
||||
def apply_fun(params, inputs, states):
|
||||
output = inputs
|
||||
output_states = []
|
||||
for layer_id in range(num_layer):
|
||||
output, out_state = layer_funs[layer_id](params[layer_id], output, states[layer_id])
|
||||
output_states.append(out_state)
|
||||
return jnp.stack(output), output_states
|
||||
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def StackedLSTMModel(input_size, hidden_size=-1, output_size=-1, num_layer=3,
|
||||
w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
|
||||
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
|
||||
output_size = output_size if output_size > 0 else input_size
|
||||
stacked_init, stacked_fun = StackedLSTM(hidden_size, num_layer, w_init=w_init, b_init=b_init)
|
||||
|
||||
def init_fun(rng, input_size):
|
||||
stacked_output_size, stacked_params = stacked_init(rng, input_size)
|
||||
|
||||
k1, k2 = jax.random.split(rng)
|
||||
weight = w_init(k1, (stacked_output_size, output_size))
|
||||
bias = b_init(k2, (output_size,))
|
||||
|
||||
return output_size, ((weight, bias), stacked_params)
|
||||
|
||||
@jax.jit
|
||||
def apply_fun(params, inputs, states):
|
||||
(weight, bias), stacked_params = params
|
||||
output, new_states = stacked_fun(stacked_params, inputs, states)
|
||||
return lax.dot(output[-1], weight) + bias, new_states
|
||||
|
||||
return init_fun, apply_fun
|
||||
136
src/torch_networks.py
Normal file
136
src/torch_networks.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LSTMModel(nn.Module):
|
||||
NUM_LAYERS = 3
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
|
||||
super().__init__()
|
||||
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
|
||||
output_size = output_size if output_size > 0 else input_size
|
||||
|
||||
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS)
|
||||
self.dense = nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[
|
||||
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
output, state = self.lstm(input_data, init_state)
|
||||
return self.dense(output), state
|
||||
|
||||
|
||||
class LSTMCell(torch.jit.ScriptModule):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.weight_ih = nn.Parameter(torch.randn(input_size, 4 * hidden_size))
|
||||
self.weight_hh = nn.Parameter(torch.randn(hidden_size, 4 * hidden_size))
|
||||
self.bias = nn.Parameter(torch.randn(4 * hidden_size))
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
|
||||
torch.Tensor, torch.Tensor]:
|
||||
hx, cx = state
|
||||
gates = (
|
||||
(torch.mm(input_data, self.weight_ih) + torch.mm(hx, self.weight_hh) + self.bias))
|
||||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
||||
|
||||
ingate = torch.sigmoid(ingate)
|
||||
forgetgate = torch.sigmoid(forgetgate)
|
||||
cellgate = torch.tanh(cellgate)
|
||||
outgate = torch.sigmoid(outgate)
|
||||
|
||||
cy = (forgetgate * cx) + (ingate * cellgate)
|
||||
hy = outgate * torch.tanh(cy)
|
||||
|
||||
return (hy, cy)
|
||||
|
||||
|
||||
class LSTMLayer(torch.jit.ScriptModule):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super().__init__()
|
||||
self.cell = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[
|
||||
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
inputs = input_data.unbind(0)
|
||||
outputs = torch.jit.annotate(list[torch.Tensor], [])
|
||||
for i in range(len(inputs)):
|
||||
state = self.cell(inputs[i], state)
|
||||
outputs += [state[0]]
|
||||
return torch.stack(outputs), state
|
||||
|
||||
|
||||
class StackedLSTM(torch.jit.ScriptModule):
|
||||
def __init__(self, input_size, hidden_size, num_layers):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[LSTMLayer(input_size=input_size, hidden_size=hidden_size)] + [
|
||||
LSTMLayer(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layers - 1)])
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input_data: torch.Tensor, states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
|
||||
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
|
||||
output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], [])
|
||||
output = input_data
|
||||
i = 0
|
||||
for rnn_layer in self.layers:
|
||||
state = states[i]
|
||||
output, out_state = rnn_layer(output, state)
|
||||
output_states += [out_state]
|
||||
i += 1
|
||||
return output, output_states
|
||||
|
||||
|
||||
class StackedLSTMModel(torch.jit.ScriptModule):
|
||||
NUM_LAYERS = 3
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
|
||||
super().__init__()
|
||||
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
|
||||
output_size = output_size if output_size > 0 else input_size
|
||||
|
||||
self.lstm = StackedLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS)
|
||||
self.dense = nn.Linear(hidden_size, output_size)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input_data: torch.Tensor, init_state: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
|
||||
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
|
||||
output, state = self.lstm(input_data, init_state)
|
||||
return self.dense(output), state
|
||||
|
||||
|
||||
class LSTMCellModel(torch.jit.ScriptModule):
|
||||
NUM_LAYERS = 3
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1):
|
||||
super().__init__()
|
||||
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
|
||||
output_size = output_size if output_size > 0 else input_size
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.layers = nn.ModuleList([
|
||||
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size),
|
||||
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size),
|
||||
nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
|
||||
])
|
||||
self.dense = nn.Linear(hidden_size, output_size)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input_data: torch.Tensor, init_states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[
|
||||
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
|
||||
output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], [])
|
||||
cell_inputs = input_data.unbind(0)
|
||||
cell_id = 0
|
||||
for cell in self.layers:
|
||||
cell_state = init_states[cell_id]
|
||||
cell_outputs = torch.jit.annotate(list[torch.Tensor], [])
|
||||
for i in range(len(cell_inputs)):
|
||||
cell_state = cell(cell_inputs[i], cell_state)
|
||||
cell_outputs += [cell_state[0]]
|
||||
cell_inputs = cell_outputs
|
||||
output_states += [cell_state]
|
||||
cell_id += 1
|
||||
return self.dense(torch.stack(cell_outputs)), output_states
|
||||
1
src/torch_utils
Submodule
1
src/torch_utils
Submodule
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 1bac46219b42fe41ba3568fdde3ca364b02e46e9
|
||||
Loading…
Add table
Add a link
Reference in a new issue