15. Endogenous Grid Method#

GPU

This lecture was built using a machine with JAX installed and access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

15.1. Overview#

In this lecture we use the endogenous grid method (EGM) to solve a basic income fluctuation (optimal savings) problem.

Background on the endogenous grid method can be found in an earlier QuantEcon lecture.

Here we focus on providing an efficient JAX implementation.

In addition to JAX and Anaconda, this lecture will need the following libraries:

!pip install --upgrade quantecon
Hide code cell output
Requirement already satisfied: quantecon in /opt/conda/envs/quantecon/lib/python3.11/site-packages (0.7.2)
Requirement already satisfied: numba>=0.49.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (0.59.0)
Requirement already satisfied: numpy>=1.17.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.26.4)
Requirement already satisfied: requests in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (2.31.0)
Requirement already satisfied: scipy>=1.5.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.11.4)
Requirement already satisfied: sympy in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.12)
Requirement already satisfied: llvmlite<0.43,>=0.42.0dev0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from numba>=0.49.0->quantecon) (0.42.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2024.2.2)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from sympy->quantecon) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

import quantecon as qe
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import numba
from time import time

Let’s check the GPU we are running

!nvidia-smi
Thu Jun 13 03:49:58 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.5     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       On  | 00000001:00:00.0 Off |                  Off |
| N/A   46C    P8               9W /  70W |      2MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

We use 64 bit floating point numbers for extra precision.

jax.config.update("jax_enable_x64", True)

15.2. Setup#

We consider a household that chooses a state-contingent consumption plan \(\{c_t\}_{t \geq 0}\) to maximize

\[ \mathbb{E} \, \sum_{t=0}^{\infty} \beta^t u(c_t) \]

subject to

\[ a_{t+1} \leq R(a_t - c_t) + Y_{t+1}, \quad c_t \geq 0, \quad a_t \geq 0 \quad t = 0, 1, \ldots \]

Here \(R = 1 + r\) where \(r\) is the interest rate.

The income process \(\{Y_t\}\) is a Markov chain generated by stochastic matrix \(P\).

The matrix \(P\) and the grid of values taken by \(Y_t\) are obtained by discretizing the AR(1) process

\[ Y_{t+1} = \rho Y_t + \nu \epsilon_{t+1} \]

where \(\{\epsilon_t\}\) is IID and standard normal.

Utility has the CRRA specification

\[ u(c) = \frac{c^{1 - \gamma}} {1 - \gamma} \]

The following function stores default parameter values for the income fluctuation problem and creates suitable arrays.

def ifp(R=1.01,             # gross interest rate
        β=0.99,             # discount factor
        γ=1.5,              # CRRA preference parameter
        s_max=16,           # savings grid max
        s_size=200,         # savings grid size
        ρ=0.99,             # income persistence
        ν=0.02,             # income volatility
        y_size=25):         # income grid size
  
    # require R β < 1 for convergence
    assert R * β < 1, "Stability condition failed."
    # Create income Markov chain
    mc = qe.tauchen(y_size, ρ, ν)
    y_grid, P = jnp.exp(mc.state_values), mc.P
    # Shift to JAX arrays
    P, y_grid = jax.device_put((P, y_grid))
    s_grid = jnp.linspace(0, s_max, s_size)
    # Pack and return
    constants = β, R, γ
    sizes = s_size, y_size
    arrays = s_grid, y_grid, P
    return constants, sizes, arrays

15.3. Solution method#

Let \(S = \mathbb R_+ \times \mathsf Y\) be the set of possible values for the state \((a_t, Y_t)\).

We aim to compute an optimal consumption policy \(\sigma^* \colon S \to \mathbb R\), under which dynamics are given by

\[ c_t = \sigma^*(a_t, Y_t) \quad \text{and} \quad a_{t+1} = R (a_t - c_t) + Y_{t+1} \]

In this section we discuss how we intend to solve for this policy.

15.3.1. Euler equation#

The Euler equation for the optimization problem is

\[ u' (c_t) = \max \left\{ \beta R \, \mathbb{E}_t u'(c_{t+1}) \,,\; u'(a_t) \right\} \]

An explanation for this expression can be found here.

We rewrite the Euler equation in functional form

\[ (u' \circ \sigma) (a, y) = \max \left\{ \beta R \, \mathbb E_y (u' \circ \sigma) [R (a - \sigma(a, y)) + \hat Y, \, \hat Y] \, , \; u'(a) \right\} \]

where \((u' \circ \sigma)(a, y) := u'(\sigma(a, y))\) and \(\sigma\) is a consumption policy.

For given consumption policy \(\sigma\), we define \((K \sigma) (a,y)\) as the unique \(c \in [0, a]\) that solves

(15.1)#\[ u'(c) = \max \left\{ \beta R \, \mathbb E_y (u' \circ \sigma) \, [R (a - c) + \hat Y, \, \hat Y] \, , \; u'(a) \right\} \]

It can be shown that

  1. iterating with \(K\) computes an optimal policy and

  2. if \(\sigma\) is increasing in its first argument, then so is \(K\sigma\)

Hence below we always assume that \(\sigma\) is increasing in its first argument.

The EGM is a technique for computing the update \(K\sigma\) given \(\sigma\) along a grid of asset values.

Notice that, since \(u'(a) \to \infty\) as \(a \downarrow 0\), the second term in the max in (15.1) dominates for sufficiently small \(a\).

Also, again using (15.1), we have \(c=a\) for all such \(a\).

Hence, for sufficiently small \(a\),

\[ u'(a) \geq \beta R \, \mathbb E_y (u' \circ \sigma) \, [\hat Y, \, \hat Y] \]

Equality holds at \(\bar a(y)\) given by

\[ \bar a (y) = (u')^{-1} \left\{ \beta R \, \mathbb E_y (u' \circ \sigma) \, [\hat Y, \, \hat Y] \right\} \]

We can now write

\[\begin{split} u'(c) = \begin{cases} \beta R \, \mathbb E_y (u' \circ \sigma) \, [R (a - c) + \hat Y, \, \hat Y] & \text{if } a > \bar a (y) \\ u'(a) & \text{if } a \leq \bar a (y) \end{cases} \end{split}\]

Equivalently, we can state that the \(c\) satisfying \(c = (K\sigma)(a, y)\) obeys

(15.2)#\[\begin{split} c = \begin{cases} (u')^{-1} \left\{ \beta R \, \mathbb E_y (u' \circ \sigma) \, [R (a - c) + \hat Y, \, \hat Y] \right\} & \text{if } a > \bar a (y) \\ a & \text{if } a \leq \bar a (y) \end{cases} \end{split}\]

We begin with an exogenous grid of saving values \(0 = s_0 < \ldots < s_{N-1}\)

Using the exogenous savings grid, and a fixed value of \(y\), we create an endogenous asset grid \(a_0, \ldots, a_{N-1}\) and a consumption grid \(c_0, \ldots, c_{N-1}\) as follows.

First we set \(a_0 = c_0 = 0\), since zero consumption is an optimal (in fact the only) choice when \(a=0\).

Then, for \(i > 0\), we compute

(15.3)#\[ c_i = (u')^{-1} \left\{ \beta R \, \mathbb E_y (u' \circ \sigma) \, [R s_i + \hat Y, \, \hat Y] \right\} \quad \text{for all } i \]

and we set

\[ a_i = s_i + c_i \]

We claim that each pair \(a_i, c_i\) obeys (15.2).

Indeed, since \(s_i > 0\), choosing \(c_i\) according to (15.3) gives

\[ c_i = (u')^{-1} \left\{ \beta R \, \mathbb E_y (u' \circ \sigma) \, [R s_i + \hat Y, \, \hat Y] \right\} \geq \bar a(y) \]

where the inequality uses the fact that \(\sigma\) is increasing in its first argument.

If we now take \(a_i = s_i + c_i\) we get \(a_i > \bar a(y)\), so the pair \((a_i, c_i)\) satisfies

\[ c_i = (u')^{-1} \left\{ \beta R \, \mathbb E_y (u' \circ \sigma) \, [R (a_i - c_i) + \hat Y, \, \hat Y] \right\} \quad \text{and} \quad a_i > \bar a(y) \]

Hence (15.2) holds.

We are now ready to iterate with \(K\).

15.3.2. JAX version#

First we define a vectorized operator \(K\) based on the EGM.

Notice in the code below that

  • we avoid all loops and any mutation of arrays

  • the function is pure (no globals, no mutation of inputs)

def K_egm(a_in, σ_in, constants, sizes, arrays):
    """
    The vectorized operator K using EGM.

    """
    
    # Unpack
    β, R, γ = constants
    s_size, y_size = sizes
    s_grid, y_grid, P = arrays
    
    def u_prime(c):
        return c**(-γ)

    def u_prime_inv(u):
            return u**(-1/γ)

    # Linearly interpolate σ(a, y)
    def σ(a, y):
        return jnp.interp(a, a_in[:, y], σ_in[:, y])
    σ_vec = jnp.vectorize(σ)

    # Broadcast and vectorize
    y_hat = jnp.reshape(y_grid, (1, 1, y_size))
    y_hat_idx = jnp.reshape(jnp.arange(y_size), (1, 1, y_size))
    s = jnp.reshape(s_grid, (s_size, 1, 1))
    P = jnp.reshape(P, (1, y_size, y_size))
    
    # Evaluate consumption choice
    a_next = R * s + y_hat
    σ_next = σ_vec(a_next, y_hat_idx)
    up = u_prime(σ_next)
    E = jnp.sum(up * P, axis=-1)
    c = u_prime_inv(β * R * E)

    # Set up a column vector with zero in the first row and ones elsewhere
    e_0 = jnp.ones(s_size) - jnp.identity(s_size)[:, 0]
    e_0 = jnp.reshape(e_0, (s_size, 1))

    # The policy is computed consumption with the first row set to zero
    σ_out = c * e_0

    # Compute a_out by a = s + c
    a_out = np.reshape(s_grid, (s_size, 1)) + σ_out
    
    return a_out, σ_out

Then we use jax.jit to compile \(K\).

We use static_argnums to allow a recompile whenever sizes changes, since the compiler likes to specialize on shapes.

K_egm_jax = jax.jit(K_egm, static_argnums=(3,))

Next we define a successive approximator that repeatedly applies \(K\).

def successive_approx_jax(model,        
            tol=1e-5,
            max_iter=100_000,
            verbose=True,
            print_skip=25):

    # Unpack
    constants, sizes, arrays = model
    β, R, γ = constants
    s_size, y_size = sizes
    s_grid, y_grid, P = arrays
    
    # Initial condition is to consume all in every state
    σ_init = jnp.repeat(s_grid, y_size)
    σ_init = jnp.reshape(σ_init, (s_size, y_size))
    a_init = jnp.copy(σ_init)
    a_vec, σ_vec = a_init, σ_init
    
    i = 0
    error = tol + 1

    while i < max_iter and error > tol:
        a_new, σ_new = K_egm_jax(a_vec, σ_vec, constants, sizes, arrays)    
        error = jnp.max(jnp.abs(σ_vec - σ_new))
        i += 1
        if verbose and i % print_skip == 0:
            print(f"Error at iteration {i} is {error}.")
        a_vec, σ_vec = jnp.copy(a_new), jnp.copy(σ_new)

    if error > tol:
        print("Failed to converge!")
    else:
        print(f"\nConverged in {i} iterations.")

    return a_new, σ_new

15.3.3. Numba version#

Below we provide a second set of code, which solves the same model with Numba.

The purpose of this code is to cross-check our results from the JAX version, as well as to do a runtime comparison.

Most readers will want to skip ahead to the next section, where we solve the model and run the cross-check.

@numba.jit
def K_egm_nb(a_in, σ_in, constants, sizes, arrays):
    """
    The operator K using Numba.

    """
    
    # Simplify names
    β, R, γ = constants
    s_size, y_size = sizes
    s_grid, y_grid, P = arrays

    def u_prime(c):
        return c**(-γ)

    def u_prime_inv(u):
        return u**(-1/γ)

    # Linear interpolation of policy using endogenous grid
    def σ(a, z):
        return np.interp(a, a_in[:, z], σ_in[:, z])
    
    # Allocate memory for new consumption array
    σ_out = np.zeros_like(σ_in)
    a_out = np.zeros_like(σ_out)
    
    for i, s in enumerate(s_grid[1:]):
        i += 1
        for z in range(y_size):
            expect = 0.0
            for z_hat in range(y_size):
                expect += u_prime(σ(R * s + y_grid[z_hat], z_hat)) * \
                            P[z, z_hat]
            c = u_prime_inv(β * R * expect)
            σ_out[i, z] = c
            a_out[i, z] = s + c
    
    return a_out, σ_out
def successive_approx_numba(model,        # Class with model information
                              tol=1e-5,
                              max_iter=100_000,
                              verbose=True,
                              print_skip=25):

    # Unpack
    constants, sizes, arrays = model
    s_size, y_size = sizes
    # make NumPy versions of arrays
    arrays = tuple(map(np.array, arrays))
    s_grid, y_grid, P = arrays
    
    σ_init = np.repeat(s_grid, y_size)
    σ_init = np.reshape(σ_init, (s_size, y_size))
    a_init = np.copy(σ_init)
    a_vec, σ_vec = a_init, σ_init
    
    # Set up loop
    i = 0
    error = tol + 1

    while i < max_iter and error > tol:
        a_new, σ_new = K_egm_nb(a_vec, σ_vec, constants, sizes, arrays)
        error = np.max(np.abs(σ_vec - σ_new))
        i += 1
        if verbose and i % print_skip == 0:
            print(f"Error at iteration {i} is {error}.")
        a_vec, σ_vec = np.copy(a_new), np.copy(σ_new)

    if error > tol:
        print("Failed to converge!")
    else:
        print(f"\nConverged in {i} iterations.")

    return a_new, σ_new

15.4. Solutions#

Here we solve the IFP with JAX and Numba.

We will compare both the outputs and the execution time.

15.4.1. Outputs#

model = ifp()

Here’s a first run of the JAX code.

%%time
a_star_egm_jax, σ_star_egm_jax = successive_approx_jax(model,
                                        print_skip=1000)
Error at iteration 1000 is 6.472028596182788e-05.
Error at iteration 2000 is 1.2994575430580468e-05.

Converged in 2192 iterations.
CPU times: user 2.36 s, sys: 727 ms, total: 3.09 s
Wall time: 2.34 s

Next let’s solve the same IFP with Numba.

%%time
a_star_egm_nb, σ_star_egm_nb = successive_approx_numba(model,
                                    print_skip=1000)
Error at iteration 1000 is 6.472028596182788e-05.
Error at iteration 2000 is 1.2994575430802513e-05.
Converged in 2192 iterations.
CPU times: user 54.1 s, sys: 32.3 ms, total: 54.2 s
Wall time: 54.2 s

Now let’s check the outputs in a plot to make sure they are the same.

constants, sizes, arrays = model
β, R, γ = constants
s_size, y_size = sizes
s_grid, y_grid, P = arrays


fig, ax = plt.subplots()

for z in (0, y_size-1):
    ax.plot(a_star_egm_nb[:, z], 
            σ_star_egm_nb[:, z], 
            '--', lw=2,
            label=f"Numba EGM: consumption when $z={z}$")
    ax.plot(a_star_egm_jax[:, z], 
            σ_star_egm_jax[:, z], 
            label=f"JAX EGM: consumption when $z={z}$")

ax.set_xlabel('asset')
plt.legend()
plt.show()
_images/c0c6c2a3f323104158e431f86faf13554631c1f51e006523e6d128a93d45e1a6.png

15.4.2. Timing#

Now let’s compare execution time of the two methods.

start = time()
a_star_egm_jax, σ_star_egm_jax = successive_approx_jax(model,
                                         print_skip=1000)
jax_time_without_compile = time() - start
print("Jax execution time = ", jax_time_without_compile)
Error at iteration 1000 is 6.472028596182788e-05.
Error at iteration 2000 is 1.2994575430580468e-05.

Converged in 2192 iterations.
Jax execution time =  1.7681024074554443
start = time()
a_star_egm_nb, σ_star_egm_nb = successive_approx_numba(model,
                                         print_skip=1000)
numba_time_without_compile = time() - start
print("Numba execution time = ", numba_time_without_compile)
Error at iteration 1000 is 6.472028596182788e-05.
Error at iteration 2000 is 1.2994575430802513e-05.
Converged in 2192 iterations.
Numba execution time =  53.108508825302124
jax_time_without_compile / numba_time_without_compile
0.03329226232413213

The JAX code is significantly faster, as expected.

This difference will increase when more features (and state variables) are added to the model.