commit e41417c0295bf68e296f914e606505e7a2a6d5c5 Author: Corentin Date: Thu Aug 26 16:48:09 2021 +0900 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b1d5e39 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.pyc + +output \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..33fe1b0 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/torch_utils"] + path = src/torch_utils + url = git@gitlab.com:corentin-pro/torch_utils.git diff --git a/modulo.py b/modulo.py new file mode 100644 index 0000000..51ce082 --- /dev/null +++ b/modulo.py @@ -0,0 +1,209 @@ +from argparse import ArgumentParser +from pathlib import Path +import math +import shutil +import time + +import numpy as np +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from src.torch_networks import LSTMModel, LSTMCellModel, StackedLSTMModel +from src.torch_utils.utils.batch_generator import BatchGenerator +from src.torch_utils.train import parameter_summary + + +def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]: + modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32) + data = np.zeros((data_length, batch_size, 1), dtype=np.float32) + starts = [] + for mod in modulos: + starts.append(int(np.random.uniform(0, mod))) + for i in range(batch_size): + # np.where(data[i] % modulos[i] == starts[i], [1.0], data[i]) + for j in range(starts[i], data_length, modulos[i]): + data[j, i, 0] = 1.0 + label = [] + for i in range(batch_size): + label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0) + return data, np.asarray(label, dtype=np.int64) + + +class DataGenerator: + MAX_LENGTH = 1 + INITIALIZED = False + + @staticmethod + def pipeline(sequence_length, _dummy_label): + if not DataGenerator.INITIALIZED: + np.random.seed(time.time_ns() % (2**32)) + DataGenerator.INITIALIZED = True + data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32) + modulo = int(np.random.uniform(3, sequence_length // 2 + 1)) + start = int(np.random.uniform(0, modulo)) + for i in range(start, sequence_length, modulo): + data[i, 0] = 1.0 + return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64) + + +def main(): + parser = ArgumentParser() + parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir') + parser.add_argument('--batch', type=int, default=32, help='Batch size') + parser.add_argument('--sequence', type=int, default=12, help='Max sequence length') + parser.add_argument('--step', type=int, default=2000, help='Number of steps to train') + parser.add_argument('--model', help='Model to train') + arguments = parser.parse_args() + + output_dir: Path = arguments.output + batch_size: int = arguments.batch + sequence_size: int = arguments.sequence + max_step: int = arguments.step + model: str = arguments.model + + if not output_dir.exists(): + output_dir.mkdir(parents=True) + if (output_dir / 'train').exists(): + shutil.rmtree(output_dir / 'train') + writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + torch.backends.cudnn.benchmark = True + if model == 'stack': + network = StackedLSTMModel(1, 16, 2).to(device) + elif model == 'cell': + network = LSTMCellModel(1, 16, 2).to(device) + else: + network = LSTMModel(1, 16, 2).to(device) + torch.save(network.state_dict(), output_dir / 'model_ini.pt') + + # Save parameters info + with open(output_dir / 'parameters.csv', 'w') as param_file: + param_summary = parameter_summary(network) + names = [len(name) for name, _, _ in param_summary] + shapes = [len(str(shape)) for _, shape, _ in param_summary] + param_file.write( + '\n'.join( + [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' + for name, shape, size in param_summary])) + + optimizer = torch.optim.Adam(network.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995) + criterion = nn.CrossEntropyLoss() + + sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32) + sequence_data[0] = sequence_size + sequence_data_reshaped = np.reshape(np.broadcast_to( + sequence_data, + (batch_size, max_step)).transpose((1, 0)), (batch_size * max_step)) + dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8) + DataGenerator.MAX_LENGTH = sequence_size + if model in ['cell', 'stack']: + state = [(torch.zeros((batch_size, 16)).to(device), + torch.zeros((batch_size, 16)).to(device))] * network.NUM_LAYERS + else: + state = None + with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size, + pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator: + # data_np, _ = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) + data_np = batch_generator.batch_data + label_np = batch_generator.batch_label + # writer_train.add_graph(network, (torch.from_numpy(data_np).to(device),)) + + running_loss = 0.0 + running_accuracy = 0.0 + running_count = 0 + summary_period = max(max_step // 100, 1) + np.set_printoptions(precision=2) + try: + start_time = time.time() + while batch_generator.epoch == 0: + # data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) + data = torch.from_numpy( + data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]]).to(device) + label = torch.from_numpy(label_np).to(device) + + optimizer.zero_grad(set_to_none=True) + + outputs, _states = network(data, state) + loss = criterion(outputs[-1], label) + running_loss += loss.item() + outputs_np = outputs[-1].detach().cpu().numpy() + running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype( + np.float32).mean() + running_count += 1 + if (batch_generator.step + 1) % summary_period == 0: + writer_train.add_scalar('metric/loss', running_loss / running_count, + global_step=batch_generator.step) + writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), + global_step=batch_generator.step) + writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], + global_step=batch_generator.step) + scheduler.step() + + speed = summary_period / (time.time() - start_time) + print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}' + f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') + start_time = time.time() + running_loss = 0.0 + running_accuracy = 0.0 + running_count = 0 + loss.backward() + optimizer.step() + data_np, label_np = batch_generator.next_batch() + except KeyboardInterrupt: + print('\r ', end='\r') + writer_train.close() + + network.eval() + running_accuracy = 0.0 + running_count = 0 + for _ in range(math.ceil(1000 / batch_size)): + data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) + data = torch.from_numpy(data_np).to(device) + label = torch.from_numpy(label_np).to(device) + + outputs, _states = network(data, state) + outputs_np = outputs[-1].detach().cpu().numpy() + running_accuracy += ((outputs_np[:, 1] > outputs_np[:, 0]).astype(np.int32) == label_np).astype( + np.float32).mean() + running_count += 1 + print(f'Validation accuracy: {running_accuracy / running_count:.03f}') + + test_data = [ + [[1.], [0.], [0.], [1.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], + [[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]], + [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]], + [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], + [0.], [1.], [0.], [0.], [1.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], + [0.], [1.], [0.], [0.], [1.], [0.]], + ] + test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32) + running_accuracy = 0.0 + running_count = 0 + if model in ['cell', 'stack']: + state = [(torch.zeros((1, 16)).to(device), + torch.zeros((1, 16)).to(device))] * network.NUM_LAYERS + for data, label in zip(test_data, test_label): + outputs, _states = network( + torch.from_numpy(np.expand_dims(np.asarray(data, dtype=np.float32), 1)).to(device), + state) + outputs_np = outputs[-1].detach().cpu().numpy() + output_correct = int(outputs_np[0, 1] > outputs_np[0, 0]) == label + running_accuracy += 1.0 if output_correct else 0.0 + running_count += 1 + print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}' + f', output: {int(outputs_np[0, 1] > outputs_np[0, 0])}') + print(f'Test accuracy: {running_accuracy / running_count:.03f}') + + +if __name__ == '__main__': + main() diff --git a/modulo_jax.py b/modulo_jax.py new file mode 100644 index 0000000..c5c1b6e --- /dev/null +++ b/modulo_jax.py @@ -0,0 +1,191 @@ +from argparse import ArgumentParser +from pathlib import Path +import math +import shutil +import time +import sys + +import jax +from jax import lax +import jax.numpy as jnp +from jax.experimental import optimizers +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from src.jax_networks import StackedLSTMModel +from src.torch_utils.utils.batch_generator import BatchGenerator + + +def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]: + modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32) + data = np.zeros((data_length, batch_size, 1), dtype=np.float32) + starts = [] + for mod in modulos: + starts.append(int(np.random.uniform(0, mod))) + for i in range(batch_size): + # np.where(data[i] % modulos[i] == starts[i], [1.0], data[i]) + for j in range(starts[i], data_length, modulos[i]): + data[j, i, 0] = 1.0 + label = [] + for i in range(batch_size): + label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0) + return data, np.asarray(label, dtype=np.int64) + + +class DataGenerator: + MAX_LENGTH = 1 + INITIALIZED = False + + @staticmethod + def pipeline(sequence_length, _dummy_label): + if not DataGenerator.INITIALIZED: + np.random.seed(time.time_ns() % (2**32)) + DataGenerator.INITIALIZED = True + data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32) + modulo = int(np.random.uniform(3, sequence_length // 2 + 1)) + start = int(np.random.uniform(0, modulo)) + for i in range(start, sequence_length, modulo): + data[i, 0] = 1.0 + return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64) + + +def main(): + parser = ArgumentParser() + parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir') + parser.add_argument('--batch', type=int, default=32, help='Batch size') + parser.add_argument('--sequence', type=int, default=12, help='Max sequence length') + parser.add_argument('--step', type=int, default=2000, help='Number of steps to train') + parser.add_argument('--model', help='Model to train') + arguments = parser.parse_args() + + output_dir: Path = arguments.output + batch_size: int = arguments.batch + sequence_size: int = arguments.sequence + max_step: int = arguments.step + model: str = arguments.model + + if not output_dir.exists(): + output_dir.mkdir(parents=True) + if (output_dir / 'train').exists(): + shutil.rmtree(output_dir / 'train') + writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) + + rng = jax.random.PRNGKey(0) + if model == 'stack': + network_init, network_fun = StackedLSTMModel(1, 16, 2) + _, network_params = network_init(rng, 1) + else: + print('Model not implemented') + sys.exit(1) + + @jax.jit + def loss_fun(preds, targets): + return lax.mul(-targets, lax.log(preds)) - lax.mul(1 - targets, lax.log(1 - preds)) + + def accuracy_fun(preds, targets): + return jnp.mean((preds[:, 1] > preds[:, 0]).int64() == targets) + + opt_init, opt_update, opt_params = optimizers.adam(1e-3) + + @jax.jit + def update_fun(step, opt_state, preds, targets): + return opt_update(step, jax.grad(loss_fun)(preds, targets), opt_params(opt_state)) + + opt_state = opt_init(network_params) + + sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32) + sequence_data[0] = sequence_size + sequence_data_reshaped = np.reshape(np.broadcast_to( + sequence_data, + (batch_size, max_step)).transpose((1, 0)), (batch_size * max_step)) + dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8) + DataGenerator.MAX_LENGTH = sequence_size + state = [(jnp.zeros((batch_size, 16)), + jnp.zeros((batch_size, 16)))] * 3 + with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size, + pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator: + data_np = batch_generator.batch_data + label_np = batch_generator.batch_label + + running_loss = 0.0 + running_accuracy = 0.0 + running_count = 0 + summary_period = max(max_step // 100, 1) + np.set_printoptions(precision=2) + try: + start_time = time.time() + while batch_generator.epoch == 0: + # data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) + data = jnp.asarray(data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]]) + label = jnp.asarray(label_np) + + preds, _state = network_fun(network_params, data, state) + loss = loss_fun(preds, label) + update_fun(batch_generator.global_step, opt_state, preds, label) + + running_loss += loss + running_accuracy += accuracy_fun(preds, label) + running_count += 1 + if (batch_generator.step + 1) % summary_period == 0: + writer_train.add_scalar('metric/loss', running_loss / running_count, + global_step=batch_generator.step) + writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), + global_step=batch_generator.step) + + speed = summary_period / (time.time() - start_time) + print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}' + f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') + start_time = time.time() + running_loss = 0.0 + running_accuracy = 0.0 + running_count = 0 + data_np, label_np = batch_generator.next_batch() + except KeyboardInterrupt: + print('\r ', end='\r') + writer_train.close() + + running_accuracy = 0.0 + running_count = 0 + for _ in range(math.ceil(1000 / batch_size)): + data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) + data = jnp.asarray(data_np) + label = jnp.asarray(label_np) + + preds, _state = network_fun(network_params, data, state) + running_accuracy += accuracy_fun(preds, label) + running_count += 1 + print(f'Validation accuracy: {running_accuracy / running_count:.03f}') + + test_data = [ + [[1.], [0.], [0.], [1.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], + [[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], + [[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]], + [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]], + [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], + [0.], [1.], [0.], [0.], [1.], [0.], [0.]], + [[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], + [0.], [1.], [0.], [0.], [1.], [0.]], + ] + test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32) + running_accuracy = 0.0 + running_count = 0 + state = [(jnp.zeros((1, 16)), jnp.zeros((1, 16)))] * 3 + for data, label in zip(test_data, test_label): + outputs, _states = network_fun( + network_params, + jnp.asarray(np.expand_dims(np.asarray(data, dtype=np.float32), 1)), + state) + running_accuracy += accuracy_fun(preds, label) + running_count += 1 + print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}' + f', output: {outputs.detach().cpu().numpy()[0, -1, 1]:.02f}') + print(f'Test accuracy: {running_accuracy / running_count:.03f}') + + +if __name__ == '__main__': + main() diff --git a/recorder.py b/recorder.py new file mode 100644 index 0000000..2d9f573 --- /dev/null +++ b/recorder.py @@ -0,0 +1,155 @@ +from argparse import ArgumentParser +from pathlib import Path +import shutil +import time + +import numpy as np +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from src.torch_networks import LSTMModel, StackedLSTMModel +from src.torch_utils.train import parameter_summary + + +def generate_data(batch_size: int, sequence_length: int, dim: int) -> np.ndarray: + return np.random.uniform(0, dim, (batch_size, sequence_length)).astype(np.int64) + + +def main(): + parser = ArgumentParser() + parser.add_argument('--output', type=Path, default=Path('output', 'recorder'), help='Output dir') + parser.add_argument('--batch', type=int, default=32, help='Batch size') + parser.add_argument('--sequence', type=int, default=8, help='Max sequence length') + parser.add_argument('--dimension', type=int, default=15, help='Input dimension') + parser.add_argument('--step', type=int, default=20_000, help='Number of steps to train') + parser.add_argument('--model', help='Model to train') + arguments = parser.parse_args() + + output_dir: Path = arguments.output + batch_size: int = arguments.batch + sequence_size: int = arguments.sequence + input_dim: int = arguments.dimension + max_step: int = arguments.step + model: str = arguments.model + + if not output_dir.exists(): + output_dir.mkdir(parents=True) + if (output_dir / 'train').exists(): + shutil.rmtree(output_dir / 'train') + writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + if model == 'cell': + network = StackedLSTMModel(input_dim + 1).to(device) + else: + network = LSTMModel(input_dim + 1).to(device) + + # Save parameters info + with open(output_dir / 'parameters.csv', 'w') as param_file: + param_summary = parameter_summary(network) + names = [len(name) for name, _, _ in param_summary] + shapes = [len(str(shape)) for _, shape, _ in param_summary] + param_file.write( + '\n'.join( + [f'{name: <{max(names)}} {str(shape): <{max(shapes)}} {size}' + for name, shape, size in param_summary])) + + optimizer = torch.optim.Adam(network.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.996) + criterion = nn.CrossEntropyLoss() + + zero_data = torch.zeros((batch_size, sequence_size, input_dim + 1)).to(device) + # writer_train.add_graph(network, (zero_data[:, :5],)) + + if model == 'cell': + state = [(torch.zeros((batch_size, (input_dim + 1) * 2)).to(device), + torch.zeros((batch_size, (input_dim + 1) * 2)).to(device))] * network.NUM_LAYERS + else: + state = None + running_loss = 0.0 + running_accuracy = 0.0 + running_count = 0 + summary_period = max_step // 100 + np.set_printoptions(precision=2) + try: + start_time = time.time() + for step in range(1, max_step + 1): + label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)), input_dim) + label = torch.from_numpy(label_np).to(device) + data = nn.functional.one_hot(label, input_dim + 1).float() + data[:, :, input_dim] = 1.0 + + optimizer.zero_grad() + + outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state) + outputs = outputs[label_np.shape[1]:, :, :-1] + # state = (state[0].detach(), state[1].detach()) + # state = (state[0][:, :1].detach().expand(state[0].shape[0], batch_size, *state[0].shape[2:]), + # state[1][:, :1].detach().expand(state[1].shape[0], batch_size, *state[1].shape[2:])) + + data[:, :, input_dim] = 0.0 + loss = criterion(outputs[:, 0], label[0]) + for i in range(1, batch_size): + loss += criterion(outputs[:, i], label[i]) + loss /= batch_size + running_loss += loss.item() + running_accuracy += ( + torch.sum((torch.argmax(outputs, 2).transpose(1, 0) == label).long(), + 1) == label.size(1)).float().mean().item() + running_count += 1 + if step % summary_period == 0: + writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=step) + writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=step) + writer_train.add_scalar('optimizer/lr', scheduler.get_last_lr()[0], global_step=step) + scheduler.step() + + speed = summary_period / (time.time() - start_time) + print(f'Step {step}, loss: {running_loss / running_count:.03e}' + f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') + start_time = time.time() + running_loss = 0.0 + running_accuracy = 0.0 + running_count = 0 + loss.backward() + optimizer.step() + except KeyboardInterrupt: + print('\r ', end='\r') + writer_train.close() + + network.eval() + test_label = [ + np.asarray([[0, 0, 0, 0]], dtype=np.int64), + np.asarray([[2, 2, 2, 2]], dtype=np.int64), + np.asarray([[8, 1, 10, 5, 6, 13]], dtype=np.int64), + generate_data(1, 4, input_dim), + np.asarray([[0, 0, 0, 0, 0, 0]], dtype=np.int64), + np.asarray([[5, 5, 5, 5, 5, 5]], dtype=np.int64), + np.asarray([[11, 0, 2, 8, 5, 1]], dtype=np.int64), + generate_data(1, sequence_size, input_dim) + ] + zero_data = torch.zeros((1, sequence_size, input_dim + 1)).to(device) + if model == 'cell': + state = [(torch.zeros((1, (input_dim + 1) * 2)).to(device), + torch.zeros((1, (input_dim + 1) * 2)).to(device))] * 3 + running_accuracy = 0.0 + running_count = 0 + for label_np in test_label: + label = torch.from_numpy(label_np).to(device) + data = nn.functional.one_hot(label, input_dim + 1).float() + data[:, :, input_dim] = 1.0 + + outputs, _state = network(torch.cat([data, zero_data[:, :label_np.shape[1]]], dim=1).transpose(1, 0), state) + outputs = outputs[label_np.shape[1]:, :, :-1] + # state = (state[0].detach(), state[1].detach()) + + running_accuracy += ( + torch.sum( + (torch.argmax(outputs, 2).transpose(1, 0) == label).long(), 1) == label.size(1)).float().mean().item() + running_count += 1 + print(f'{len(label_np)} label: {label_np}, output: {torch.argmax(outputs, 2)[:, 0].detach().cpu().numpy()}') + print(f'Test accuracy: {running_accuracy / running_count:.03f}') + + +if __name__ == '__main__': + main() diff --git a/src/jax_networks.py b/src/jax_networks.py new file mode 100644 index 0000000..57afce0 --- /dev/null +++ b/src/jax_networks.py @@ -0,0 +1,104 @@ +import jax +import jax.numpy as jnp +from jax import lax +from jax import nn + + +@jax.jit +def sigmoid(x): + return 1 / (1 + lax.exp(-x)) + + +def LSTMCell(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()): + def init_fun(rng, input_size): + k1, k2, k3 = jax.random.split(rng, 3) + weight_ih = w_init(k1, (input_size, 4 * hidden_size)) + weight_hh = w_init(k2, (hidden_size, 4 * hidden_size)) + bias = b_init(k3, (4 * hidden_size,)) + return hidden_size, (weight_ih, weight_hh, bias) + + @jax.jit + def apply_fun(params, inputs, states): + (hx, cx) = states + weight_ih, weight_hh, bias = params + gates = lax.dot(inputs, weight_ih) + lax.dot(hx, weight_hh) + bias + + in_gate = sigmoid(gates[:, :16]) + forget_gate = sigmoid(gates[:, 16:32]) + cell_gate = lax.tanh(gates[:, 32:48]) + out_gate = sigmoid(gates[:, 48:]) + + cy = lax.mul(forget_gate, cx) + lax.mul(in_gate, cell_gate) + hy = lax.mul(out_gate, lax.tanh(cy)) + + return (hy, cy) + + return init_fun, apply_fun + + +def LSTMLayer(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()): + cell_init, cell_fun = LSTMCell(hidden_size, w_init=w_init, b_init=b_init) + + @jax.jit + def apply_fun(params, inputs, states): + output = [] + for input_data in inputs: + states = cell_fun(params, input_data, states) + output.append(states[0]) + return jnp.stack(output), states + + return cell_init, apply_fun + + +def StackedLSTM(hidden_size, num_layer, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()): + layer_inits = [] + layer_funs = [] + for _ in range(num_layer): + layer_init, layer_fun = LSTMLayer(hidden_size, w_init=w_init, b_init=b_init) + layer_inits.append(layer_init) + layer_funs.append(layer_fun) + del layer_init + del layer_fun + + def init_fun(rng, input_size): + params = [] + output_size = input_size + for init in layer_inits: + output_size, layer_params = init(rng, output_size) + params.append(layer_params) + return output_size, tuple(params) + + @jax.jit + def apply_fun(params, inputs, states): + output = inputs + output_states = [] + for layer_id in range(num_layer): + output, out_state = layer_funs[layer_id](params[layer_id], output, states[layer_id]) + output_states.append(out_state) + return jnp.stack(output), output_states + + return init_fun, apply_fun + + +def StackedLSTMModel(input_size, hidden_size=-1, output_size=-1, num_layer=3, + w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()): + hidden_size = hidden_size if hidden_size > 0 else input_size * 2 + output_size = output_size if output_size > 0 else input_size + stacked_init, stacked_fun = StackedLSTM(hidden_size, num_layer, w_init=w_init, b_init=b_init) + + def init_fun(rng, input_size): + stacked_output_size, stacked_params = stacked_init(rng, input_size) + + k1, k2 = jax.random.split(rng) + weight = w_init(k1, (stacked_output_size, output_size)) + bias = b_init(k2, (output_size,)) + + return output_size, ((weight, bias), stacked_params) + + @jax.jit + def apply_fun(params, inputs, states): + (weight, bias), stacked_params = params + output, new_states = stacked_fun(stacked_params, inputs, states) + return lax.dot(output[-1], weight) + bias, new_states + + return init_fun, apply_fun diff --git a/src/torch_networks.py b/src/torch_networks.py new file mode 100644 index 0000000..4bf3c6d --- /dev/null +++ b/src/torch_networks.py @@ -0,0 +1,136 @@ +import torch +from torch import nn + + +class LSTMModel(nn.Module): + NUM_LAYERS = 3 + + def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1): + super().__init__() + hidden_size = hidden_size if hidden_size > 0 else input_size * 2 + output_size = output_size if output_size > 0 else input_size + + self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS) + self.dense = nn.Linear(hidden_size, output_size) + + def forward(self, input_data: torch.Tensor, init_state=None) -> tuple[ + torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + output, state = self.lstm(input_data, init_state) + return self.dense(output), state + + +class LSTMCell(torch.jit.ScriptModule): + def __init__(self, input_size, hidden_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = nn.Parameter(torch.randn(input_size, 4 * hidden_size)) + self.weight_hh = nn.Parameter(torch.randn(hidden_size, 4 * hidden_size)) + self.bias = nn.Parameter(torch.randn(4 * hidden_size)) + + @torch.jit.script_method + def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ + torch.Tensor, torch.Tensor]: + hx, cx = state + gates = ( + (torch.mm(input_data, self.weight_ih) + torch.mm(hx, self.weight_hh) + self.bias)) + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + + return (hy, cy) + + +class LSTMLayer(torch.jit.ScriptModule): + def __init__(self, input_size, hidden_size): + super().__init__() + self.cell = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) + + @torch.jit.script_method + def forward(self, input_data: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[ + torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + inputs = input_data.unbind(0) + outputs = torch.jit.annotate(list[torch.Tensor], []) + for i in range(len(inputs)): + state = self.cell(inputs[i], state) + outputs += [state[0]] + return torch.stack(outputs), state + + +class StackedLSTM(torch.jit.ScriptModule): + def __init__(self, input_size, hidden_size, num_layers): + super().__init__() + self.layers = nn.ModuleList( + [LSTMLayer(input_size=input_size, hidden_size=hidden_size)] + [ + LSTMLayer(input_size=hidden_size, hidden_size=hidden_size) for _ in range(num_layers - 1)]) + + @torch.jit.script_method + def forward(self, input_data: torch.Tensor, states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[ + torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: + output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], []) + output = input_data + i = 0 + for rnn_layer in self.layers: + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + i += 1 + return output, output_states + + +class StackedLSTMModel(torch.jit.ScriptModule): + NUM_LAYERS = 3 + + def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1): + super().__init__() + hidden_size = hidden_size if hidden_size > 0 else input_size * 2 + output_size = output_size if output_size > 0 else input_size + + self.lstm = StackedLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=self.NUM_LAYERS) + self.dense = nn.Linear(hidden_size, output_size) + + @torch.jit.script_method + def forward(self, input_data: torch.Tensor, init_state: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[ + torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: + output, state = self.lstm(input_data, init_state) + return self.dense(output), state + + +class LSTMCellModel(torch.jit.ScriptModule): + NUM_LAYERS = 3 + + def __init__(self, input_size: int, hidden_size: int = -1, output_size: int = -1): + super().__init__() + hidden_size = hidden_size if hidden_size > 0 else input_size * 2 + output_size = output_size if output_size > 0 else input_size + + self.hidden_size = hidden_size + self.layers = nn.ModuleList([ + nn.LSTMCell(input_size=input_size, hidden_size=hidden_size), + nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size), + nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size) + ]) + self.dense = nn.Linear(hidden_size, output_size) + + @torch.jit.script_method + def forward(self, input_data: torch.Tensor, init_states: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[ + torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: + output_states = torch.jit.annotate(list[tuple[torch.Tensor, torch.Tensor]], []) + cell_inputs = input_data.unbind(0) + cell_id = 0 + for cell in self.layers: + cell_state = init_states[cell_id] + cell_outputs = torch.jit.annotate(list[torch.Tensor], []) + for i in range(len(cell_inputs)): + cell_state = cell(cell_inputs[i], cell_state) + cell_outputs += [cell_state[0]] + cell_inputs = cell_outputs + output_states += [cell_state] + cell_id += 1 + return self.dense(torch.stack(cell_outputs)), output_states diff --git a/src/torch_utils b/src/torch_utils new file mode 160000 index 0000000..1bac462 --- /dev/null +++ b/src/torch_utils @@ -0,0 +1 @@ +Subproject commit 1bac46219b42fe41ba3568fdde3ca364b02e46e9