Initial commit
This commit is contained in:
commit
e41417c029
8 changed files with 802 additions and 0 deletions
104
src/jax_networks.py
Normal file
104
src/jax_networks.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import nn
|
||||
|
||||
|
||||
@jax.jit
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + lax.exp(-x))
|
||||
|
||||
|
||||
def LSTMCell(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
|
||||
def init_fun(rng, input_size):
|
||||
k1, k2, k3 = jax.random.split(rng, 3)
|
||||
weight_ih = w_init(k1, (input_size, 4 * hidden_size))
|
||||
weight_hh = w_init(k2, (hidden_size, 4 * hidden_size))
|
||||
bias = b_init(k3, (4 * hidden_size,))
|
||||
return hidden_size, (weight_ih, weight_hh, bias)
|
||||
|
||||
@jax.jit
|
||||
def apply_fun(params, inputs, states):
|
||||
(hx, cx) = states
|
||||
weight_ih, weight_hh, bias = params
|
||||
gates = lax.dot(inputs, weight_ih) + lax.dot(hx, weight_hh) + bias
|
||||
|
||||
in_gate = sigmoid(gates[:, :16])
|
||||
forget_gate = sigmoid(gates[:, 16:32])
|
||||
cell_gate = lax.tanh(gates[:, 32:48])
|
||||
out_gate = sigmoid(gates[:, 48:])
|
||||
|
||||
cy = lax.mul(forget_gate, cx) + lax.mul(in_gate, cell_gate)
|
||||
hy = lax.mul(out_gate, lax.tanh(cy))
|
||||
|
||||
return (hy, cy)
|
||||
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def LSTMLayer(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
|
||||
cell_init, cell_fun = LSTMCell(hidden_size, w_init=w_init, b_init=b_init)
|
||||
|
||||
@jax.jit
|
||||
def apply_fun(params, inputs, states):
|
||||
output = []
|
||||
for input_data in inputs:
|
||||
states = cell_fun(params, input_data, states)
|
||||
output.append(states[0])
|
||||
return jnp.stack(output), states
|
||||
|
||||
return cell_init, apply_fun
|
||||
|
||||
|
||||
def StackedLSTM(hidden_size, num_layer, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
|
||||
layer_inits = []
|
||||
layer_funs = []
|
||||
for _ in range(num_layer):
|
||||
layer_init, layer_fun = LSTMLayer(hidden_size, w_init=w_init, b_init=b_init)
|
||||
layer_inits.append(layer_init)
|
||||
layer_funs.append(layer_fun)
|
||||
del layer_init
|
||||
del layer_fun
|
||||
|
||||
def init_fun(rng, input_size):
|
||||
params = []
|
||||
output_size = input_size
|
||||
for init in layer_inits:
|
||||
output_size, layer_params = init(rng, output_size)
|
||||
params.append(layer_params)
|
||||
return output_size, tuple(params)
|
||||
|
||||
@jax.jit
|
||||
def apply_fun(params, inputs, states):
|
||||
output = inputs
|
||||
output_states = []
|
||||
for layer_id in range(num_layer):
|
||||
output, out_state = layer_funs[layer_id](params[layer_id], output, states[layer_id])
|
||||
output_states.append(out_state)
|
||||
return jnp.stack(output), output_states
|
||||
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def StackedLSTMModel(input_size, hidden_size=-1, output_size=-1, num_layer=3,
|
||||
w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()):
|
||||
hidden_size = hidden_size if hidden_size > 0 else input_size * 2
|
||||
output_size = output_size if output_size > 0 else input_size
|
||||
stacked_init, stacked_fun = StackedLSTM(hidden_size, num_layer, w_init=w_init, b_init=b_init)
|
||||
|
||||
def init_fun(rng, input_size):
|
||||
stacked_output_size, stacked_params = stacked_init(rng, input_size)
|
||||
|
||||
k1, k2 = jax.random.split(rng)
|
||||
weight = w_init(k1, (stacked_output_size, output_size))
|
||||
bias = b_init(k2, (output_size,))
|
||||
|
||||
return output_size, ((weight, bias), stacked_params)
|
||||
|
||||
@jax.jit
|
||||
def apply_fun(params, inputs, states):
|
||||
(weight, bias), stacked_params = params
|
||||
output, new_states = stacked_fun(stacked_params, inputs, states)
|
||||
return lax.dot(output[-1], weight) + bias, new_states
|
||||
|
||||
return init_fun, apply_fun
|
||||
Loading…
Add table
Add a link
Reference in a new issue