import jax import jax.numpy as jnp from jax import lax from jax import nn @jax.jit def sigmoid(x): return 1 / (1 + lax.exp(-x)) def LSTMCell(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()): def init_fun(rng, input_size): k1, k2, k3 = jax.random.split(rng, 3) weight_ih = w_init(k1, (input_size, 4 * hidden_size)) weight_hh = w_init(k2, (hidden_size, 4 * hidden_size)) bias = b_init(k3, (4 * hidden_size,)) return hidden_size, (weight_ih, weight_hh, bias) @jax.jit def apply_fun(params, inputs, states): (hx, cx) = states weight_ih, weight_hh, bias = params gates = lax.dot(inputs, weight_ih) + lax.dot(hx, weight_hh) + bias in_gate = sigmoid(gates[:, :16]) forget_gate = sigmoid(gates[:, 16:32]) cell_gate = lax.tanh(gates[:, 32:48]) out_gate = sigmoid(gates[:, 48:]) cy = lax.mul(forget_gate, cx) + lax.mul(in_gate, cell_gate) hy = lax.mul(out_gate, lax.tanh(cy)) return (hy, cy) return init_fun, apply_fun def LSTMLayer(hidden_size, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()): cell_init, cell_fun = LSTMCell(hidden_size, w_init=w_init, b_init=b_init) @jax.jit def apply_fun(params, inputs, states): output = [] for input_data in inputs: states = cell_fun(params, input_data, states) output.append(states[0]) return jnp.stack(output), states return cell_init, apply_fun def StackedLSTM(hidden_size, num_layer, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()): layer_inits = [] layer_funs = [] for _ in range(num_layer): layer_init, layer_fun = LSTMLayer(hidden_size, w_init=w_init, b_init=b_init) layer_inits.append(layer_init) layer_funs.append(layer_fun) del layer_init del layer_fun def init_fun(rng, input_size): params = [] output_size = input_size for init in layer_inits: output_size, layer_params = init(rng, output_size) params.append(layer_params) return output_size, tuple(params) @jax.jit def apply_fun(params, inputs, states): output = inputs output_states = [] for layer_id in range(num_layer): output, out_state = layer_funs[layer_id](params[layer_id], output, states[layer_id]) output_states.append(out_state) return jnp.stack(output), output_states return init_fun, apply_fun def StackedLSTMModel(input_size, hidden_size=-1, output_size=-1, num_layer=3, w_init=nn.initializers.glorot_normal(), b_init=nn.initializers.normal()): hidden_size = hidden_size if hidden_size > 0 else input_size * 2 output_size = output_size if output_size > 0 else input_size stacked_init, stacked_fun = StackedLSTM(hidden_size, num_layer, w_init=w_init, b_init=b_init) def init_fun(rng, input_size): stacked_output_size, stacked_params = stacked_init(rng, input_size) k1, k2 = jax.random.split(rng) weight = w_init(k1, (stacked_output_size, output_size)) bias = b_init(k2, (output_size,)) return output_size, ((weight, bias), stacked_params) @jax.jit def apply_fun(params, inputs, states): (weight, bias), stacked_params = params output, new_states = stacked_fun(stacked_params, inputs, states) return lax.dot(output[-1], weight) + bias, new_states return init_fun, apply_fun