22. Simple Neural Network Regression with Keras and JAX#
GPU
This lecture was built using a machine with JAX installed and access to a GPU.
To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.
To run this lecture on your own machine, you need to install Google JAX.
In this lecture we show how to implement one-dimensional nonlinear regression using a neural network.
We will use the popular deep learning library Keras, which provides a simple and elegant interface to deep learning.
The emphasis in Keras is on providing an intuitive API, while the heavy lifting is done by another library.
Currently the backend library can be Tensorflow, PyTorch, or JAX.
In this lecture we will use JAX.
The objective of this lecture is to provide a very simple introduction to deep learning in a regression setting.
We begin with some standard imports.
import numpy as np
import matplotlib.pyplot as plt
Let’s install Keras.
!pip install keras
Show code cell output
Collecting keras
Downloading keras-3.6.0-py3-none-any.whl.metadata (5.8 kB)
Collecting absl-py (from keras)
Downloading absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Requirement already satisfied: numpy in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (1.26.4)
Requirement already satisfied: rich in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (13.7.1)
Collecting namex (from keras)
Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes)
Requirement already satisfied: h5py in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (3.11.0)
Collecting optree (from keras)
Downloading optree-0.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (47 kB)
Requirement already satisfied: ml-dtypes in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (0.5.0)
Requirement already satisfied: packaging in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from keras) (24.1)
Requirement already satisfied: typing-extensions>=4.5.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from optree->keras) (4.11.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from rich->keras) (2.2.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from rich->keras) (2.15.1)
Requirement already satisfied: mdurl~=0.1 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich->keras) (0.1.0)
Downloading keras-3.6.0-py3-none-any.whl (1.2 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.2 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 122.6 MB/s eta 0:00:00
?25hDownloading absl_py-2.1.0-py3-none-any.whl (133 kB)
Downloading namex-0.0.8-py3-none-any.whl (5.8 kB)
Downloading optree-0.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (385 kB)
Installing collected packages: namex, optree, absl-py, keras
Successfully installed absl-py-2.1.0 keras-3.6.0 namex-0.0.8 optree-0.13.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
Now we specify that the desired backend is JAX.
import os
os.environ['KERAS_BACKEND'] = 'jax'
Next we import some tools from Keras.
import keras
from keras.models import Sequential
from keras.layers import Dense
22.1. Data#
First let’s write a function to generate some data.
The data has the form
The map \(f\) is specified inside the function and \(\epsilon_i\) is an independent draw from a fixed normal distribution.
Here’s the function that creates vectors x
and y
according to the rule
above.
def generate_data(x_min=0, x_max=5, data_size=400):
x = np.linspace(x_min, x_max, num=data_size)
x = x.reshape(data_size, 1)
ϵ = 0.2 * np.random.randn(*x.shape)
y = x**0.5 + np.sin(x) + ϵ
x, y = [z.astype('float32') for z in (x, y)]
return x, y
Now we generate some data to train the model.
x, y = generate_data()
Here’s a plot of the training data.
We’ll also use data from the same process for cross-validation.
x_validate, y_validate = generate_data()
22.2. Models#
We supply functions to build two types of models.
The first implements linear regression.
This is achieved by constructing a neural network with just one layer, that maps to a single dimension (since the prediction is real-valued).
The input model
will be an instance of keras.Sequential
, which is used to
group a stack of layers into a single prediction model.
def build_regression_model(model):
model.add(Dense(units=1))
model.compile(optimizer=keras.optimizers.SGD(),
loss='mean_squared_error')
return model
In the function above you can see that we use stochastic gradient descent to train the model, and that the loss is mean squared error (MSE).
MSE is the standard loss function for ordinary least squares regression.
The second function creates a dense (i.e., fully connected) neural network with 3 hidden layers, where each hidden layer maps to a k-dimensional output space.
def build_nn_model(model, k=10, activation_function='tanh'):
# Construct network
model.add(Dense(units=k, activation=activation_function))
model.add(Dense(units=k, activation=activation_function))
model.add(Dense(units=k, activation=activation_function))
model.add(Dense(1))
# Embed training configurations
model.compile(optimizer=keras.optimizers.SGD(),
loss='mean_squared_error')
return model
The following function will be used to plot the MSE of the model during the training process.
Initially the MSE will be relatively high, but it should fall at each iteration, as the parameters are adjusted to better fit the data.
def plot_loss_history(training_history, ax):
ax.plot(training_history.epoch,
np.array(training_history.history['loss']),
label='training loss')
ax.plot(training_history.epoch,
np.array(training_history.history['val_loss']),
label='validation loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (Mean squared error)')
ax.legend()
22.3. Training#
Now let’s go ahead and train our models.
22.3.1. Linear regression#
We’ll start with linear regression.
First we create a Model
instance using Sequential()
.
model = Sequential()
regression_model = build_regression_model(model)
Now we train the model using the training data.
training_history = regression_model.fit(
x, y, batch_size=x.shape[0], verbose=0,
epochs=4000, validation_data=(x_validate, y_validate))
Let’s have a look at the evolution of MSE as the model is trained.
Let’s print the final MSE on the cross-validation data.
print("Testing loss on the validation set.")
regression_model.evaluate(x_validate, y_validate, verbose=2)
Testing loss on the validation set.
13/13 - 0s - 11ms/step - loss: 0.3260
0.3260418772697449
Here’s our output predictions on the cross-validation data.
y_predict = regression_model.predict(x_validate, verbose=2)
13/13 - 0s - 4ms/step
We use the following function to plot our predictions along with the data.
def plot_results(x, y, y_predict, ax):
ax.scatter(x, y)
ax.plot(x, y_predict, label="fitted model", color='black')
ax.set_xlabel('x')
ax.set_ylabel('y')
Let’s now call the function on the cross-validation data.
22.3.2. Deep learning#
Now let’s switch to a neural network with multiple layers.
We implement the same steps as before.
model = Sequential()
nn_model = build_nn_model(model)
training_history = nn_model.fit(
x, y, batch_size=x.shape[0], verbose=0,
epochs=4000, validation_data=(x_validate, y_validate))
Here’s the final MSE for the deep learning model.
print("Testing loss on the validation set.")
nn_model.evaluate(x_validate, y_validate, verbose=2)
Testing loss on the validation set.
13/13 - 0s - 31ms/step - loss: 0.0407
0.040745146572589874
You will notice that this loss is much lower than the one we achieved with linear regression, suggesting a better fit.
To confirm this, let’s look at the fitted function.
y_predict = nn_model.predict(x_validate, verbose=2)
13/13 - 0s - 9ms/step
def plot_results(x, y, y_predict, ax):
ax.scatter(x, y)
ax.plot(x, y_predict, label="fitted model", color='black')
ax.set_xlabel('x')
ax.set_ylabel('y')