9. An Asset Pricing Problem#

GPU

This lecture was built using a machine with JAX installed and access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

9.1. Overview#

In this lecture we consider some asset pricing problems and use them to illustrate some foundations of JAX programming.

The main difference from the lecture Asset Pricing: The Lucas Asset Pricing Model, which also considers asset prices, is that the the state spaces will be discrete and multi-dimensional.

Most of the heavy lifting is done through routines from linear algebra.

Along the way, we will show how to solve some memory-intensive problems with large state spaces.

We do this using elegant techniques made available by JAX, involving the use of linear operators to avoid instantiating large matrices.

If you wish to skip all motivation and move straight to the first equation we plan to solve, you can jump to (9.13).

The code outputs below are generated by machine connected to the following GPU

!nvidia-smi
Mon Aug 12 03:59:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.5     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       On  | 00000001:00:00.0 Off |                  Off |
| N/A   43C    P0              32W /  70W |      2MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

In addition to JAX and Anaconda, this lecture will need the following libraries:

!pip install quantecon
Hide code cell output
Requirement already satisfied: quantecon in /opt/conda/envs/quantecon/lib/python3.11/site-packages (0.7.2)
Requirement already satisfied: numba>=0.49.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (0.59.1)
Requirement already satisfied: numpy>=1.17.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.26.4)
Requirement already satisfied: requests in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (2.32.2)
Requirement already satisfied: scipy>=1.5.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.13.1)
Requirement already satisfied: sympy in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.12)
Requirement already satisfied: llvmlite<0.43,>=0.42.0dev0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from numba>=0.49.0->quantecon) (0.42.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2024.6.2)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from sympy->quantecon) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Below we use the following imports

import scipy
import quantecon as qe
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from collections import namedtuple
from time import time

We will use 64 bit floats with JAX in order to increase precision.

jax.config.update("jax_enable_x64", True)

9.2. Pricing a single payoff#

Suppose, at time \(t\), we have an asset that pays a random amount \(D_{t+1}\) at time \(t+1\) and nothing after that.

The simplest way to price this asset is to use “risk-neutral” asset pricing, which asserts that the price of the asset at time \(t\) should be

(9.1)#\[ P_t = \beta \, \mathbb E_t D_{t+1} \]

Here \(\beta\) is a constant discount factor and \(\mathbb E_t D_{t+1}\) is the expectation of \(D_{t+1}\) at time \(t\).

Roughly speaking, (9.1) says that the cost (i.e., price) equals expected benefit.

The discount factor is introduced because most people prefer payments now to payments in the future.

One problem with this very simple model is that it does not take into account attitudes to risk.

For example, investors often demand higher rates of return for holding risky assets.

This feature of asset prices cannot be captured by risk neutral pricing.

Hence we modify (9.1) to

(9.2)#\[ P_t = \mathbb E_t M_{t+1} D_{t+1} \]

In this expression, \(M_{t+1}\) replaces \(\beta\) and is called the stochastic discount factor.

In essence, allowing discounting to become a random variable gives us the flexibility to combine temporal discounting and attitudes to risk.

We leave further discussion to other lectures because our aim is to move to the computational problem.

9.3. Pricing a cash flow#

Now let’s try to price an asset like a share, which delivers a cash flow \(D_t, D_{t+1}, \ldots\).

We will call these payoffs “dividends”.

If we buy the share, hold it for one period and sell it again, we receive one dividend and our payoff is \(D_{t+1} + P_{t+1}\).

Therefore, by (9.2), the price should be

(9.3)#\[ P_t = \mathbb E_t M_{t+1} [ D_{t+1} + P_{t+1} ] \]

Because prices generally grow over time, which complicates analysis, it will be easier for us to solve for the price-dividend ratio \(V_t := P_t / D_t\).

Let’s write down an expression that this ratio should satisfy.

We can divide both sides of (9.3) by \(D_t\) to get

(9.4)#\[V_t = {\mathbb E}_t \left[ M_{t+1} \frac{D_{t+1}}{D_t} (1 + V_{t+1}) \right]\]

We can also write this as

(9.5)#\[V_t = {\mathbb E}_t \left[ M_{t+1} \exp(G^d_{t+1}) (1 + V_{t+1}) \right]\]

where

\[ G^d_{t+1} = \ln \frac{D_{t+1}}{D_t} \]

is the growth rate of dividends.

Our aim is to solve (9.5) but before that we need to specify

  1. the stochastic discount factor \(M_{t+1}\) and

  2. the growth rate of dividends \(G^d_{t+1}\)

9.4. Choosing the stochastic discount factor#

We will adopt the stochastic discount factor described in Asset Pricing: The Lucas Asset Pricing Model, which has the form

(9.6)#\[ M_{t+1} = \beta \frac{u'(C_{t+1})}{u'(C_t)}\]

where \(u\) is a utility function and \(C_t\) is time \(t\) consumption of a representative consumer.

For utility, we’ll assume the constant relative risk aversion (CRRA) specification

(9.7)#\[ u(c) = \frac{c^{1-\gamma}}{1 - \gamma}\]

Inserting the CRRA specification into (9.6) and letting

\[ G^c_{t+1} = \ln \left( \frac{C_{t+1}}{C_t} \right) \]

the growth rate rate of consumption, we obtain

(9.8)#\[ M_{t+1} = \beta \left(\frac{C_{t+1}}{C_t}\right)^{-\gamma} = \beta \exp( G^c_{t+1} )^{-\gamma} = \beta \exp(-\gamma G^c_{t+1})\]

9.5. Solving for the price-dividend ratio#

Substituting (9.8) into (9.5) gives the price-dividend ratio formula

(9.9)#\[ V_t = \beta {\mathbb E}_t \left[ \exp(G^d_{t+1} - \gamma G^c_{t+1}) (1 + V_{t+1}) \right] \]

We assume there is a Markov chain \(\{X_t\}\), which we call the state process, such that

\[\begin{split} \begin{aligned} & G^c_{t+1} = \mu_c + X_t + \sigma_c \epsilon_{c, t+1} \\ & G^d_{t+1} = \mu_d + X_t + \sigma_d \epsilon_{d, t+1} \end{aligned} \end{split}\]

Here \(\{\epsilon_{c, t}\}\) and \(\{\epsilon_{d, t}\}\) are IID and standard normal, and independent of each other.

We can think of \(\{X_t\}\) as an aggregate shock that affects both consumption growth and firm profits (and hence dividends).

We let \(P\) be the stochastic matrix that governs \(\{X_t\}\) and assume \(\{X_t\}\) takes values in some finite set \(S\).

We guess that \(V_t\) is a fixed function of this state process (and this guess turns out to be correct).

This means that \(V_t = v(X_t)\) for some unknown function \(v\).

By (9.9), the unknown function \(v\) satisfies the equation

(9.10)#\[ v(X_t) = \beta {\mathbb E}_t \left\{ \exp[ a + (1-\gamma) X_t + \sigma_d \epsilon_{d, t+1} - \gamma \sigma_c \epsilon_{c, t+1} ] (1 + v(X_{t+1})) \right\} \]

where \(a := \mu_d - \gamma \mu_c\)

Since the shocks \(\epsilon_{c, t+1}\) and \(\epsilon_{d, t+1}\) are independent of \(\{X_t\}\), we can integrate them out.

We use the following property of lognormal distributions: if \(Y = \exp(c \epsilon)\) for constant \(c\) and \(\epsilon \sim N(0,1)\), then \(\mathbb E Y = \exp(c^2/2)\).

This yields

(9.11)#\[ v(X_t) = \beta {\mathbb E}_t \left\{ \exp \left[ a + (1-\gamma) X_t + \frac{\sigma_d^2 + \gamma^2 \sigma_c^2}{2} \right] (1 + v(X_{t+1})) \right\} \]

Conditioning on \(X_t = x\), we can write this as

(9.12)#\[ v(x) = \beta \sum_{y \in S} \left\{ \exp \left[ a + (1-\gamma) x + \frac{\sigma_d^2 + \gamma^2 \sigma_c^2}{2} \right] (1 + v(y)) \right\} P(x, y) \]

for all \(x \in S\).

Suppose \(S = \{x_1, \ldots, x_N\}\).

Then we can think of \(v\) as an \(N\)-vector and, using square brackets for indices on arrays, write

(9.13)#\[ v[i] = \beta \sum_{j=1}^N \left\{ \exp \left[ a + (1-\gamma) x[i] + \frac{\sigma_d^2 + \gamma^2 \sigma_c^2}{2} \right] (1 + v[j]) \right\} P[i, j] \]

for \(i = 1, \ldots, N\).

Equivalently, we can write

(9.14)#\[ v[i] = \sum_{j=1}^N K[i, j] (1 + v[j]) \]

where \(K\) is the matrix defined by

(9.15)#\[ K[i, j] = \beta \left\{ \exp \left[ a + (1-\gamma) x[i] + \frac{\sigma_d^2 + \gamma^2 \sigma_c^2}{2} \right] \right\} P[i,j] \]

Rewriting (9.14) in vector form yields

(9.16)#\[ v = K (\mathbb 1 + v) \]

Notice that (9.16) can be written as \((I - K)v = K \mathbb 1\).

The Neumann series lemma tells us that \(I - K\) is invertible and the solution is

(9.17)#\[ v = (I - K)^{-1} K \mathbb 1 \]

whenever \(r(K)\), the spectral radius of \(K\), is strictly less than one.

Once we specify \(P\) and all the parameters, we can

  1. obtain \(K\)

  2. check the spectral radius condition \(r(K) < 1\) and, assuming it holds,

  3. compute the solution via (9.17).

9.6. Code#

We will use the power iteration algorithm to check the spectral radius condition.

The function below computes the spectral radius of A.

def power_iteration_sr(A, num_iterations=15, seed=1234):
    " Estimates the spectral radius of A via power iteration. "

    # Initialize
    key = jax.random.PRNGKey(seed)
    b_k = jax.random.normal(key, (A.shape[1],))
    sr = 0

    for _ in range(num_iterations):
        # calculate the matrix-by-vector product Ab
        b_k1 = jnp.dot(A, b_k)

        # calculate the norm
        b_k1_norm = jnp.linalg.norm(b_k1)

        # Record the current estimate of the spectral radius
        sr = jnp.sum(b_k1 * b_k)/jnp.sum(b_k * b_k)

        # re-normalize the vector and continue
        b_k = b_k1 / b_k1_norm

    return sr

power_iteration_sr = jax.jit(power_iteration_sr)

The next function verifies that the spectral radius of a given matrix is \(< 1\).

def test_stability(Q):
    """
    Assert that the spectral radius of matrix Q is < 1.
    """
    sr = power_iteration_sr(Q)
    assert sr < 1, f"Spectral radius condition failed with radius = {sr}"

In what follows we assume that \(\{X_t\}\), the state process, is a discretization of the AR(1) process

\[ X_{t+1} = \rho X_t + \sigma \eta_{t+1} \]

where \(\rho, \sigma\) are parameters and \(\{\eta_t\}\) is IID and standard normal.

To discretize this process we use QuantEcon.py’s tauchen function.

Below we write a function called create_model() that returns a namedtuple storing the relevant parameters and arrays.

Model = namedtuple('Model',
                   ('P', 'S', 'β', 'γ', 'μ_c', 'μ_d', 'σ_c', 'σ_d'))

def create_model(N=100,         # size of state space for Markov chain
                 ρ=0.9,         # persistence parameter for Markov chain
                 σ=0.01,        # persistence parameter for Markov chain
                 β=0.98,        # discount factor
                 γ=2.5,         # coefficient of risk aversion
                 μ_c=0.01,      # mean growth of consumption
                 μ_d=0.01,      # mean growth of dividends
                 σ_c=0.02,      # consumption volatility
                 σ_d=0.04):     # dividend volatility
    # Create the state process
    mc = qe.tauchen(N, ρ, σ)
    S = mc.state_values
    P = mc.P
    # Shift arrays to the device
    S, P = map(jax.device_put, (S, P))
    # Return the namedtuple
    return Model(P=P, S=S, β=β, γ=γ, μ_c=μ_c, μ_d=μ_d, σ_c=σ_c, σ_d=σ_d)

Our first step is to construct the matrix \(K\) defined in (9.15).

Here’s a function that does this using loops.

def compute_K_loop(model):
    # unpack
    P, S, β, γ, μ_c, μ_d, σ_c, σ_d = model
    N = len(S)
    K = np.empty((N, N))
    a = μ_d - γ * μ_c
    for i, x in enumerate(S):
        for j, y in enumerate(S):
            e = np.exp(a + (1 - γ) * x + (σ_d**2 + γ**2 * σ_c**2) / 2)
            K[i, j] = β * e * P[i, j]
    return K

To exploit the parallelization capabilities of JAX, let’s also write a vectorized (i.e., loop-free) implementation.

def compute_K(model):
    # unpack
    P, S, β, γ, μ_c, μ_d, σ_c, σ_d = model
    N = len(S)
    # Reshape and multiply pointwise using broadcasting
    x = np.reshape(S, (N, 1))
    a = μ_d - γ * μ_c
    e = np.exp(a + (1 - γ) * x + (σ_d**2 + γ**2 * σ_c**2) / 2)
    K = β * e * P
    return K

These two functions produce the same output:

model = create_model(N=10)
K1 = compute_K(model)
K2 = compute_K_loop(model)
np.allclose(K1, K2)
2024-08-12 04:00:02.351552: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
True

Now we can compute the price-dividend ratio:

def price_dividend_ratio(model, test_stable=True):
    """
    Computes the price-dividend ratio of the asset.

    Parameters
    ----------
    model: an instance of Model
        contains primitives

    Returns
    -------
    v : array_like
        price-dividend ratio

    """
    K = compute_K(model)
    N = len(model.S)

    if test_stable:
        test_stability(K)

    # Compute v
    I = np.identity(N)
    ones_vec = np.ones(N)
    v = np.linalg.solve(I - K, K @ ones_vec)

    return v

Here’s a plot of \(v\) as a function of the state for several values of \(\gamma\).

model = create_model()
S = model.S
γs = np.linspace(2.0, 3.0, 5)

fig, ax = plt.subplots()

for γ in γs:
    model = create_model(γ=γ)
    v = price_dividend_ratio(model)
    ax.plot(S, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$")

ax.set_ylabel("price-dividend ratio")
ax.set_xlabel("state")
ax.legend(loc='upper right')
plt.show()
_images/28857233c5331591ac8213304ed7dcd214903cd3805207d18cd18b7157e8df69.png

Notice that \(v\) is decreasing in each case.

This is because, with a positively correlated state process, higher states indicate higher future consumption growth.

With the stochastic discount factor (9.8), higher growth decreases the discount factor, lowering the weight placed on future dividends.

9.7. An Extended Example#

One problem with the last set is that volatility is constant through time (i.e., \(\sigma_c\) and \(\sigma_d\) are constants).

In reality, financial markets and growth rates of macroeconomic variables exhibit bursts of volatility.

To accommodate this, we now develop a stochastic volatility model.

To begin, suppose that consumption and dividends grow as follows.

\[ G^i_{t+1} = \mu_i + Z_t + \bar \sigma \exp(H^i_t) \epsilon_{i, t+1}, \qquad i \in \{c, d\} \]

where \(\{Z_t\}\) is a finite Markov chain and \(\{H^c_t\}\) and \(\{H^d_t\}\) are volatility processes.

We assume that \(\{H^c_t\}\) and \(\{H^d_t\}\) are AR(1) processes of the form

\[ H^i_{t+1} = \rho_i H^i_t + \sigma_i \eta_{i, t+1}, \qquad i \in \{c, d\} \]

Here \(\{\eta^c_t\}\) and \(\{\eta^d_t\}\) are IID and standard normal.

Let \(X_t = (H^c_t, H^d_t, Z_t)\).

We call \(\{X_t\}\) the state process and guess that \(V_t\) is a function of this state process, so that \(V_t = v(X_t)\) for some unknown function \(v\).

Modifying (9.10) to accommodate the new growth specifications, we find that \(v\) satisfies

(9.18)#\[\begin{split} \begin{aligned} v(X_t) = & \beta \times \\ & {\mathbb E}_t \left\{ \exp[ a + (1-\gamma) Z_t + \bar \sigma \exp(H^d_t) \epsilon_{d, t+1} - \gamma \bar \sigma \exp(H^c_t) \epsilon_{c, t+1} ] (1 + v(X_{t+1})) \right\} \end{aligned} \end{split}\]

where, as before, \(a := \mu_d - \gamma \mu_c\)

Conditioning on state \(x = (h_c, h_d, z)\), this becomes

(9.19)#\[ v(x) = \beta {\mathbb E}_t \exp[ a + (1-\gamma) z + \bar \sigma \exp(h_d) \epsilon_{d, t+1} - \gamma \bar \sigma \exp(h_c) \epsilon_{c, t+1} ] (1 + v(X_{t+1}))\]

As before, we integrate out the independent shocks and use the rules for expectations of lognormals to obtain

(9.20)#\[ v(x) = \beta {\mathbb E}_t \exp \left[ a + (1-\gamma) z + \bar \sigma^2 \frac{\exp(2 h_d) + \gamma^2 \exp(2 h_c)}{2} \right] (1 + v(X_{t+1}))\]

Let

\[\begin{split} \begin{aligned} A(h_c, h_d, z, h_c', h_d', z') := & \\ & \beta \, \exp \left[ a + (1-\gamma) z + \bar \sigma^2 \frac{\exp(2 h_d) + \gamma^2 \exp(2 h_c)}{2} \right] P(h_c, h_c')Q(h_d, h_d')R(z, z') \end{aligned} \end{split}\]

where \(P, Q, R\) are the stochastic matrices for, respectively, discretized \(\{H^c_t\}\), discretized \(\{H^d_t\}\) and \(\{Z_t\}\),

With this notation, we can write (9.20) more explicitly as

(9.21)#\[ v(h_c, h_d, z) = \sum_{h_c', h_d', z'} (1 + v(h_c', h_d', z')) A(h_c, h_d, z, h_c', h_d', z')\]

Let’s now write the state using indices, with \((i, j, k)\) being the indices for \((h_c, h_d, z)\).

Then (9.21) becomes

(9.22)#\[ v[i, j, k] = \sum_{i', j', k'} A[i, j, k, i', j', k'] (1 + v[i', j', k']) \]

One way to understand this is to reshape \(v\) into an \(N\)-vector, where \(N = I \times J \times K\), and \(A\) into an \(N \times N\) matrix.

Then we can write (9.22) as

\[ v = A(\mathbb 1 + v) \]

Provided that the spectral radius condition \(r(A) < 1\) holds, the solution is given by

\[ v = (I - A)^{-1} A \mathbb 1 \]

9.8. Numpy Version#

Our first implementation will be in NumPy.

Once we have a NumPy version working, we will convert it to JAX and check the difference in the run times.

The code block below provides a function called create_sv_model() that returns a namedtuple containing arrays and other data that form the primitives of the problem.

It assumes that \(\{Z_t\}\) is a discretization of

\[ Z_{t+1} = \rho_z Z_t + \sigma_z \xi_{t+1} \]
SVModel = namedtuple('SVModel',
                        ('P', 'hc_grid',
                         'Q', 'hd_grid',
                         'R', 'z_grid',
                         'β', 'γ', 'bar_σ', 'μ_c', 'μ_d'))

def create_sv_model(β=0.98,        # discount factor
                    γ=2.5,         # coefficient of risk aversion
                    I=14,          # size of state space for h_c
                    ρ_c=0.9,       # persistence parameter for h_c
                    σ_c=0.01,      # volatility parameter for h_c
                    J=14,          # size of state space for h_d
                    ρ_d=0.9,       # persistence parameter for h_d
                    σ_d=0.01,      # volatility parameter for h_d
                    K=14,          # size of state space for z
                    bar_σ=0.01,    # volatility scaling parameter
                    ρ_z=0.9,       # persistence parameter for z
                    σ_z=0.01,      # persistence parameter for z
                    μ_c=0.001,     # mean growth of consumption
                    μ_d=0.005):    # mean growth of dividends

    mc = qe.tauchen(I, ρ_c, σ_c)
    hc_grid = mc.state_values
    P = mc.P
    mc = qe.tauchen(J, ρ_d, σ_d)
    hd_grid = mc.state_values
    Q = mc.P
    mc = qe.tauchen(K, ρ_z, σ_z)
    z_grid = mc.state_values
    R = mc.P

    return SVModel(P=P, hc_grid=hc_grid,
                   Q=Q, hd_grid=hd_grid,
                   R=R, z_grid=z_grid,
                   β=β, γ=γ, bar_σ=bar_σ, μ_c=μ_c, μ_d=μ_d)

Now we provide a function to compute the matrix \(A\).

def compute_A(sv_model):
    # Set up
    P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
    I, J, K = len(hc_grid), len(hd_grid), len(z_grid)
    N = I * J * K
    # Reshape and broadcast over (i, j, k, i', j', k')
    hc = np.reshape(hc_grid,     (I, 1, 1, 1,  1,  1))
    hd = np.reshape(hd_grid,     (1, J, 1, 1,  1,  1))
    z = np.reshape(z_grid,       (1, 1, K, 1,  1,  1))
    P = np.reshape(P,            (I, 1, 1, I,  1,  1))
    Q = np.reshape(Q,            (1, J, 1, 1,  J,  1))
    R = np.reshape(R,            (1, 1, K, 1,  1,  K))
    # Compute A and then reshape to create a matrix
    a = μ_d - γ * μ_c
    b = bar_σ**2 * (np.exp(2 * hd) + γ**2 * np.exp(2 * hc)) / 2
    κ = np.exp(a + (1 - γ) * z + b)
    A = β * κ * P * Q * R
    A = np.reshape(A, (N, N))
    return A

Here’s our function to compute the price-dividend ratio for the stochastic volatility model.

def sv_pd_ratio(sv_model, test_stable=True):
    """
    Computes the price-dividend ratio of the asset for the stochastic volatility
    model.

    Parameters
    ----------
    sv_model: an instance of Model
        contains primitives

    Returns
    -------
    v : array_like
        price-dividend ratio

    """
    # unpack
    P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
    I, J, K = len(hc_grid), len(hd_grid), len(z_grid)
    N = I * J * K

    A = compute_A(sv_model)
    # Make sure that a unique solution exists
    if test_stable:
        test_stability(A)

    # Compute v
    ones_array = np.ones(N)
    Id = np.identity(N)
    v = scipy.linalg.solve(Id - A, A @ ones_array)
    # Reshape into an array of the form v[i, j, k]
    v = np.reshape(v, (I, J, K))
    return v

Let’s create an instance of the model and solve it.

sv_model = create_sv_model()
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model

Let’s run it to compile.

start = time()
v = sv_pd_ratio(sv_model)
numpy_with_compile = time() - start
print("Numpy compile plus execution time = ", numpy_with_compile)
Numpy compile plus execution time =  0.7299137115478516

Let’s run it again to remove the compile.

start = time()
v = sv_pd_ratio(sv_model)
numpy_without_compile = time() - start
print("Numpy execution time = ", numpy_without_compile)
Numpy execution time =  0.2265474796295166

Here are some plots of the solution \(v\) along the three dimensions.

fig, ax = plt.subplots()
ax.plot(hc_grid, v[:, 0, 0], lw=2, alpha=0.6, label="$v$ as a function of $h^c$")
ax.set_ylabel("price-dividend ratio")
ax.set_xlabel("state")
ax.legend()
plt.show()
_images/c720f43fa578081a8cb399704eed66692caa6599017671d2a7652bb77b94156d.png
fig, ax = plt.subplots()
ax.plot(hd_grid, v[0, :, 0], lw=2, alpha=0.6, label="$v$ as a function of $h^d$")
ax.set_ylabel("price-dividend ratio")
ax.set_xlabel("state")
ax.legend()
plt.show()
_images/21800021b8f8830a48420015e879f6b2660890e9e5d524dc5ecf5b8caa148619.png
fig, ax = plt.subplots()
ax.plot(z_grid, v[0, 0, :], lw=2, alpha=0.6, label="$v$ as a function of $Z$")
ax.set_ylabel("price-dividend ratio")
ax.set_xlabel("state")
ax.legend()
plt.show()
_images/b68fe5e7b208e08fba4e133af186723b0d2968c882dd05591da59d66e5c70d18.png

9.9. JAX Version#

Now let’s write a JAX version that is a simple transformation of the NumPy version.

(Below we will write a more efficient version using JAX’s ability to work with linear operators.)

def create_sv_model_jax(sv_model):    # mean growth of dividends

    # Take the contents of a NumPy sv_model instance
    P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model

    # Shift the arrays to the device (GPU if available)
    hc_grid, hd_grid, z_grid = map(jax.device_put, (hc_grid, hd_grid, z_grid))
    P, Q, R = map(jax.device_put, (P, Q, R))

    # Create a new instance and return it
    return SVModel(P=P, hc_grid=hc_grid,
                   Q=Q, hd_grid=hd_grid,
                   R=R, z_grid=z_grid,
                   β=β, γ=γ, bar_σ=bar_σ, μ_c=μ_c, μ_d=μ_d)

Here’s a function to compute \(A\).

We include the extra argument shapes to help the compiler understand the size of the arrays.

This is important when we JIT-compile the function below.

def compute_A_jax(sv_model, shapes):
    # Set up
    P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
    I, J, K = shapes
    N = I * J * K
    # Reshape and broadcast over (i, j, k, i', j', k')
    hc = jnp.reshape(hc_grid,     (I, 1, 1, 1,  1,  1))
    hd = jnp.reshape(hd_grid,     (1, J, 1, 1,  1,  1))
    z = jnp.reshape(z_grid,       (1, 1, K, 1,  1,  1))
    P = jnp.reshape(P,            (I, 1, 1, I,  1,  1))
    Q = jnp.reshape(Q,            (1, J, 1, 1,  J,  1))
    R = jnp.reshape(R,            (1, 1, K, 1,  1,  K))
    # Compute A and then reshape to create a matrix
    a = μ_d - γ * μ_c
    b = bar_σ**2 * (jnp.exp(2 * hd) + γ**2 * jnp.exp(2 * hc)) / 2
    κ = jnp.exp(a + (1 - γ) * z + b)
    A = β * κ * P * Q * R
    A = jnp.reshape(A, (N, N))
    return A

Here’s the function that computes the solution.

def sv_pd_ratio_jax(sv_model, shapes):
    """
    Computes the price-dividend ratio of the asset for the stochastic volatility
    model.

    Parameters
    ----------
    sv_model: an instance of Model
        contains primitives

    Returns
    -------
    v : array_like
        price-dividend ratio

    """
    # unpack
    P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
    I, J, K = len(hc_grid), len(hd_grid), len(z_grid)
    shapes = I, J, K
    N = I * J * K

    A = compute_A_jax(sv_model, shapes)

    # Compute v, reshape and return
    ones_array = jnp.ones(N)
    Id = jnp.identity(N)
    v = jax.scipy.linalg.solve(Id - A, A @ ones_array)
    return jnp.reshape(v, (I, J, K))

Now let’s target these functions for JIT-compilation, while using static_argnums to indicate that the function will need to be recompiled when shapes changes.

compute_A_jax = jax.jit(compute_A_jax, static_argnums=(1,))
sv_pd_ratio_jax = jax.jit(sv_pd_ratio_jax, static_argnums=(1,))
sv_model = create_sv_model()
sv_model_jax = create_sv_model_jax(sv_model)
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model_jax
shapes = len(hc_grid), len(hd_grid), len(z_grid)

Let’s see how long it takes to run with compile time included.

start = time()
v_jax = sv_pd_ratio_jax(sv_model_jax, shapes).block_until_ready()
jnp_with_compile = time() - start
print("JAX compile plus execution time = ", jnp_with_compile)
JAX compile plus execution time =  0.4375894069671631

And now let’s see without compile time.

start = time()
v_jax = sv_pd_ratio_jax(sv_model_jax, shapes).block_until_ready()
jnp_without_compile = time() - start
print("JAX execution time = ", jnp_without_compile)
JAX execution time =  0.09215307235717773

Here’s the ratio of times:

jnp_without_compile / numpy_without_compile
0.40677156288774363

Let’s check that the NumPy and JAX versions realize the same solution.

v = jax.device_put(v)

print(jnp.allclose(v, v_jax))
True

9.10. A memory-efficient JAX version#

One problem with the code above is that we instantiate a matrix of size \(N = I \times J \times K\).

This quickly becomes impossible as \(I, J, K\) increase.

Fortunately, JAX makes it possible to solve for the price-dividend ratio without instantiating this large matrix.

The first step is to think of \(A\) not as a matrix, but rather as the linear operator that transforms \(g\) into \(Ag\).

def A(g, sv_model, shapes):
    # Set up
    P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
    I, J, K = shapes
    # Reshape and broadcast over (i, j, k, i', j', k')
    hc = jnp.reshape(hc_grid,     (I, 1, 1, 1,  1,  1))
    hd = jnp.reshape(hd_grid,     (1, J, 1, 1,  1,  1))
    z = jnp.reshape(z_grid,       (1, 1, K, 1,  1,  1))
    P = jnp.reshape(P,            (I, 1, 1, I,  1,  1))
    Q = jnp.reshape(Q,            (1, J, 1, 1,  J,  1))
    R = jnp.reshape(R,            (1, 1, K, 1,  1,  K))
    g = jnp.reshape(g,            (1, 1, 1, I,  J,  K))
    a = μ_d - γ * μ_c
    b = bar_σ**2 * (jnp.exp(2 * hd) + γ**2 * jnp.exp(2 * hc)) / 2
    κ = jnp.exp(a + (1 - γ) * z + b)
    A = β * κ * P * Q * R
    Ag = jnp.sum(A * g, axis=(3, 4, 5))
    return Ag

Now we write a version of the solution function for the price-dividend ratio that acts directly on the linear operator A.

def sv_pd_ratio_linop(sv_model, shapes):
    P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
    I, J, K = shapes

    ones_array = jnp.ones((I, J, K))
    # Set up the operator g -> (I - A) g
    J = lambda g: g - A(g, sv_model, shapes)
    # Solve v = (I - A)^{-1} A 1
    A1 = A(ones_array, sv_model, shapes)
    # Apply an iterative solver that works for linear operators
    v = jax.scipy.sparse.linalg.bicgstab(J, A1)[0]
    return v

Let’s target these functions for JIT compilation.

A = jax.jit(A, static_argnums=(2,))
sv_pd_ratio_linop = jax.jit(sv_pd_ratio_linop, static_argnums=(1,))

Let’s time the solution with compile time included.

start = time()
v_jax_linop = sv_pd_ratio_linop(sv_model, shapes).block_until_ready()
jnp_linop_with_compile = time() - start
print("JAX compile plus execution time = ", jnp_linop_with_compile)
JAX compile plus execution time =  0.6296956539154053

And now let’s see without compile time.

start = time()
v_jax_linop = sv_pd_ratio_linop(sv_model, shapes).block_until_ready()
jnp_linop_without_compile = time() - start
print("JAX execution time = ", jnp_linop_without_compile)
JAX execution time =  0.0056879520416259766

Let’s verify the solution again:

print(jnp.allclose(v, v_jax_linop))
True

Here’s the ratio of times between memory-efficient and direct version:

jnp_linop_without_compile / jnp_without_compile
0.061722869310096816

The speed is somewhat faster and, moreover, we can now work with much larger grids.

Here’s a moderately large example, where the state space has 15,625 elements.

sv_model = create_sv_model(I=25, J=25, K=25)
sv_model_jax = create_sv_model_jax(sv_model)
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model_jax
shapes = len(hc_grid), len(hd_grid), len(z_grid)

%time _ = sv_pd_ratio_linop(sv_model_jax, shapes).block_until_ready()
%time _ = sv_pd_ratio_linop(sv_model_jax, shapes).block_until_ready()
CPU times: user 493 ms, sys: 2.59 ms, total: 496 ms
Wall time: 692 ms
CPU times: user 178 ms, sys: 0 ns, total: 178 ms
Wall time: 178 ms

The solution is computed relatively quickly and without memory issues.

Readers will find that they can push these numbers further, although we refrain from doing so here.