24. Policy Gradient-Based Optimal Savings#
GPU
This lecture was built using a machine with JAX installed and access to a GPU.
To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.
To run this lecture on your own machine, you need to install Google JAX.
24.1. Introduction#
In this notebook we solve infinite horizon optimal savings problems using deep learning and policy gradient ascent with JAX.
Each policy is represented as a fully connected feed forward neural network.
We begin with a cake eating problem with a known analytical solution.
Then we shift to an income fluctuation problem where we can compute an optimal policy easily with the endogenous grid method.
We do this first and then try to learn the same policy with deep learning.
Throughout, utility takes the CRRA form \(u(c) = c^{1-\gamma} / (1-\gamma)\) and the discount factor is \(\beta\).
We’ll use the following libraries
!pip install optax
We’ll use the following imports
import jax
import jax.numpy as jnp
from jax import grad, jit, random
import optax
import matplotlib.pyplot as plt
from functools import partial
from typing import NamedTuple
24.2. Cake Eating Case#
With \(R\) as the gross interest rate, assets evolve according to
To ensure stability we need \(\beta R^{1-\gamma} < 1\).
For this model, it is known that the optimal policy is \(c = \kappa a\), where
We use this known exact solution to check our numerical methods.
Initial assets \(a_0\) is fixed at 1.0, so the objective function is
Here
\(\Sigma\) is the set of all feasible policies and
\(v_\sigma(a)\) is the lifetime value of following stationary policy \(\sigma\), given initial assets \(a\).
24.3. Set up#
We use a class called CakeEatingModel to store model parameters.
class CakeEatingModel(NamedTuple):
"""
Stores parameters for the model.
"""
γ: float = 1.5
β: float = 0.96
R: float = 1.01
We use a class called LayerParams to store parameters representing a single
layer of the neural network.
class LayerParams(NamedTuple):
"""
Stores parameters for one layer of the neural network.
"""
W: jnp.ndarray # weights
b: jnp.ndarray # biases
The next class stores some fixed values that form part of the network training configuration.
class Config:
"""
Configuration and parameters for training the neural network.
"""
seed = 42 # Seed for network initialization
epochs = 400 # No of training epochs
path_length = 320 # Length of each consumption path
layer_sizes = 1, 6, 6, 6, 6, 6, 1 # Network layer sizes
learning_rate = 0.001 # Constant learning rate
The following function initializes a single layer of the network using Le Cun initialization.
(Le Cun initialization is thought to pair well with selu activation.)
def initialize_layer(in_dim, out_dim, key):
"""
Initialize weights and biases for a single layer of a the network.
Use LeCun initialization.
"""
s = jnp.sqrt(1.0 / in_dim)
W = jax.random.normal(key, (in_dim, out_dim)) * s
b = jnp.zeros((out_dim,))
return LayerParams(W, b)
The next function builds an entire network, as represented by its parameters, by initializing layers and stacking them into a list.
def initialize_network(key, layer_sizes):
"""
Build a network by initializing all of the parameters.
A network is a list of LayerParams instances, each
containing a weight-bias pair (W, b).
"""
params = []
# For all layers but the output layer
for i in range(len(layer_sizes) - 1):
# Build the layer
key, subkey = jax.random.split(key)
layer = initialize_layer(
layer_sizes[i], # in dimension for layer
layer_sizes[i + 1], # out dimension for layer
subkey
)
# And add it to the parameter list
params.append(layer)
return params
Now we provide a function to do a forward pass through the network, given the parameters.
def forward(params, a):
"""
Evaluate neural network policy: maps a given asset level a to
consumption rate c/a by running a forward pass through the network.
"""
σ = jax.nn.selu # Activation function
x = jnp.array((a,)) # Make state a 1D array
# Forward pass through network, without the last step
for W, b in params[:-1]:
x = σ(x @ W + b)
# Complete with sigmoid activation for consumption rate
W, b = params[-1]
# Direct output in [0, 0.99] range for stability
x = jax.nn.sigmoid(x @ W + b) * 0.99
# Extract and return consumption rate
consumption_rate = x[0]
return consumption_rate
We use CRRA utility.
def u(c, γ):
""" Utility function. """
c = jnp.maximum(c, 1e-10)
return c**(1 - γ) / (1 - γ)
The next function approximates lifetime value associated with a given policy, as represented by the parameters of a neural network.
@partial(jax.jit, static_argnames=('path_length'))
def compute_lifetime_value(params, model, path_length):
"""
Compute the lifetime value of a path generated from
the policy embedded in params and the initial condition a_0 = 1.
"""
γ, β, R = model.γ, model.β, model.R
initial_a = 1.0
def update(t, state):
# Unpack and compute consumption given current assets
a, value, discount = state
consumption_rate = forward(params, a)
c = consumption_rate * a
# Update loop state and return it
a = R * (a - c)
value = value + discount * u(c, γ)
discount = discount * β
new_state = a, value, discount
return new_state
initial_value, initial_discount = 0.0, 1.0
initial_state = initial_a, initial_value, initial_discount
final_a, final_value, discount = jax.lax.fori_loop(
0, path_length, update, initial_state
)
return final_value
Here’s the loss function we will minimize.
def loss_function(params, model, path_length):
"""
Loss is the negation of the lifetime value of the policy
identified by `params`.
"""
return -compute_lifetime_value(params, model, path_length)
24.4. Train and solve#
First we create an instance of the model and unpack names
model = CakeEatingModel()
γ, β, R = model.γ, model.β, model.R
seed, epochs = Config.seed, Config.epochs
path_length = Config.path_length
layer_sizes = Config.layer_sizes
We test stability.
assert β * R**(1 - γ) < 1, "Parameters fail stability test."
We compute the optimal consumption rate and lifetime value from the analytical expressions.
κ = 1 - (β * R**(1 - γ))**(1/γ)
print(f"Optimal consumption rate = {κ}.\n")
v_max = κ**(-γ) * u(1.0, γ)
print(f"Theoretical maximum lifetime value = {v_max}.\n")
Optimal consumption rate = 0.03007006297501369.
Theoretical maximum lifetime value = -383.5557556152344.
W1129 00:49:12.314425 2034 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1129 00:49:12.318021 1973 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
Let’s now set up the Optax minimizer, using Adam with a constant learning rate.
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # Gradient clipping for stability
optax.adam(learning_rate=Config.learning_rate)
)
We initialize the parameters in the neural network and the state of the optimizer.
key = random.PRNGKey(seed)
params = initialize_network(key, layer_sizes)
opt_state = optimizer.init(params)
Now let’s train the network.
value_history = []
best_value = -jnp.inf
best_params = params
for i in range(epochs):
# Compute value and gradients at existing parameterization
loss, grads = jax.value_and_grad(loss_function)(params, model, path_length)
lifetime_value = - loss
value_history.append(lifetime_value)
# Track best parameters
if lifetime_value > best_value:
best_value = lifetime_value
best_params = params
# Update parameters using optimizer
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % 100 == 0:
print(f"Iteration {i}: Value = {lifetime_value:.4f}")
# Use best parameters instead of final
params = best_params
print(f"\nBest value: {best_value:.4f}")
print(f"Final value: {value_history[-1]:.4f}")
Iteration 0: Value = -1495089.7500
Iteration 100: Value = -1252.4659
Iteration 200: Value = -383.5421
Iteration 300: Value = -383.5356
Best value: -383.5334
Final value: -383.5335
First we plot the evolution of lifetime value over the epochs.
# Plot learning progress
fig, ax = plt.subplots()
ax.plot(value_history, 'b-', linewidth=2)
ax.set_xlabel('iteration')
ax.set_ylabel('policy value')
ax.set_title('learning progress')
plt.show()
Next we compare the learned and optimal policies.
a_grid = jnp.linspace(0.01, 1.0, 1000)
policy_vmap = jax.vmap(lambda a: forward(params, a))
consumption_rate = policy_vmap(a_grid)
# Compute actual consumption: c = (c/a) * a
c_learned = consumption_rate * a_grid
c_optimal = κ * a_grid
fig, ax = plt.subplots()
ax.plot(a_grid, c_learned, linestyle='--', lw=4, label='learned policy')
ax.plot(a_grid, c_optimal, lw=2, label='optimal')
ax.set_xlabel('assets')
ax.set_ylabel('consumption')
ax.set_title('Consumption policy')
ax.legend()
plt.show()
Let’s have a look at paths for consumption and assets under the learned and optimal policies.
The figures below show that the learned policies are close to optimal.
def simulate_consumption_path(params, T=120):
"""
Compute consumption path using neural network policy identified by params.
"""
a_sim = [1.0] # 1.0 is the initial assets
c_sim = []
a_opt = [1.0]
c_opt = []
a = 1.0
for t in range(T):
# Update policy path - forward returns consumption rate
c = forward(params, a) * a
c_sim.append(float(c))
a = R * (a - c)
a_sim.append(float(a))
if a <= 1e-10:
break
a = 1.0
for t in range(T):
# Update optimal path
c = κ * a
c_opt.append(c)
a = R * (a - c)
a_opt.append(a)
if a <= 1e-10:
break
return a_sim, c_sim, a_opt, c_opt
# Simulate and plot path
a_sim, c_sim, a_opt, c_opt = simulate_consumption_path(params)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.plot(a_sim, lw=4, linestyle='--', label='learned policy')
ax1.plot(a_opt, lw=2, label='optimal')
ax1.set_xlabel('Time')
ax1.set_ylabel('Assets')
ax1.set_title('Assets over time')
ax1.legend()
ax2.plot(c_sim, lw=4, linestyle='--', label='learned policy')
ax2.plot(c_opt, lw=2, label='optimal')
ax2.set_xlabel('Time')
ax2.set_ylabel('Consumption')
ax2.set_title('Consumption over time')
ax2.legend()
plt.tight_layout()
plt.show()
24.5. Extension: stochastic labor income with IID shocks#
Now let’s solve a model with IID stochastic labor income using deep learning.
24.5.1. Set-Up#
A household chooses a consumption plan \(\{c_t\}_{t \geq 0}\) to maximize
subject to
Here \(Y_t\) is labor income, which is IID and normally distributed:
Since the shocks are IID, the optimal policy depends only on current assets \(a\), not on the shock history.
We assume:
\(\beta R < 1\)
\(u\) is CRRA with parameter \(\gamma\)
24.5.2. JAX Implementation#
We start with a class called IFP that stores the model primitives.
class IFP(NamedTuple):
R: float # Gross interest rate R = 1 + r
β: float # Discount factor
γ: float # Preference parameter
z_mean: float # Mean of log income shock
z_std: float # Std dev of log income shock
z_samples: jnp.ndarray # Std dev of log income shock
def create_ifp(
r=0.01,
β=0.96,
γ=1.5,
z_mean=0.1,
z_std=0.1,
n_shocks=200,
seed=42
):
R = 1 + r
assert R * β < 1, "Stability condition violated."
key = random.PRNGKey(seed)
z_samples = z_mean + z_std * jax.random.normal(key, n_shocks)
return IFP(R, β, γ, z_mean, z_std, z_samples)
24.5.3. Solving the IID model using the EGM#
Since the shocks are IID, the optimal policy depends only on current assets \(a\).
For the IID normal case, we need to compute the expectation:
where \(Z \sim N(m, v)\) and \(Y = \exp(Z)\).
We approximate this expectation using Monte Carlo.
Here is the EGM operator \(K\) for the IID case:
def K(c_in, a_in, ifp, s_grid, n_shocks=50):
"""
The Euler equation operator for the IFP model with IID shocks using EGM.
Args:
c_in: Current consumption policy on endogenous grid
a_in: Current endogenous asset grid
ifp: IFP model instance
s_grid: Exogenous savings grid
n_shocks: Number of points for Monte Carlo integration
Returns:
c_out: Updated consumption policy
a_out: Updated endogenous asset grid
"""
R, β, γ, z_mean, z_std, z_samples = ifp
y_samples = jnp.exp(z_samples)
u_prime = lambda c: c**(-γ)
u_prime_inv = lambda c: c**(-1/γ)
def compute_c_i(s_i):
"""Compute consumption for savings level s_i."""
# For each income realization, compute next period assets and consumption
def compute_mu_k(y_k):
next_a = R * s_i + y_k
# Interpolate to get consumption
next_c = jnp.interp(next_a, a_in, c_in)
return u_prime(next_c)
# Compute expectation over income shocks (Monte Carlo average)
mu_values = jax.vmap(compute_mu_k)(y_samples)
expectation = jnp.mean(mu_values)
# Invert to get consumption (handles s_i=0 case via smooth function)
c = u_prime_inv(β * R * expectation)
# For s_i = 0, consumption should be 0
return jnp.where(s_i == 0, 0.0, c)
# Compute consumption for all savings levels
c_out = jax.vmap(compute_c_i)(s_grid)
# Compute endogenous asset grid
a_out = c_out + s_grid
return c_out, a_out
Here’s the solver using time iteration:
def solve_model(ifp, s_grid, n_shocks=50, tol=1e-5, max_iter=1000):
"""
Solve the IID model using time iteration with EGM.
Args:
ifp: IFP model instance
s_grid: Exogenous savings grid
n_shocks: Number of income shock realizations for integration
tol: Convergence tolerance
max_iter: Maximum iterations
Returns:
c_out: Optimal consumption policy on endogenous grid
a_out: Endogenous asset grid
"""
# Initialize with consumption = assets (consume everything)
a_init = s_grid.copy()
c_init = s_grid.copy()
c_in, a_in = c_init, a_init
for i in range(max_iter):
c_out, a_out = K(c_in, a_in, ifp, s_grid, n_shocks)
error = jnp.max(jnp.abs(c_out - c_in))
if error < tol:
print(f"Converged in {i} iterations, error = {error:.2e}")
break
c_in, a_in = c_out, a_out
if i % 100 == 0:
print(f"Iteration {i}, error = {error:.2e}")
return c_out, a_out
Let’s solve the model and plot the optimal policy:
# Create model instance
ifp = create_ifp(z_mean=0.1, z_std=0.1)
# Create savings grid
s_grid = jnp.linspace(0, 10, 200)
# Solve using EGM
print("Solving IFP model using EGM...\n")
c_egm, a_egm = solve_model(ifp, s_grid, n_shocks=100)
Solving IFP model using EGM...
Iteration 0, error = 1.40e+00
Converged in 38 iterations, error = 6.79e-06
Plot the optimal consumption policy:
24.5.4. Solving the IID model with DL#
Since the shocks are IID, the policy depends only on current assets \(a\).
We use the same network architecture as the deterministic case.
The forward pass uses the forward function from the deterministic case.
Here we implement lifetime value computation.
The key is to simulate paths with IID normal income shocks.
@partial(jax.jit, static_argnames=('path_length', 'num_paths'))
def compute_lifetime_value_ifp(params, ifp, path_length, num_paths, key):
"""
Compute expected lifetime value by averaging over multiple
simulated paths.
Args:
params: Neural network parameters
ifp: IFP model instance
path_length: Length of each simulated path
num_paths: Number of paths to simulate for averaging
key: JAX random key for generating income shocks
Returns:
Average lifetime value across all simulated paths
"""
R, β, γ, z_mean, z_std, z_samples = ifp
def simulate_path(subkey):
"""Simulate a single path and return its lifetime value."""
z_shocks = z_mean + z_std * jax.random.normal(subkey, path_length)
Y = jnp.exp(z_shocks)
def update(t, loop_state):
a, value, discount = loop_state
consumption_rate = forward(params, a)
c = consumption_rate * a
next_value = value + discount * u(c, γ)
next_a = R * (a - c) + Y[t]
next_discount = discount * β
return next_a, next_value, next_discount
initial_a = 10.0
initial_value = 0.0
initial_discount = 1.0
initial_state = (initial_a, initial_value, initial_discount)
final_a, final_value, final_discount = jax.lax.fori_loop(
0, path_length, update, initial_state
)
return final_value
# Generate keys for all paths
path_keys = jax.random.split(key, num_paths)
# Simulate all paths and average
values = jax.vmap(simulate_path)(path_keys)
return jnp.mean(values)
The loss function is the negation of the expected lifetime value.
def loss_function_ifp(params, ifp, path_length, num_paths, key):
return -compute_lifetime_value_ifp(
params, ifp, path_length, num_paths, key
)
Now let’s set up and train the network.
We use the same ifp instance that was created for the EGM solution above.
stochastic_config = {
'seed': 1234,
'epochs': 400,
'path_length': 320,
'num_paths': 500, # Number of paths to average over
'learning_rate': 0.001
}
We initialize parameters.
key = random.PRNGKey(seed)
ifp_params = initialize_network(key, layer_sizes)
Let’s set up the optimizer.
ifp_optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # Gradient clipping for stability
optax.adam(learning_rate=stochastic_config['learning_rate'])
)
ifp_opt_state = ifp_optimizer.init(ifp_params)
Train the network using policy gradient ascent.
We use a fixed random key at each epoch for variance reduction.
ifp_value_history = []
best_ifp_value = -jnp.inf
best_ifp_params = ifp_params
fixed_key = random.PRNGKey(stochastic_config['seed'])
print("Training IFP model with deep learning...\n")
for i in range(stochastic_config['epochs']):
# Compute loss and gradients
loss, grads = jax.value_and_grad(loss_function_ifp)(
ifp_params, ifp,
stochastic_config['path_length'],
stochastic_config['num_paths'],
fixed_key
)
lifetime_value = -loss
ifp_value_history.append(lifetime_value)
# Track best parameters
if lifetime_value > best_ifp_value:
best_ifp_value = lifetime_value
best_ifp_params = ifp_params
# Update parameters
updates, ifp_opt_state = ifp_optimizer.update(grads, ifp_opt_state)
ifp_params = optax.apply_updates(ifp_params, updates)
if i % 50 == 0:
print(f"Iteration {i}: Value = {lifetime_value:.4f}")
# Use best parameters
ifp_params = best_ifp_params
print(f"\nBest value: {best_ifp_value:.4f}")
print(f"Final value: {ifp_value_history[-1]:.4f}")
Training IFP model with deep learning...
Iteration 0: Value = -46.4904
Iteration 50: Value = -44.8104
Iteration 100: Value = -43.0656
Iteration 150: Value = -43.0144
Iteration 200: Value = -42.9911
Iteration 250: Value = -42.9793
Iteration 300: Value = -42.9777
Iteration 350: Value = -42.9733
Best value: -42.9710
Final value: -42.9713
Plot the learning progress.
fig, ax = plt.subplots()
ax.plot(ifp_value_history, 'b-', linewidth=2)
ax.set_xlabel('iteration')
ax.set_ylabel('policy value')
ax.set_title('Learning progress')
plt.show()
Compare EGM and DL solutions.
# Evaluate DL policy on asset grid
a_grid_dl = jnp.linspace(0.01, 10.0, 200)
policy_vmap = jax.vmap(lambda a: forward(ifp_params, a))
consumption_rate_dl = policy_vmap(a_grid_dl)
c_dl = consumption_rate_dl * a_grid_dl
fig, ax = plt.subplots()
ax.plot(a_egm, c_egm, lw=2, label='EGM solution')
ax.plot(a_grid_dl, c_dl, lw=2, label='DL solution')
ax.set_xlabel('assets', fontsize=12)
ax.set_ylabel('consumption', fontsize=12)
ax.set_xlim(0, min(a_grid_dl[-1], a_egm[-1]))
ax.legend()
plt.show()