23. Neural Network Regression with JAX and Optax#
GPU
This lecture was built using a machine with JAX installed and access to a GPU.
To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.
To run this lecture on your own machine, you need to install Google JAX.
In a previous lecture, we showed how to implement regression using a neural network via the popular deep learning library Keras.
In this lecture, we solve the same problem directly, using JAX operations rather than relying on the Keras frontend.
The objective is to understand the nuts and bolts of the exercise better, as well as to explore more features of JAX.
The lecture proceeds in three stages:
We repeat the Keras exercise, to give ourselves a benchmark.
We solve the same problem in pure JAX, using pytree operations and gradient descent.
We solve the same problem using a combination of JAX and Optax, an optimization library build for JAX.
We begin with imports and installs.
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import os
from time import time
!pip install keras
Show code cell output
Collecting keras
Downloading keras-3.9.0-py3-none-any.whl.metadata (6.1 kB)
Collecting absl-py (from keras)
Downloading absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Requirement already satisfied: numpy in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (1.26.4)
Requirement already satisfied: rich in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (13.7.1)
Collecting namex (from keras)
Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes)
Requirement already satisfied: h5py in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (3.11.0)
Collecting optree (from keras)
Downloading optree-0.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (49 kB)
Requirement already satisfied: ml-dtypes in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (0.5.0)
Requirement already satisfied: packaging in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (24.1)
Requirement already satisfied: typing-extensions>=4.5.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from optree->keras) (4.11.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from rich->keras) (2.2.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from rich->keras) (2.15.1)
Requirement already satisfied: mdurl~=0.1 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich->keras) (0.1.0)
Downloading keras-3.9.0-py3-none-any.whl (1.3 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.3 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 87.9 MB/s eta 0:00:00
?25h
Downloading absl_py-2.1.0-py3-none-any.whl (133 kB)
Downloading namex-0.0.8-py3-none-any.whl (5.8 kB)
Downloading optree-0.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (403 kB)
Installing collected packages: namex, optree, absl-py, keras
Successfully installed absl-py-2.1.0 keras-3.9.0 namex-0.0.8 optree-0.14.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
!pip install optax
Show code cell output
Collecting optax
Downloading optax-0.2.4-py3-none-any.whl.metadata (8.3 kB)
Requirement already satisfied: absl-py>=0.7.1 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from optax) (2.1.0)
Collecting chex>=0.1.87 (from optax)
Downloading chex-0.1.89-py3-none-any.whl.metadata (17 kB)
Requirement already satisfied: jax>=0.4.27 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from optax) (0.4.35)
Requirement already satisfied: jaxlib>=0.4.27 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from optax) (0.4.34)
Requirement already satisfied: numpy>=1.18.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from optax) (1.26.4)
Collecting etils[epy] (from optax)
Downloading etils-1.12.2-py3-none-any.whl.metadata (6.5 kB)
Requirement already satisfied: typing_extensions>=4.2.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from chex>=0.1.87->optax) (4.11.0)
Requirement already satisfied: setuptools in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from chex>=0.1.87->optax) (75.1.0)
Requirement already satisfied: toolz>=0.9.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from chex>=0.1.87->optax) (0.12.0)
Requirement already satisfied: ml-dtypes>=0.4.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from jax>=0.4.27->optax) (0.5.0)
Requirement already satisfied: opt-einsum in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from jax>=0.4.27->optax) (3.4.0)
Requirement already satisfied: scipy>=1.10 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from jax>=0.4.27->optax) (1.13.1)
Downloading optax-0.2.4-py3-none-any.whl (319 kB)
Downloading chex-0.1.89-py3-none-any.whl (99 kB)
Downloading etils-1.12.2-py3-none-any.whl (167 kB)
Installing collected packages: etils, chex, optax
Successfully installed chex-0.1.89 etils-1.12.2 optax-0.2.4
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
os.environ['KERAS_BACKEND'] = 'jax'
import keras
from keras import Sequential
from keras.layers import Dense
import optax
23.1. Set Up#
Let’s hardcode some of the learning-related constants we’ll use across all implementations.
EPOCHS = 4000 # Number of passes through the data set
DATA_SIZE = 400 # Sample size
NUM_LAYERS = 4 # Depth of the network
OUTPUT_DIM = 10 # Output dimension of input and hidden layers
LEARNING_RATE = 0.001 # Learning rate for gradient descent
The next piece of code is repeated from our Keras lecture and generates the data.
def generate_data(x_min=0,
x_max=5,
data_size=DATA_SIZE,
seed=1234): # Default size for dataset
np.random.seed(seed)
x = np.linspace(x_min, x_max, num=data_size)
ϵ = 0.2 * np.random.randn(data_size)
y = x**0.5 + np.sin(x) + ϵ
# Return observations as column vectors
x, y = [np.reshape(z, (data_size, 1)) for z in (x, y)]
return x, y
23.2. Training with Keras#
We repeat the Keras training exercise from our Keras lecture as a benchmark.
The code is essentially the same, although written slightly more succinctly.
Here is a function to build the model.
def build_keras_model(num_layers=NUM_LAYERS,
activation_function='tanh'):
model = Sequential()
# Add layers to the network sequentially, from inputs towards outputs
for i in range(NUM_LAYERS-1):
model.add(
Dense(units=OUTPUT_DIM,
activation=activation_function)
)
# Add a final layer that maps to a scalar value, for regression.
model.add(Dense(units=1))
# Embed training configurations
model.compile(
optimizer=keras.optimizers.SGD(),
loss='mean_squared_error'
)
return model
Here is a function to train the model.
def train_keras_model(model, x, y, x_validate, y_validate):
print(f"Training NN using Keras.")
start_time = time()
training_history = model.fit(
x, y,
batch_size=max(x.shape),
verbose=0,
epochs=EPOCHS,
validation_data=(x_validate, y_validate)
)
elapsed = time() - start_time
mse = model.evaluate(x_validate, y_validate, verbose=2)
print(f"Trained Keras model in {elapsed:.2f} seconds with final MSE on validation data = {mse}")
return model, training_history
The next function visualizes the prediction.
def plot_keras_output(model, x, y, x_validate, y_validate):
y_predict = model.predict(x_validate, verbose=2)
fig, ax = plt.subplots()
ax.scatter(x, y)
ax.plot(x, y_predict, label="fitted model", color='black')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()
Here’s a function to run all the routines above.
def keras_run_all():
model = build_keras_model()
x, y = generate_data()
x_validate, y_validate = generate_data()
model, training_history = train_keras_model(
model, x, y, x_validate, y_validate
)
plot_keras_output(model, x, y, x_validate, y_validate)
Let’s put it to work:
keras_run_all()
We’ve seen this figure before and we note the relatively low final MSE.
23.3. Training with JAX#
For the JAX implementation, we need to construct the network ourselves, as a map from inputs to outputs.
We’ll use the same network structure we used for the Keras implementation.
23.3.1. Background and set up#
The neural network as the form
Here
\(x\) is a scalar input – a point on the horizontal axis in the Keras estimation above,
\(\circ\) means composition of maps,
\(\sigma\) is the activation function – in our case, \(\tanh\), and
\(A_i\) represents the affine map \(A_i x = W_i x + b_i\).
Each matrix \(W_i\) is called a weight matrix and each vector \(b_i\) is called bias term.
The symbol \(\theta\) represents the entire collection of parameters:
In fact, when we implement the affine map \(A_i x = W_i x + b_i\), we will work with row vectors rather than column vectors, so that
\(x\) and \(b_i\) are stored as row vectors, and
the mapping is executed by JAX via the expression
x @ W + b
.
We work with row vectors because Python numerical operations are row-major rather than column-major, so that row-based operations tend to be more efficient.
Here’s a function to initialize parameters.
The parameter “vector” θ
will be stored as a list of dicts.
def initialize_params(seed=1234):
"""
Generate an initial parameterization for a feed forward neural network with
number of layers = NUM_LAYERS. Each of the hidden layers have OUTPUT_DIM
units.
"""
k = OUTPUT_DIM
shapes = (
(1, k), # W_0.shape
(k, k), # W_1.shape
(k, k), # W_2.shape
(k, 1) # W_3.shape
)
np.random.seed(seed)
# A function to generate weight matrices
def w_init(m, n):
return np.random.normal(size=(m, n)) * np.sqrt(2 / m)
# Build list of dicts, each containing a (weight, bias) pair
θ = []
for w_shape in shapes:
m, n = w_shape
θ.append(dict(W=w_init(m, n), b=np.ones((1, n))) )
return θ
Wait, you say!
Shouldn’t we concatenate the elements of \(\theta\) into some kind of big array, so that we can do autodiff with respect to this array?
Actually we don’t need to, as will become clear below.
23.3.2. Coding the network#
Here’s our implementation of \(f\):
@jax.jit
def f(θ, x, σ=jnp.tanh):
"""
Perform a forward pass over the network to evaluate f(θ, x).
The state x is stored and iterated on as a row vector.
"""
*hidden, last = θ
for layer in hidden:
W, b = layer['W'], layer['b']
x = σ(x @ W + b)
W, b = last['W'], last['b']
x = x @ W + b
return x
The function \(f\) is appropriately vectorized, so that we can pass in the entire
set of input observations as x
and return the predicted vector of outputs y_hat = f(θ, x)
corresponding to each data point.
The loss function is mean squared error, the same as the Keras case.
@jax.jit
def loss_fn(θ, x, y):
"Loss is mean squared error."
return jnp.mean((f(θ, x) - y)**2)
We’ll use its gradient to do stochastic gradient descent.
(Technically, we will be doing gradient descent, rather than stochastic gradient descent, since will not randomize over sample points when we evaluate the gradient.)
The gradient below is with respect to the first argument θ
.
loss_gradient = jax.jit(jax.grad(loss_fn))
The line above seems kind of magical, since we are differentiating with respect to a parameter “vector” stored as a list of dictionaries containing arrays.
How can we differentiate with respect to such a complex object?
The answer is that the list of dictionaries is treated internally as a pytree.
The JAX function grad
understands how to
extract the individual arrays (the ``leaves’’ of the tree),
compute derivatives with respect to each one, and
pack the resulting derivatives into a pytree with the same structure as the parameter vector.
23.3.3. Gradient descent#
Using the above code, we can now write our rule for updating the parameters via gradient descent, which is the algorithm we covered in our lecture on autodiff.
In this case, however, to keep things as simple as possible, we’ll use a fixed learning rate for every iteration.
@jax.jit
def update_parameters(θ, x, y):
λ = LEARNING_RATE
gradient = loss_gradient(θ, x, y)
θ = jax.tree.map(lambda p, g: p - λ * g, θ, gradient)
return θ
We are implementing the gradient descent update
new_params = current_params - learning_rate * gradient_of_loss_function
This is nontrivial for a complex structure such as a neural network, so how is it done?
The key line in the function above is Θ = jax.tree.map(lambda p, g: p - λ * g, θ, gradient)
.
The jax.tree.map
function understands θ
and gradient
as pytrees of the
same structure and executes p - λ * g
on the corresponding leaves of the pair
of trees.
This means that each weight matrix and bias vector is updated by gradient descent, exactly as required.
Here’s code that puts this all together.
def train_jax_model(θ, x, y, x_validate, y_validate):
"""
Train model using gradient descent via JAX autodiff.
"""
training_loss = np.empty(EPOCHS)
validation_loss = np.empty(EPOCHS)
for i in range(EPOCHS):
training_loss[i] = loss_fn(θ, x, y)
validation_loss[i] = loss_fn(θ, x_validate, y_validate)
θ = update_parameters(θ, x, y)
return θ, training_loss, validation_loss
23.3.4. Execution#
Let’s run our code and see how it goes.
θ = initialize_params()
x, y = generate_data()
x_validate, y_validate = generate_data()
%%time
θ, training_loss, validation_loss = train_jax_model(
θ, x, y, x_validate, y_validate
)
CPU times: user 3.99 s, sys: 2.18 s, total: 6.17 s
Wall time: 3.39 s
This figure shows MSE across iterations:
fig, ax = plt.subplots()
ax.plot(range(EPOCHS), validation_loss, label='validation loss')
ax.legend()
plt.show()
Let’s check the final MSE on the validation data, at the estimated parameters.
print(f"""
Final MSE on test data set = {loss_fn(θ, x_validate, y_validate)}.
"""
)
Final MSE on test data set = 0.052898816764354706.
This MSE is not as low as we got for Keras, but we did quite well given how simple our implementation is.
Here’s a visualization of the quality of our fit.
23.4. JAX plus Optax#
Our hand-coded optimization routine above was quite effective, but in practice we might wish to use an optimization library written for JAX.
One such library is Optax.
23.4.1. Optax with SGD#
Here’s a training routine using Optax’s stochastic gradient descent solver.
def train_jax_optax(θ, x, y):
solver = optax.sgd(learning_rate=LEARNING_RATE)
opt_state = solver.init(θ)
for _ in range(EPOCHS):
grad = loss_gradient(θ, x, y)
updates, opt_state = solver.update(grad, opt_state, θ)
θ = optax.apply_updates(θ, updates)
return θ
Let’s try running it.
# Reset parameter vector
θ = initialize_params()
# Train network
%time θ = train_jax_optax(θ, x, y)
CPU times: user 12.8 s, sys: 4.46 s, total: 17.2 s
Wall time: 10.1 s
The resulting MSE is the same as our hand-coded routine.
print(f"""
Completed training JAX model using Optax with SGD.
Final MSE on test data set = {loss_fn(θ, x_validate, y_validate)}.
"""
)
Completed training JAX model using Optax with SGD.
Final MSE on test data set = 0.052898820489645004.
23.4.2. Optax with ADAM#
We can also consider using a slightly more sophisticated gradient-based method, such as ADAM.
You will notice that the syntax for using this alternative optimizer is very similar.
def train_jax_optax(θ, x, y):
solver = optax.adam(learning_rate=LEARNING_RATE)
opt_state = solver.init(θ)
for _ in range(EPOCHS):
grad = loss_gradient(θ, x, y)
updates, opt_state = solver.update(grad, opt_state, θ)
θ = optax.apply_updates(θ, updates)
return θ
# Reset parameter vector
θ = initialize_params()
# Train network
%time θ = train_jax_optax(θ, x, y)
CPU times: user 1min 4s, sys: 29.7 s, total: 1min 33s
Wall time: 45.6 s
Here’s the MSE.
print(f"""
Completed training JAX model using Optax with ADAM.
Final MSE on test data set = {loss_fn(θ, x_validate, y_validate)}.
"""
)
Completed training JAX model using Optax with ADAM.
Final MSE on test data set = 0.03738926723599434.
Here’s a visualization of the result.