11. Optimal Savings#
GPU
This lecture was built using hardware that has access to a GPU.
To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.
To run this lecture on your own machine, you need to install Google JAX.
In addition to what’s in Anaconda, this lecture will need the following libraries:
!pip install quantecon
Show code cell output
Requirement already satisfied: quantecon in /opt/conda/envs/quantecon/lib/python3.10/site-packages (0.7.1)
Requirement already satisfied: numba>=0.49.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from quantecon) (0.56.4)
Requirement already satisfied: sympy in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from quantecon) (1.11.1)
Requirement already satisfied: requests in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from quantecon) (2.28.1)
Requirement already satisfied: numpy>=1.17.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from quantecon) (1.23.5)
Requirement already satisfied: scipy>=1.5.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from quantecon) (1.10.0)
Requirement already satisfied: setuptools in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from numba>=0.49.0->quantecon) (65.6.3)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from numba>=0.49.0->quantecon) (0.39.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from requests->quantecon) (1.26.14)
Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from requests->quantecon) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from requests->quantecon) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from requests->quantecon) (2022.12.7)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/envs/quantecon/lib/python3.10/site-packages/mpmath-1.2.1-py3.10.egg (from sympy->quantecon) (1.2.1)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
We will use the following imports:
import quantecon as qe
import jax
import jax.numpy as jnp
from collections import namedtuple
import matplotlib.pyplot as plt
import time
Let’s check the GPU we are running
!nvidia-smi
Fri Sep 22 00:40:22 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03 Driver Version: 470.182.03 CUDA Version: 12.1 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... Off | 00000000:00:1E.0 Off | 0 |
| N/A 29C P0 38W / 300W | 0MiB / 16160MiB | 2% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Use 64 bit floats with JAX in order to match NumPy code
By default, JAX uses 32-bit datatypes.
By default, NumPy uses 64-bit datatypes.
jax.config.update("jax_enable_x64", True)
11.1. Overview#
We consider an optimal savings problem with CRRA utility and budget constraint
We assume that labor income \((Y_t)\) is a discretized AR(1) process.
The right-hand side of the Bellman equation is
where
We use successive approximation for VFI.
def successive_approx(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
max_iter=10_000, # Max iteration bound
print_step=25, # Print at multiples
verbose=False):
x = x_0
error = tolerance + 1
k = 1
while error > tolerance and k <= max_iter:
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
if verbose and k % print_step == 0:
print(f"Completed iteration {k} with error {error}.")
x = x_new
k += 1
if error > tolerance:
print(f"Warning: Iteration hit upper bound {max_iter}.")
elif verbose:
print(f"Terminated successfully in {k} iterations.")
return x
11.2. Model primitives#
First we define a model that stores parameters and grids
def create_consumption_model(R=1.01, # Gross interest rate
β=0.98, # Discount factor
γ=2, # CRRA parameter
w_min=0.01, # Min wealth
w_max=5.0, # Max wealth
w_size=150, # Grid side
ρ=0.9, ν=0.1, y_size=100): # Income parameters
"""
A function that takes in parameters and returns parameters and grids
for the optimal savings problem.
"""
w_grid = jnp.linspace(w_min, w_max, w_size)
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
y_grid, Q = jnp.exp(mc.state_values), mc.P
β, R, γ = jax.device_put([β, R, γ])
w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q]))
sizes = w_size, y_size
return (β, R, γ), sizes, (w_grid, y_grid, Q)
Here’s the right hand side of the Bellman equation:
def B(v, constants, sizes, arrays):
"""
A vectorized version of the right-hand side of the Bellman equation
(before maximization), which is a 3D array representing
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
for all (w, y, w′).
"""
# Unpack
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
# Compute current rewards r(w, y, wp) as array r[i, j, ip]
w = jnp.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip]
y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
wp = jnp.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip]
c = R * w + y - wp
# Calculate continuation rewards at all combinations of (w, y, wp)
v = jnp.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
EV = jnp.sum(v * Q, axis=3) # sum over last index jp
# Compute the right-hand side of the Bellman equation
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
11.3. Operators#
We define a function to compute the current rewards \(r_\sigma\) given policy \(\sigma\), which is defined as the vector
def compute_r_σ(σ, constants, sizes, arrays):
"""
Compute the array r_σ[i, j] = r[i, j, σ[i, j]], which gives current
rewards given policy σ.
"""
# Unpack model
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
# Compute r_σ[i, j]
w = jnp.reshape(w_grid, (w_size, 1))
y = jnp.reshape(y_grid, (1, y_size))
wp = w_grid[σ]
c = R * w + y - wp
r_σ = c**(1-γ)/(1-γ)
return r_σ
Now we define the policy operator \(T_\sigma\)
def T_σ(v, σ, constants, sizes, arrays):
"The σ-policy operator."
# Unpack model
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
r_σ = compute_r_σ(σ, constants, sizes, arrays)
# Compute the array v[σ[i, j], jp]
yp_idx = jnp.arange(y_size)
yp_idx = jnp.reshape(yp_idx, (1, 1, y_size))
σ = jnp.reshape(σ, (w_size, y_size, 1))
V = v[σ, yp_idx]
# Convert Q[j, jp] to Q[i, j, jp]
Q = jnp.reshape(Q, (1, y_size, y_size))
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
Ev = jnp.sum(V * Q, axis=2)
return r_σ + β * Ev
and the Bellman operator \(T\)
def T(v, constants, sizes, arrays):
"The Bellman operator."
return jnp.max(B(v, constants, sizes, arrays), axis=2)
The next function computes a \(v\)-greedy policy given \(v\)
def get_greedy(v, constants, sizes, arrays):
"Computes a v-greedy policy, returned as a set of indices."
return jnp.argmax(B(v, constants, sizes, arrays), axis=2)
The function below computes the value \(v_\sigma\) of following policy \(\sigma\).
This lifetime value is a function \(v_\sigma\) that satisfies
We wish to solve this equation for \(v_\sigma\).
Suppose we define the linear operator \(L_\sigma\) by
With this notation, the problem is to solve for \(v\) via
In vector for this is \(L_\sigma v = r_\sigma\), which tells us that the function we seek is
JAX allows us to solve linear systems defined in terms of operators; the first step is to define the function \(L_{\sigma}\).
def L_σ(v, σ, constants, sizes, arrays):
"""
Here we set up the linear map v -> L_σ v, where
(L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)
"""
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
# Set up the array v[σ[i, j], jp]
zp_idx = jnp.arange(y_size)
zp_idx = jnp.reshape(zp_idx, (1, 1, y_size))
σ = jnp.reshape(σ, (w_size, y_size, 1))
V = v[σ, zp_idx]
# Expand Q[j, jp] to Q[i, j, jp]
Q = jnp.reshape(Q, (1, y_size, y_size))
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
return v - β * jnp.sum(V * Q, axis=2)
Now we can define a function to compute \(v_{\sigma}\)
def get_value(σ, constants, sizes, arrays):
"Get the value v_σ of policy σ by inverting the linear map L_σ."
# Unpack
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
r_σ = compute_r_σ(σ, constants, sizes, arrays)
# Reduce L_σ to a function in v
partial_L_σ = lambda v: L_σ(v, σ, constants, sizes, arrays)
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
11.4. JIT compiled versions#
B = jax.jit(B, static_argnums=(2,))
compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
T = jax.jit(T, static_argnums=(2,))
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
get_value = jax.jit(get_value, static_argnums=(2,))
T_σ = jax.jit(T_σ, static_argnums=(3,))
L_σ = jax.jit(L_σ, static_argnums=(3,))
11.5. Solvers#
Now we define the solvers, which implement VFI, HPI and OPI.
# Implements VFI-Value Function iteration
def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
_T = lambda v: T(v, constants, sizes, arrays)
vz = jnp.zeros(sizes)
v_star = successive_approx(_T, vz, tolerance=tol)
return get_greedy(v_star, constants, sizes, arrays)
# Implements HPI-Howard policy iteration routine
def policy_iteration(model, maxiter=250):
constants, sizes, arrays = model
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0 and i < maxiter:
v_σ = get_value(σ, constants, sizes, arrays)
σ_new = get_greedy(v_σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(σ_new - σ))
σ = σ_new
i = i + 1
print(f"Concluded loop {i} with error {error}.")
return σ
# Implements the OPI-Optimal policy Iteration routine
def optimistic_policy_iteration(model, tol=1e-5, m=10):
constants, sizes, arrays = model
v = jnp.zeros(sizes)
error = tol + 1
while error > tol:
last_v = v
σ = get_greedy(v, constants, sizes, arrays)
for _ in range(m):
v = T_σ(v, σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(v - last_v))
return get_greedy(v, constants, sizes, arrays)
11.6. Plots#
Create a model for consumption, perform policy iteration, and plot the resulting optimal policy function.
fontsize = 12
model = create_consumption_model()
# Unpack
constants, sizes, arrays = model
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
σ_star = policy_iteration(model)
fig, ax = plt.subplots(figsize=(9, 5.2))
ax.plot(w_grid, w_grid, "k--", label="45")
ax.plot(w_grid, w_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, y_1)$")
ax.plot(w_grid, w_grid[σ_star[:, -1]], label="$\\sigma^*(\cdot, y_N)$")
ax.legend(fontsize=fontsize)
plt.show()
Concluded loop 1 with error 77.
Concluded loop 2 with error 53.
Concluded loop 3 with error 28.
Concluded loop 4 with error 17.
Concluded loop 5 with error 8.
Concluded loop 6 with error 4.
Concluded loop 7 with error 1.
Concluded loop 8 with error 1.
Concluded loop 9 with error 1.
Concluded loop 10 with error 0.

11.7. Tests#
Here’s a quick test of the timing of each solver.
model = create_consumption_model()
print("Starting HPI.")
start_time = time.time()
out = policy_iteration(model)
elapsed = time.time() - start_time
print(f"HPI completed in {elapsed} seconds.")
Starting HPI.
Concluded loop 1 with error 77.
Concluded loop 2 with error 53.
Concluded loop 3 with error 28.
Concluded loop 4 with error 17.
Concluded loop 5 with error 8.
Concluded loop 6 with error 4.
Concluded loop 7 with error 1.
Concluded loop 8 with error 1.
Concluded loop 9 with error 1.
Concluded loop 10 with error 0.
HPI completed in 0.031373023986816406 seconds.
print("Starting VFI.")
start_time = time.time()
out = value_iteration(model)
elapsed = time.time() - start_time
print(f"VFI(jax not in succ) completed in {elapsed} seconds.")
Starting VFI.
VFI(jax not in succ) completed in 1.0236868858337402 seconds.
print("Starting OPI.")
start_time = time.time()
out = optimistic_policy_iteration(model, m=100)
elapsed = time.time() - start_time
print(f"OPI completed in {elapsed} seconds.")
Starting OPI.
OPI completed in 0.2709696292877197 seconds.
def run_algorithm(algorithm, model, **kwargs):
start_time = time.time()
result = algorithm(model, **kwargs)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"{algorithm.__name__} completed in {elapsed_time:.2f} seconds.")
return result, elapsed_time
model = create_consumption_model()
σ_pi, pi_time = run_algorithm(policy_iteration, model)
σ_vfi, vfi_time = run_algorithm(value_iteration, model, tol=1e-5)
m_vals = range(5, 600, 40)
opi_times = []
for m in m_vals:
σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, model, m=m, tol=1e-5)
opi_times.append(opi_time)
Concluded loop 1 with error 77.
Concluded loop 2 with error 53.
Concluded loop 3 with error 28.
Concluded loop 4 with error 17.
Concluded loop 5 with error 8.
Concluded loop 6 with error 4.
Concluded loop 7 with error 1.
Concluded loop 8 with error 1.
Concluded loop 9 with error 1.
Concluded loop 10 with error 0.
policy_iteration completed in 0.03 seconds.
value_iteration completed in 0.38 seconds.
optimistic_policy_iteration completed in 0.17 seconds.
optimistic_policy_iteration completed in 0.09 seconds.
optimistic_policy_iteration completed in 0.09 seconds.
optimistic_policy_iteration completed in 0.12 seconds.
optimistic_policy_iteration completed in 0.18 seconds.
optimistic_policy_iteration completed in 0.22 seconds.
optimistic_policy_iteration completed in 0.26 seconds.
optimistic_policy_iteration completed in 0.30 seconds.
optimistic_policy_iteration completed in 0.34 seconds.
optimistic_policy_iteration completed in 0.38 seconds.
optimistic_policy_iteration completed in 0.40 seconds.
optimistic_policy_iteration completed in 0.51 seconds.
optimistic_policy_iteration completed in 0.51 seconds.
optimistic_policy_iteration completed in 0.53 seconds.
optimistic_policy_iteration completed in 0.55 seconds.
fig, ax = plt.subplots(figsize=(9, 5.2))
ax.plot(m_vals, jnp.full(len(m_vals), pi_time), lw=2, label="Howard policy iteration")
ax.plot(m_vals, jnp.full(len(m_vals), vfi_time), lw=2, label="value function iteration")
ax.plot(m_vals, opi_times, lw=2, label="optimistic policy iteration")
ax.legend(fontsize=fontsize, frameon=False)
ax.set_xlabel("$m$", fontsize=fontsize)
ax.set_ylabel("time", fontsize=fontsize)
plt.show()