# 15. Endogenous Grid Method#

GPU

This lecture was built using a machine with JAX installed and access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

## 15.1. Overview#

In this lecture we use the endogenous grid method (EGM) to solve a basic income fluctuation (optimal savings) problem.

Background on the endogenous grid method can be found in an earlier QuantEcon lecture.

Here we focus on providing an efficient JAX implementation.

In addition to JAX and Anaconda, this lecture will need the following libraries:

```
!pip install --upgrade quantecon
```

## Show code cell output

```
Requirement already satisfied: quantecon in /opt/conda/envs/quantecon/lib/python3.11/site-packages (0.7.2)
```

```
Requirement already satisfied: numba>=0.49.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (0.59.0)
Requirement already satisfied: numpy>=1.17.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.26.4)
Requirement already satisfied: requests in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (2.31.0)
Requirement already satisfied: scipy>=1.5.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.11.4)
Requirement already satisfied: sympy in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.12)
Requirement already satisfied: llvmlite<0.43,>=0.42.0dev0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from numba>=0.49.0->quantecon) (0.42.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (3.4)
```

```
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2024.2.2)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from sympy->quantecon) (1.3.0)
```

```
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
```

```
import quantecon as qe
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import numba
```

Let’s check the GPU we are running

```
!nvidia-smi
```

```
/opt/conda/envs/quantecon/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
```

```
Thu Apr 11 22:01:51 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03 Driver Version: 470.182.03 CUDA Version: 12.3 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... Off | 00000000:00:1E.0 Off | 0 |
| N/A 42C P0 42W / 300W | 0MiB / 16160MiB | 2% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
```

We use 64 bit floating point numbers for extra precision.

```
jax.config.update("jax_enable_x64", True)
```

## 15.2. Setup#

We consider a household that chooses a state-contingent consumption plan \(\{c_t\}_{t \geq 0}\) to maximize

subject to

Here \(R = 1 + r\) where \(r\) is the interest rate.

The income process \(\{Y_t\}\) is a Markov chain generated by stochastic matrix \(P\).

The matrix \(P\) and the grid of values taken by \(Y_t\) are obtained by discretizing the AR(1) process

where \(\{\epsilon_t\}\) is IID and standard normal.

Utility has the CRRA specification

The following function stores default parameter values for the income fluctuation problem and creates suitable arrays.

```
def ifp(R=1.01, # gross interest rate
β=0.99, # discount factor
γ=1.5, # CRRA preference parameter
s_max=16, # savings grid max
s_size=200, # savings grid size
ρ=0.99, # income persistence
ν=0.02, # income volatility
y_size=25): # income grid size
# require R β < 1 for convergence
assert R * β < 1, "Stability condition failed."
# Create income Markov chain
mc = qe.tauchen(y_size, ρ, ν)
y_grid, P = jnp.exp(mc.state_values), mc.P
# Shift to JAX arrays
P, y_grid = jax.device_put((P, y_grid))
s_grid = jnp.linspace(0, s_max, s_size)
# Pack and return
constants = β, R, γ
sizes = s_size, y_size
arrays = s_grid, y_grid, P
return constants, sizes, arrays
```

## 15.3. Solution method#

Let \(S = \mathbb R_+ \times \mathsf Y\) be the set of possible values for the state \((a_t, Y_t)\).

We aim to compute an optimal consumption policy \(\sigma^* \colon S \to \mathbb R\), under which dynamics are given by

In this section we discuss how we intend to solve for this policy.

### 15.3.1. Euler equation#

The Euler equation for the optimization problem is

An explanation for this expression can be found here.

We rewrite the Euler equation in functional form

where \((u' \circ \sigma)(a, y) := u'(\sigma(a, y))\) and \(\sigma\) is a consumption policy.

For given consumption policy \(\sigma\), we define \((K \sigma) (a,y)\) as the unique \(c \in [0, a]\) that solves

iterating with \(K\) computes an optimal policy and

if \(\sigma\) is increasing in its first argument, then so is \(K\sigma\)

Hence below we always assume that \(\sigma\) is increasing in its first argument.

The EGM is a technique for computing the update \(K\sigma\) given \(\sigma\) along a grid of asset values.

Notice that, since \(u'(a) \to \infty\) as \(a \downarrow 0\), the second term in the max in (15.1) dominates for sufficiently small \(a\).

Also, again using (15.1), we have \(c=a\) for all such \(a\).

Hence, for sufficiently small \(a\),

Equality holds at \(\bar a(y)\) given by

We can now write

Equivalently, we can state that the \(c\) satisfying \(c = (K\sigma)(a, y)\) obeys

We begin with an *exogenous* grid of saving values \(0 = s_0 < \ldots < s_{N-1}\)

Using the exogenous savings grid, and a fixed value of \(y\), we create an *endogenous* asset grid
\(a_0, \ldots, a_{N-1}\) and a consumption grid \(c_0, \ldots, c_{N-1}\) as follows.

First we set \(a_0 = c_0 = 0\), since zero consumption is an optimal (in fact the only) choice when \(a=0\).

Then, for \(i > 0\), we compute

and we set

We claim that each pair \(a_i, c_i\) obeys (15.2).

Indeed, since \(s_i > 0\), choosing \(c_i\) according to (15.3) gives

where the inequality uses the fact that \(\sigma\) is increasing in its first argument.

If we now take \(a_i = s_i + c_i\) we get \(a_i > \bar a(y)\), so the pair \((a_i, c_i)\) satisfies

Hence (15.2) holds.

We are now ready to iterate with \(K\).

### 15.3.2. JAX version#

First we define a vectorized operator \(K\) based on the EGM.

Notice in the code below that

we avoid all loops and any mutation of arrays

the function is pure (no globals, no mutation of inputs)

```
def K_egm(a_in, σ_in, constants, sizes, arrays):
"""
The vectorized operator K using EGM.
"""
# Unpack
β, R, γ = constants
s_size, y_size = sizes
s_grid, y_grid, P = arrays
def u_prime(c):
return c**(-γ)
def u_prime_inv(u):
return u**(-1/γ)
# Linearly interpolate σ(a, y)
def σ(a, y):
return jnp.interp(a, a_in[:, y], σ_in[:, y])
σ_vec = jnp.vectorize(σ)
# Broadcast and vectorize
y_hat = jnp.reshape(y_grid, (1, 1, y_size))
y_hat_idx = jnp.reshape(jnp.arange(y_size), (1, 1, y_size))
s = jnp.reshape(s_grid, (s_size, 1, 1))
P = jnp.reshape(P, (1, y_size, y_size))
# Evaluate consumption choice
a_next = R * s + y_hat
σ_next = σ_vec(a_next, y_hat_idx)
up = u_prime(σ_next)
E = jnp.sum(up * P, axis=-1)
c = u_prime_inv(β * R * E)
# Set up a column vector with zero in the first row and ones elsewhere
e_0 = jnp.ones(s_size) - jnp.identity(s_size)[:, 0]
e_0 = jnp.reshape(e_0, (s_size, 1))
# The policy is computed consumption with the first row set to zero
σ_out = c * e_0
# Compute a_out by a = s + c
a_out = np.reshape(s_grid, (s_size, 1)) + σ_out
return a_out, σ_out
```

Then we use `jax.jit`

to compile \(K\).

We use `static_argnums`

to allow a recompile whenever `sizes`

changes, since the compiler likes to specialize on shapes.

```
K_egm_jax = jax.jit(K_egm, static_argnums=(3,))
```

Next we define a successive approximator that repeatedly applies \(K\).

```
def successive_approx_jax(model,
tol=1e-5,
max_iter=100_000,
verbose=True,
print_skip=25):
# Unpack
constants, sizes, arrays = model
β, R, γ = constants
s_size, y_size = sizes
s_grid, y_grid, P = arrays
# Initial condition is to consume all in every state
σ_init = jnp.repeat(s_grid, y_size)
σ_init = jnp.reshape(σ_init, (s_size, y_size))
a_init = jnp.copy(σ_init)
a_vec, σ_vec = a_init, σ_init
i = 0
error = tol + 1
while i < max_iter and error > tol:
a_new, σ_new = K_egm_jax(a_vec, σ_vec, constants, sizes, arrays)
error = jnp.max(jnp.abs(σ_vec - σ_new))
i += 1
if verbose and i % print_skip == 0:
print(f"Error at iteration {i} is {error}.")
a_vec, σ_vec = jnp.copy(a_new), jnp.copy(σ_new)
if error > tol:
print("Failed to converge!")
else:
print(f"\nConverged in {i} iterations.")
return a_new, σ_new
```

### 15.3.3. Numba version#

Below we provide a second set of code, which solves the same model with Numba.

The purpose of this code is to cross-check our results from the JAX version, as well as to do a runtime comparison.

Most readers will want to skip ahead to the next section, where we solve the model and run the cross-check.

```
@numba.jit
def K_egm_nb(a_in, σ_in, constants, sizes, arrays):
"""
The operator K using Numba.
"""
# Simplify names
β, R, γ = constants
s_size, y_size = sizes
s_grid, y_grid, P = arrays
def u_prime(c):
return c**(-γ)
def u_prime_inv(u):
return u**(-1/γ)
# Linear interpolation of policy using endogenous grid
def σ(a, z):
return np.interp(a, a_in[:, z], σ_in[:, z])
# Allocate memory for new consumption array
σ_out = np.zeros_like(σ_in)
a_out = np.zeros_like(σ_out)
for i, s in enumerate(s_grid[1:]):
i += 1
for z in range(y_size):
expect = 0.0
for z_hat in range(y_size):
expect += u_prime(σ(R * s + y_grid[z_hat], z_hat)) * \
P[z, z_hat]
c = u_prime_inv(β * R * expect)
σ_out[i, z] = c
a_out[i, z] = s + c
return a_out, σ_out
```

```
def successive_approx_numba(model, # Class with model information
tol=1e-5,
max_iter=100_000,
verbose=True,
print_skip=25):
# Unpack
constants, sizes, arrays = model
s_size, y_size = sizes
# make NumPy versions of arrays
arrays = tuple(map(np.array, arrays))
s_grid, y_grid, P = arrays
σ_init = np.repeat(s_grid, y_size)
σ_init = np.reshape(σ_init, (s_size, y_size))
a_init = np.copy(σ_init)
a_vec, σ_vec = a_init, σ_init
# Set up loop
i = 0
error = tol + 1
while i < max_iter and error > tol:
a_new, σ_new = K_egm_nb(a_vec, σ_vec, constants, sizes, arrays)
error = np.max(np.abs(σ_vec - σ_new))
i += 1
if verbose and i % print_skip == 0:
print(f"Error at iteration {i} is {error}.")
a_vec, σ_vec = np.copy(a_new), np.copy(σ_new)
if error > tol:
print("Failed to converge!")
else:
print(f"\nConverged in {i} iterations.")
return a_new, σ_new
```

## 15.4. Solutions#

Here we solve the IFP with JAX and Numba.

We will compare both the outputs and the execution time.

### 15.4.1. Outputs#

```
model = ifp()
```

Here’s a first run of the JAX code.

```
a_star_egm_jax, σ_star_egm_jax = successive_approx_jax(model,
print_skip=100)
```

```
Error at iteration 100 is 0.003274240577000098.
Error at iteration 200 is 0.0013133107388259013.
```

```
Error at iteration 300 is 0.0006550972250753961.
Error at iteration 400 is 0.00038003859326907197.
```

```
Error at iteration 500 is 0.00024736616926013255.
Error at iteration 600 is 0.00017446354504913053.
```

```
Error at iteration 700 is 0.000129892015863442.
Error at iteration 800 is 0.00010058769447773841.
```

```
Error at iteration 900 is 7.993256952376626e-05.
Error at iteration 1000 is 6.472028596182788e-05.
```

```
Error at iteration 1100 is 5.316228631624398e-05.
Error at iteration 1200 is 4.425450893941196e-05.
```

```
Error at iteration 1300 is 3.7260418253914906e-05.
Error at iteration 1400 is 3.1614060126861077e-05.
```

```
Error at iteration 1500 is 2.6984975752375462e-05.
Error at iteration 1600 is 2.3148392509719784e-05.
```

```
Error at iteration 1700 is 1.9940474091262317e-05.
Error at iteration 1800 is 1.723818132703947e-05.
```

```
Error at iteration 1900 is 1.4947303633494613e-05.
Error at iteration 2000 is 1.2994575430580468e-05.
```

```
Error at iteration 2100 is 1.132223596411741e-05.
Converged in 2192 iterations.
```

Next let’s solve the same IFP with Numba.

```
qe.tic()
a_star_egm_nb, σ_star_egm_nb = successive_approx_numba(model,
print_skip=100)
qe.toc()
```

```
Error at iteration 100 is 0.0032742405770003202.
```

```
Error at iteration 200 is 0.0013133107388259013.
```

```
Error at iteration 300 is 0.0006550972250753961.
```

```
Error at iteration 400 is 0.0003800385932688499.
```

```
Error at iteration 500 is 0.00024736616926013255.
```

```
Error at iteration 600 is 0.00017446354504935258.
```

```
Error at iteration 700 is 0.000129892015863442.
```

```
Error at iteration 800 is 0.00010058769447773841.
```

```
Error at iteration 900 is 7.993256952354422e-05.
```

```
Error at iteration 1000 is 6.472028596182788e-05.
```

```
Error at iteration 1100 is 5.316228631624398e-05.
```

```
Error at iteration 1200 is 4.425450893941196e-05.
```

```
Error at iteration 1300 is 3.7260418253914906e-05.
```

```
Error at iteration 1400 is 3.1614060126861077e-05.
```

```
Error at iteration 1500 is 2.6984975752597506e-05.
```

```
Error at iteration 1600 is 2.3148392509719784e-05.
```

```
Error at iteration 1700 is 1.9940474091262317e-05.
```

```
Error at iteration 1800 is 1.7238181326817426e-05.
```

```
Error at iteration 1900 is 1.4947303633494613e-05.
```

```
Error at iteration 2000 is 1.2994575430802513e-05.
```

```
Error at iteration 2100 is 1.132223596411741e-05.
```

```
Converged in 2192 iterations.
TOC: Elapsed: 0:01:21.06
```

```
81.06275987625122
```

Now let’s check the outputs in a plot to make sure they are the same.

```
constants, sizes, arrays = model
β, R, γ = constants
s_size, y_size = sizes
s_grid, y_grid, P = arrays
fig, ax = plt.subplots()
for z in (0, y_size-1):
ax.plot(a_star_egm_nb[:, z],
σ_star_egm_nb[:, z],
'--', lw=2,
label=f"Numba EGM: consumption when $z={z}$")
ax.plot(a_star_egm_jax[:, z],
σ_star_egm_jax[:, z],
label=f"JAX EGM: consumption when $z={z}$")
ax.set_xlabel('asset')
plt.legend()
plt.show()
```

### 15.4.2. Timing#

Now let’s compare execution time of the two methods

```
qe.tic()
a_star_egm_jax, σ_star_egm_jax = successive_approx_jax(model,
print_skip=1000)
jax_time = qe.toc()
```

```
Error at iteration 1000 is 6.472028596182788e-05.
```

```
Error at iteration 2000 is 1.2994575430580468e-05.
```

```
Converged in 2192 iterations.
TOC: Elapsed: 0:00:3.23
```

```
qe.tic()
a_star_egm_nb, σ_star_egm_nb = successive_approx_numba(model,
print_skip=1000)
numba_time = qe.toc()
```

```
Error at iteration 1000 is 6.472028596182788e-05.
```

```
Error at iteration 2000 is 1.2994575430802513e-05.
```

```
Converged in 2192 iterations.
TOC: Elapsed: 0:01:18.66
```

```
jax_time / numba_time
```

```
0.04107768511500939
```

The JAX code is significantly faster, as expected.

This difference will increase when more features (and state variables) are added to the model.