time-series/modulo_jax.py
2021-08-26 21:12:40 +09:00

191 lines
8.1 KiB
Python

from argparse import ArgumentParser
from pathlib import Path
import math
import shutil
import time
import sys
import jax
from jax import lax
import jax.numpy as jnp
from jax.experimental import optimizers
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from src.jax_networks import StackedLSTMModel
from src.torch_utils.utils.batch_generator import BatchGenerator
def generate_data(batch_size: int, data_length: int) -> tuple[np.ndarray, np.ndarray]:
modulos = np.random.uniform(3, data_length // 2 + 1, batch_size).astype(np.int32)
data = np.zeros((data_length, batch_size, 1), dtype=np.float32)
starts = []
for mod in modulos:
starts.append(int(np.random.uniform(0, mod)))
for i in range(batch_size):
# np.where(data[i] % modulos[i] == starts[i], [1.0], data[i])
for j in range(starts[i], data_length, modulos[i]):
data[j, i, 0] = 1.0
label = []
for i in range(batch_size):
label.append(1 if len(data[:, i]) % modulos[i] == starts[i] else 0)
return data, np.asarray(label, dtype=np.int64)
class DataGenerator:
MAX_LENGTH = 1
INITIALIZED = False
@staticmethod
def pipeline(sequence_length, _dummy_label):
if not DataGenerator.INITIALIZED:
np.random.seed(time.time_ns() % (2**32))
DataGenerator.INITIALIZED = True
data = np.zeros((DataGenerator.MAX_LENGTH, 1), dtype=np.float32)
modulo = int(np.random.uniform(3, sequence_length // 2 + 1))
start = int(np.random.uniform(0, modulo))
for i in range(start, sequence_length, modulo):
data[i, 0] = 1.0
return data, np.asarray(1 if sequence_length % modulo == start else 0, dtype=np.int64)
def main():
parser = ArgumentParser()
parser.add_argument('--output', type=Path, default=Path('output', 'modulo'), help='Output dir')
parser.add_argument('--batch', type=int, default=32, help='Batch size')
parser.add_argument('--sequence', type=int, default=12, help='Max sequence length')
parser.add_argument('--step', type=int, default=2000, help='Number of steps to train')
parser.add_argument('--model', help='Model to train')
arguments = parser.parse_args()
output_dir: Path = arguments.output
batch_size: int = arguments.batch
sequence_size: int = arguments.sequence
max_step: int = arguments.step
model: str = arguments.model
if not output_dir.exists():
output_dir.mkdir(parents=True)
if (output_dir / 'train').exists():
shutil.rmtree(output_dir / 'train')
writer_train = SummaryWriter(log_dir=str(output_dir / 'train'), flush_secs=20)
rng = jax.random.PRNGKey(0)
if model == 'stack':
network_init, network_fun = StackedLSTMModel(1, 16, 2)
_, network_params = network_init(rng, 1)
else:
print('Model not implemented')
sys.exit(1)
@jax.jit
def loss_fun(preds, targets):
return lax.mul(-targets, lax.log(preds)) - lax.mul(1 - targets, lax.log(1 - preds))
def accuracy_fun(preds, targets):
return jnp.mean((preds[:, 1] > preds[:, 0]).int64() == targets)
opt_init, opt_update, opt_params = optimizers.adam(1e-3)
@jax.jit
def update_fun(step, opt_state, preds, targets):
return opt_update(step, jax.grad(loss_fun)(preds, targets), opt_params(opt_state))
opt_state = opt_init(network_params)
sequence_data = np.random.uniform(4, sequence_size + 1, max_step).astype(np.int32)
sequence_data[0] = sequence_size
sequence_data_reshaped = np.reshape(np.broadcast_to(
sequence_data,
(batch_size, max_step)).transpose((1, 0)), (batch_size * max_step))
dummy_label = np.zeros((batch_size * max_step), dtype=np.uint8)
DataGenerator.MAX_LENGTH = sequence_size
state = [(jnp.zeros((batch_size, 16)),
jnp.zeros((batch_size, 16)))] * 3
with BatchGenerator(sequence_data_reshaped, dummy_label, batch_size=batch_size,
pipeline=DataGenerator.pipeline, num_workers=8, shuffle=False) as batch_generator:
data_np = batch_generator.batch_data
label_np = batch_generator.batch_label
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
summary_period = max(max_step // 100, 1)
np.set_printoptions(precision=2)
try:
start_time = time.time()
while batch_generator.epoch == 0:
# data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data = jnp.asarray(data_np.transpose((1, 0, 2))[:sequence_data[batch_generator.step]])
label = jnp.asarray(label_np)
preds, _state = network_fun(network_params, data, state)
loss = loss_fun(preds, label)
update_fun(batch_generator.global_step, opt_state, preds, label)
running_loss += loss
running_accuracy += accuracy_fun(preds, label)
running_count += 1
if (batch_generator.step + 1) % summary_period == 0:
writer_train.add_scalar('metric/loss', running_loss / running_count,
global_step=batch_generator.step)
writer_train.add_scalar('metric/error', 1 - (running_accuracy / running_count),
global_step=batch_generator.step)
speed = summary_period / (time.time() - start_time)
print(f'Step {batch_generator.step}, loss: {running_loss / running_count:.03e}'
f', acc: {running_accuracy / running_count:.03e}, speed: {speed:0.3f}step/s')
start_time = time.time()
running_loss = 0.0
running_accuracy = 0.0
running_count = 0
data_np, label_np = batch_generator.next_batch()
except KeyboardInterrupt:
print('\r ', end='\r')
writer_train.close()
running_accuracy = 0.0
running_count = 0
for _ in range(math.ceil(1000 / batch_size)):
data_np, label_np = generate_data(batch_size, int(np.random.uniform(4, sequence_size + 1)))
data = jnp.asarray(data_np)
label = jnp.asarray(label_np)
preds, _state = network_fun(network_params, data, state)
running_accuracy += accuracy_fun(preds, label)
running_count += 1
print(f'Validation accuracy: {running_accuracy / running_count:.03f}')
test_data = [
[[1.], [0.], [0.], [1.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]],
[[1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]],
[[0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.]],
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.]],
[[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
[0.], [1.], [0.], [0.], [1.], [0.], [0.]],
[[0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.],
[0.], [1.], [0.], [0.], [1.], [0.]],
]
test_label = np.asarray([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], dtype=np.int32)
running_accuracy = 0.0
running_count = 0
state = [(jnp.zeros((1, 16)), jnp.zeros((1, 16)))] * 3
for data, label in zip(test_data, test_label):
outputs, _states = network_fun(
network_params,
jnp.asarray(np.expand_dims(np.asarray(data, dtype=np.float32), 1)),
state)
running_accuracy += accuracy_fun(preds, label)
running_count += 1
print(f'{len(data)} {np.asarray(data)[:, 0]}, label: {label}'
f', output: {outputs.detach().cpu().numpy()[0, -1, 1]:.02f}')
print(f'Test accuracy: {running_accuracy / running_count:.03f}')
if __name__ == '__main__':
main()