from argparse import ArgumentParser from pathlib import Path import math import shutil import time import sys import jax from jax import lax import jax.numpy as jnp from jax.experimental import optimizers import numpy as np from torch.utils.tensorboard import SummaryWriter from src.jax_networks import StackedLSTMModel from src.torch_utils.utils.batch_generator import BatchGenerator def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]: modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32) data = np.zeros((data_length, batch_size, 1), dtype=np.float32) starts = [] for mod in modulos: starts.append(int(np.random.uniform(0, mod))) for i in range(batch_size): # np.where(data[i] % modulos[i] == starts[i], [1.0], data[i]) for j in range(starts[i], data_length, modulos[i]): data[j, i, 0] = 1.0 label = [] for i in range(batch_size): label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0) return data, np.asarray(label, dtype=np.int64) class DataGenerator: MAX_LENGTH = 1 INITIALIZED = False @staticmethod def pipeline(sequence_length, _dummy_label): if not DataGenerator.INITIALIZED: np.random.seed(time.time_ns() % (2**32)) DataGenerator.INITIALIZED = True data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32) modulo = int(np.random.uniform(3, sequence_length // 2 + 1)) start = int(np.random.uniform(0, modulo)) for i in range(start, sequence_length, modulo): data[i, 0] = 1.0 return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64) def main(): parser = ArgumentParser() parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir') parser.add_argument('--batch', type=int, default=32, help='Batch size') parser.add_argument('--sequence', type=int, default=12, help='Max sequence length') parser.add_argument('--step', type=int, default=2000, help='Number of steps to train') parser.add_argument('--model', help='Model to train') arguments = parser.parse_args() output_dir: Path = arguments.output batch_size: int = arguments.batch sequence_size: int = arguments.sequence max_step: int = arguments.step model: str = arguments.model if not output_dir.exists(): output_dir.mkdir(parents=True) if (output_dir / 'train').exists(): shutil.rmtree(output_dir / 'train') writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20) rng = jax.random.PRNGKey(0) if model == 'stack': network_init, network_fun = StackedLSTMModel(1, 16, 2) _, network_params = network_init(rng, 1) else: print('Model not implemented') sys.exit(1) @jax.jit def loss_fun(preds, targets): return lax.mul(-targets, lax.log(preds)) - lax.mul(1 - targets, lax.log(1 - preds)) def accuracy_fun(preds, targets): return jnp.mean((preds[:, 1] > preds[:, 0]).int64() == targets) opt_init, opt_update, opt_params = optimizers.adam(1e-3) @jax.jit def update_fun(step, opt_state, preds, targets): return opt_update(step, jax.grad(loss_fun)(preds, targets), opt_params(opt_state)) opt_state = opt_init(network_params) sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32) sequence_data[0] = sequence_size sequence_data_reshaped = np.reshape(np.broadcast_to( sequence_data, (batch_size, max_step)).transpose((1, 0)), (batch_size * max_step)) dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8) DataGenerator.MAX_LENGTH = sequence_size state = [(jnp.zeros((batch_size, 16)), jnp.zeros((batch_size, 16)))] * 3 with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size, pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator: data_np = batch_generator.batch_data label_np = batch_generator.batch_label running_loss = 0.0 running_accuracy = 0.0 running_count = 0 summary_period = max(max_step // 100, 1) np.set_printoptions(precision=2) try: start_time = time.time() while batch_generator.epoch == 0: # data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) data = jnp.asarray(data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]]) label = jnp.asarray(label_np) preds, _state = network_fun(network_params, data, state) loss = loss_fun(preds, label) update_fun(batch_generator.global_step, opt_state, preds, label) running_loss += loss running_accuracy += accuracy_fun(preds, label) running_count += 1 if (batch_generator.step + 1) % summary_period == 0: writer_train.add_scalar('metric/loss', running_loss / running_count, global_step=batch_generator.step) writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count), global_step=batch_generator.step) speed = summary_period / (time.time() - start_time) print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}' f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s') start_time = time.time() running_loss = 0.0 running_accuracy = 0.0 running_count = 0 data_np, label_np = batch_generator.next_batch() except KeyboardInterrupt: print('\r ', end='\r') writer_train.close() running_accuracy = 0.0 running_count = 0 for _ in range(math.ceil(1000 / batch_size)): data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1))) data = jnp.asarray(data_np) label = jnp.asarray(label_np) preds, _state = network_fun(network_params, data, state) running_accuracy += accuracy_fun(preds, label) running_count += 1 print(f'Validation accuracy: {running_accuracy / running_count:.03f}') test_data = [ [[1.], [0.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]], [[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]], ] test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32) running_accuracy = 0.0 running_count = 0 state = [(jnp.zeros((1, 16)), jnp.zeros((1, 16)))] * 3 for data, label in zip(test_data, test_label): outputs, _states = network_fun( network_params, jnp.asarray(np.expand_dims(np.asarray(data, dtype=np.float32), 1)), state) running_accuracy += accuracy_fun(preds, label) running_count += 1 print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}' f', output: {outputs.detach().cpu().numpy()[0, -1, 1]:.02f}') print(f'Test accuracy: {running_accuracy / running_count:.03f}') if __name__ == '__main__': main()