Initial commit
This commit is contained in:
commit
e41417c029
8 changed files with 802 additions and 0 deletions
191
modulo_jax.py
Normal file
191
modulo_jax.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
import sys
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import optimizers
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.jax_networks import StackedLSTMModel
|
||||
from src.torch_utils.utils.batch_generator import BatchGenerator
|
||||
|
||||
|
||||
def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32)
|
||||
data = np.zeros((data_length, batch_size, 1), dtype=np.float32)
|
||||
starts = []
|
||||
for mod in modulos:
|
||||
starts.append(int(np.random.uniform(0, mod)))
|
||||
for i in range(batch_size):
|
||||
# np.where(data[i] % modulos[i] == starts[i], [1.0], data[i])
|
||||
for j in range(starts[i], data_length, modulos[i]):
|
||||
data[j, i, 0] = 1.0
|
||||
label = []
|
||||
for i in range(batch_size):
|
||||
label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0)
|
||||
return data, np.asarray(label, dtype=np.int64)
|
||||
|
||||
|
||||
class DataGenerator:
|
||||
MAX_LENGTH = 1
|
||||
INITIALIZED = False
|
||||
|
||||
@staticmethod
|
||||
def pipeline(sequence_length, _dummy_label):
|
||||
if not DataGenerator.INITIALIZED:
|
||||
np.random.seed(time.time_ns() % (2**32))
|
||||
DataGenerator.INITIALIZED = True
|
||||
data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32)
|
||||
modulo = int(np.random.uniform(3, sequence_length // 2 + 1))
|
||||
start = int(np.random.uniform(0, modulo))
|
||||
for i in range(start, sequence_length, modulo):
|
||||
data[i, 0] = 1.0
|
||||
return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir')
|
||||
parser.add_argument('--batch', type=int, default=32, help='Batch size')
|
||||
parser.add_argument('--sequence', type=int, default=12, help='Max sequence length')
|
||||
parser.add_argument('--step', type=int, default=2000, help='Number of steps to train')
|
||||
parser.add_argument('--model', help='Model to train')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
output_dir: Path = arguments.output
|
||||
batch_size: int = arguments.batch
|
||||
sequence_size: int = arguments.sequence
|
||||
max_step: int = arguments.step
|
||||
model: str = arguments.model
|
||||
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
if (output_dir / 'train').exists():
|
||||
shutil.rmtree(output_dir / 'train')
|
||||
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
if model == 'stack':
|
||||
network_init, network_fun = StackedLSTMModel(1, 16, 2)
|
||||
_, network_params = network_init(rng, 1)
|
||||
else:
|
||||
print('Model not implemented')
|
||||
sys.exit(1)
|
||||
|
||||
@jax.jit
|
||||
def loss_fun(preds, targets):
|
||||
return lax.mul(-targets, lax.log(preds)) - lax.mul(1 - targets, lax.log(1 - preds))
|
||||
|
||||
def accuracy_fun(preds, targets):
|
||||
return jnp.mean((preds[:, 1] > preds[:, 0]).int64() == targets)
|
||||
|
||||
opt_init, opt_update, opt_params = optimizers.adam(1e-3)
|
||||
|
||||
@jax.jit
|
||||
def update_fun(step, opt_state, preds, targets):
|
||||
return opt_update(step, jax.grad(loss_fun)(preds, targets), opt_params(opt_state))
|
||||
|
||||
opt_state = opt_init(network_params)
|
||||
|
||||
sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32)
|
||||
sequence_data[0] = sequence_size
|
||||
sequence_data_reshaped = np.reshape(np.broadcast_to(
|
||||
sequence_data,
|
||||
(batch_size, max_step)).transpose((1, 0)), (batch_size * max_step))
|
||||
dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8)
|
||||
DataGenerator.MAX_LENGTH = sequence_size
|
||||
state = [(jnp.zeros((batch_size, 16)),
|
||||
jnp.zeros((batch_size, 16)))] * 3
|
||||
with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size,
|
||||
pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator:
|
||||
data_np = batch_generator.batch_data
|
||||
label_np = batch_generator.batch_label
|
||||
|
||||
running_loss = 0.0
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
summary_period = max(max_step // 100, 1)
|
||||
np.set_printoptions(precision=2)
|
||||
try:
|
||||
start_time = time.time()
|
||||
while batch_generator.epoch == 0:
|
||||
# data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
|
||||
data = jnp.asarray(data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]])
|
||||
label = jnp.asarray(label_np)
|
||||
|
||||
preds, _state = network_fun(network_params, data, state)
|
||||
loss = loss_fun(preds, label)
|
||||
update_fun(batch_generator.global_step, opt_state, preds, label)
|
||||
|
||||
running_loss += loss
|
||||
running_accuracy += accuracy_fun(preds, label)
|
||||
running_count += 1
|
||||
if (batch_generator.step + 1) % summary_period == 0:
|
||||
writer_train.add_scalar('metric/loss', running_loss / running_count,
|
||||
global_step=batch_generator.step)
|
||||
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count),
|
||||
global_step=batch_generator.step)
|
||||
|
||||
speed = summary_period / (time.time() - start_time)
|
||||
print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}'
|
||||
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
|
||||
start_time = time.time()
|
||||
running_loss = 0.0
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
data_np, label_np = batch_generator.next_batch()
|
||||
except KeyboardInterrupt:
|
||||
print('\r ', end='\r')
|
||||
writer_train.close()
|
||||
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
for _ in range(math.ceil(1000 / batch_size)):
|
||||
data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
|
||||
data = jnp.asarray(data_np)
|
||||
label = jnp.asarray(label_np)
|
||||
|
||||
preds, _state = network_fun(network_params, data, state)
|
||||
running_accuracy += accuracy_fun(preds, label)
|
||||
running_count += 1
|
||||
print(f'Validation accuracy: {running_accuracy / running_count:.03f}')
|
||||
|
||||
test_data = [
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
|
||||
[[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]],
|
||||
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
|
||||
[0.], [1.], [0.], [0.], [1.], [0.], [0.]],
|
||||
[[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
|
||||
[0.], [1.], [0.], [0.], [1.], [0.]],
|
||||
]
|
||||
test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32)
|
||||
running_accuracy = 0.0
|
||||
running_count = 0
|
||||
state = [(jnp.zeros((1, 16)), jnp.zeros((1, 16)))] * 3
|
||||
for data, label in zip(test_data, test_label):
|
||||
outputs, _states = network_fun(
|
||||
network_params,
|
||||
jnp.asarray(np.expand_dims(np.asarray(data, dtype=np.float32), 1)),
|
||||
state)
|
||||
running_accuracy += accuracy_fun(preds, label)
|
||||
running_count += 1
|
||||
print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}'
|
||||
f', output: {outputs.detach().cpu().numpy()[0, -1, 1]:.02f}')
|
||||
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue