# 10. Optimal Investment#

GPU

This lecture was built using hardware that has access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

We require the following library to be installed.

```
!pip install --upgrade quantecon
```

## Show code cell output

```
Requirement already satisfied: quantecon in /opt/conda/envs/quantecon/lib/python3.11/site-packages (0.7.1)
```

```
Requirement already satisfied: numba>=0.49.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (0.57.1)
Requirement already satisfied: numpy>=1.17.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.24.3)
Requirement already satisfied: requests in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (2.31.0)
Requirement already satisfied: scipy>=1.5.0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.11.1)
Requirement already satisfied: sympy in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from quantecon) (1.11.1)
Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from numba>=0.49.0->quantecon) (0.40.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2.0.4)
```

```
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from requests->quantecon) (2023.7.22)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/envs/quantecon/lib/python3.11/site-packages (from sympy->quantecon) (1.3.0)
```

```
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
```

We study a monopolist who faces inverse demand curve

where

\(P_t\) is price,

\(Y_t\) is output and

\(Z_t\) is a demand shock.

We assume that \(Z_t\) is a discretized AR(1) process, specified below.

Current profits are

Combining with the demand curve and writing \(y, y'\) for \(Y_t, Y_{t+1}\), this becomes

The firm maximizes present value of expected discounted profits. The Bellman equation is

We discretize \(y\) to a finite grid `y_grid`

.

In essence, the firm tries to choose output close to the monopolist profit maximizer, given \(Z_t\), but is constrained by adjustment costs.

Let’s begin with the following imports

```
import quantecon as qe
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
```

Let’s check the GPU we are running

```
!nvidia-smi
```

```
Thu Dec 21 00:59:38 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03 Driver Version: 470.182.03 CUDA Version: 12.1 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
```

```
| 0 Tesla V100-SXM2... Off | 00000000:00:1E.0 Off | 0 |
| N/A 33C P0 39W / 300W | 0MiB / 16160MiB | 2% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
```

We will use 64 bit floats with JAX in order to increase the precision.

```
jax.config.update("jax_enable_x64", True)
```

We need the following successive approximation function.

```
def successive_approx(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
max_iter=10_000, # Max iteration bound
print_step=25, # Print at multiples
verbose=False):
x = x_0
error = tolerance + 1
k = 1
while error > tolerance and k <= max_iter:
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
if verbose and k % print_step == 0:
print(f"Completed iteration {k} with error {error}.")
x = x_new
k += 1
if error > tolerance:
print(f"Warning: Iteration hit upper bound {max_iter}.")
elif verbose:
print(f"Terminated successfully in {k} iterations.")
return x
```

Let’s define a function to create an investment model using the given parameters.

```
def create_investment_model(
r=0.01, # Interest rate
a_0=10.0, a_1=1.0, # Demand parameters
γ=25.0, c=1.0, # Adjustment and unit cost
y_min=0.0, y_max=20.0, y_size=100, # Grid for output
ρ=0.9, ν=1.0, # AR(1) parameters
z_size=150): # Grid size for shock
"""
A function that takes in parameters and returns an instance of Model that
contains data for the investment problem.
"""
β = 1 / (1 + r)
y_grid = jnp.linspace(y_min, y_max, y_size)
mc = qe.tauchen(z_size, ρ, ν)
z_grid, Q = mc.state_values, mc.P
# Break up parameters into static and nonstatic components
constants = β, a_0, a_1, γ, c
sizes = y_size, z_size
arrays = y_grid, z_grid, Q
# Shift arrays to the device (e.g., GPU)
arrays = tuple(map(jax.device_put, arrays))
return constants, sizes, arrays
```

Let’s re-write the vectorized version of the right-hand side of the Bellman equation (before maximization), which is a 3D array representing

for all \((y, z, y')\).

```
def B(v, constants, sizes, arrays):
"""
A vectorized version of the right-hand side of the Bellman equation
(before maximization)
"""
# Unpack
β, a_0, a_1, γ, c = constants
y_size, z_size = sizes
y_grid, z_grid, Q = arrays
# Compute current rewards r(y, z, yp) as array r[i, j, ip]
y = jnp.reshape(y_grid, (y_size, 1, 1)) # y[i] -> y[i, j, ip]
z = jnp.reshape(z_grid, (1, z_size, 1)) # z[j] -> z[i, j, ip]
yp = jnp.reshape(y_grid, (1, 1, y_size)) # yp[ip] -> yp[i, j, ip]
r = (a_0 - a_1 * y + z - c) * y - γ * (yp - y)**2
# Calculate continuation rewards at all combinations of (y, z, yp)
v = jnp.reshape(v, (1, 1, y_size, z_size)) # v[ip, jp] -> v[i, j, ip, jp]
Q = jnp.reshape(Q, (1, z_size, 1, z_size)) # Q[j, jp] -> Q[i, j, ip, jp]
EV = jnp.sum(v * Q, axis=3) # sum over last index jp
# Compute the right-hand side of the Bellman equation
return r + β * EV
# Create a jitted function
B = jax.jit(B, static_argnums=(2,))
```

We define a function to compute the current rewards \(r_\sigma\) given policy \(\sigma\), which is defined as the vector

```
def compute_r_σ(σ, constants, sizes, arrays):
"""
Compute the array r_σ[i, j] = r[i, j, σ[i, j]], which gives current
rewards given policy σ.
"""
# Unpack model
β, a_0, a_1, γ, c = constants
y_size, z_size = sizes
y_grid, z_grid, Q = arrays
# Compute r_σ[i, j]
y = jnp.reshape(y_grid, (y_size, 1))
z = jnp.reshape(z_grid, (1, z_size))
yp = y_grid[σ]
r_σ = (a_0 - a_1 * y + z - c) * y - γ * (yp - y)**2
return r_σ
# Create the jitted function
compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
```

Define the Bellman operator.

```
def T(v, constants, sizes, arrays):
"""The Bellman operator."""
return jnp.max(B(v, constants, sizes, arrays), axis=2)
T = jax.jit(T, static_argnums=(2,))
```

The following function computes a v-greedy policy.

```
def get_greedy(v, constants, sizes, arrays):
"Computes a v-greedy policy, returned as a set of indices."
return jnp.argmax(B(v, constants, sizes, arrays), axis=2)
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
```

Define the \(\sigma\)-policy operator.

```
def T_σ(v, σ, constants, sizes, arrays):
"""The σ-policy operator."""
# Unpack model
β, a_0, a_1, γ, c = constants
y_size, z_size = sizes
y_grid, z_grid, Q = arrays
r_σ = compute_r_σ(σ, constants, sizes, arrays)
# Compute the array v[σ[i, j], jp]
zp_idx = jnp.arange(z_size)
zp_idx = jnp.reshape(zp_idx, (1, 1, z_size))
σ = jnp.reshape(σ, (y_size, z_size, 1))
V = v[σ, zp_idx]
# Convert Q[j, jp] to Q[i, j, jp]
Q = jnp.reshape(Q, (1, z_size, z_size))
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
Ev = jnp.sum(V * Q, axis=2)
return r_σ + β * Ev
T_σ = jax.jit(T_σ, static_argnums=(3,))
```

Next, we want to computes the lifetime value of following policy \(\sigma\).

This lifetime value is a function \(v_\sigma\) that satisfies

We wish to solve this equation for \(v_\sigma\).

Suppose we define the linear operator \(L_\sigma\) by

With this notation, the problem is to solve for \(v\) via

In vector for this is \(L_\sigma v = r_\sigma\), which tells us that the function we seek is

JAX allows us to solve linear systems defined in terms of operators; the first step is to define the function \(L_{\sigma}\).

```
def L_σ(v, σ, constants, sizes, arrays):
β, a_0, a_1, γ, c = constants
y_size, z_size = sizes
y_grid, z_grid, Q = arrays
# Set up the array v[σ[i, j], jp]
zp_idx = jnp.arange(z_size)
zp_idx = jnp.reshape(zp_idx, (1, 1, z_size))
σ = jnp.reshape(σ, (y_size, z_size, 1))
V = v[σ, zp_idx]
# Expand Q[j, jp] to Q[i, j, jp]
Q = jnp.reshape(Q, (1, z_size, z_size))
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
return v - β * jnp.sum(V * Q, axis=2)
L_σ = jax.jit(L_σ, static_argnums=(3,))
```

Now we can define a function to compute \(v_{\sigma}\)

```
def get_value(σ, constants, sizes, arrays):
# Unpack
β, a_0, a_1, γ, c = constants
y_size, z_size = sizes
y_grid, z_grid, Q = arrays
r_σ = compute_r_σ(σ, constants, sizes, arrays)
# Reduce L_σ to a function in v
partial_L_σ = lambda v: L_σ(v, σ, constants, sizes, arrays)
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
get_value = jax.jit(get_value, static_argnums=(2,))
```

Finally, we introduce the solvers that implement VFI, HPI and OPI.

```
# Implements VFI-Value Function iteration
def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
_T = lambda v: T(v, constants, sizes, arrays)
vz = jnp.zeros(sizes)
v_star = successive_approx(_T, vz, tolerance=tol)
return get_greedy(v_star, constants, sizes, arrays)
```

```
# Implements HPI-Howard policy iteration routine
def policy_iteration(model, maxiter=250):
constants, sizes, arrays = model
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0 and i < maxiter:
v_σ = get_value(σ, constants, sizes, arrays)
σ_new = get_greedy(v_σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(σ_new - σ))
σ = σ_new
i = i + 1
print(f"Concluded loop {i} with error {error}.")
return σ
```

```
# Implements the OPI-Optimal policy Iteration routine
def optimistic_policy_iteration(model, tol=1e-5, m=10):
constants, sizes, arrays = model
v = jnp.zeros(sizes)
error = tol + 1
while error > tol:
last_v = v
σ = get_greedy(v, constants, sizes, arrays)
for _ in range(m):
v = T_σ(v, σ, constants, sizes, arrays)
error = jnp.max(jnp.abs(v - last_v))
return get_greedy(v, constants, sizes, arrays)
```

```
model = create_investment_model()
print("Starting HPI.")
qe.tic()
out = policy_iteration(model)
elapsed = qe.toc()
print(out)
print(f"HPI completed in {elapsed} seconds.")
```

## Show code cell output

```
Starting HPI.
```

```
Concluded loop 1 with error 50.
Concluded loop 2 with error 26.
Concluded loop 3 with error 17.
Concluded loop 4 with error 10.
Concluded loop 5 with error 7.
Concluded loop 6 with error 4.
Concluded loop 7 with error 3.
Concluded loop 8 with error 1.
Concluded loop 9 with error 1.
Concluded loop 10 with error 1.
Concluded loop 11 with error 1.
Concluded loop 12 with error 0.
TOC: Elapsed: 0:00:1.30
[[ 2 2 2 ... 6 6 6]
[ 3 3 3 ... 7 7 7]
[ 4 4 4 ... 7 7 7]
...
[82 82 82 ... 86 86 86]
[83 83 83 ... 86 86 86]
[84 84 84 ... 87 87 87]]
HPI completed in 1.3070704936981201 seconds.
```

```
print("Starting VFI.")
qe.tic()
out = value_iteration(model)
elapsed = qe.toc()
print(out)
print(f"VFI completed in {elapsed} seconds.")
```

## Show code cell output

```
Starting VFI.
```

```
TOC: Elapsed: 0:00:2.04
[[ 2 2 2 ... 6 6 6]
[ 3 3 3 ... 7 7 7]
[ 4 4 4 ... 7 7 7]
...
[82 82 82 ... 86 86 86]
[83 83 83 ... 86 86 86]
[84 84 84 ... 87 87 87]]
VFI completed in 2.0406382083892822 seconds.
```

```
print("Starting OPI.")
qe.tic()
out = optimistic_policy_iteration(model, m=100)
elapsed = qe.toc()
print(out)
print(f"OPI completed in {elapsed} seconds.")
```

## Show code cell output

```
Starting OPI.
```

```
TOC: Elapsed: 0:00:0.82
[[ 2 2 2 ... 6 6 6]
[ 3 3 3 ... 7 7 7]
[ 4 4 4 ... 7 7 7]
...
[82 82 82 ... 86 86 86]
[83 83 83 ... 86 86 86]
[84 84 84 ... 87 87 87]]
OPI completed in 0.8238322734832764 seconds.
```

Here’s the plot of the Howard policy, as a function of \(y\) at the highest and lowest values of \(z\).

```
model = create_investment_model()
constants, sizes, arrays = model
β, a_0, a_1, γ, c = constants
y_size, z_size = sizes
y_grid, z_grid, Q = arrays
```

```
σ_star = policy_iteration(model)
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(y_grid, y_grid, "k--", label="45")
ax.plot(y_grid, y_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, z_1)$")
ax.plot(y_grid, y_grid[σ_star[:, -1]], label="$\\sigma^*(\cdot, z_N)$")
ax.legend(fontsize=12)
plt.show()
```

```
Concluded loop 1 with error 50.
Concluded loop 2 with error 26.
Concluded loop 3 with error 17.
Concluded loop 4 with error 10.
Concluded loop 5 with error 7.
Concluded loop 6 with error 4.
Concluded loop 7 with error 3.
Concluded loop 8 with error 1.
Concluded loop 9 with error 1.
Concluded loop 10 with error 1.
Concluded loop 11 with error 1.
Concluded loop 12 with error 0.
```

Let’s plot the time taken by each of the solvers and compare them.

```
m_vals = range(5, 600, 40)
```

```
model = create_investment_model()
print("Running Howard policy iteration.")
qe.tic()
σ_pi = policy_iteration(model)
pi_time = qe.toc()
```

```
Running Howard policy iteration.
Concluded loop 1 with error 50.
Concluded loop 2 with error 26.
Concluded loop 3 with error 17.
Concluded loop 4 with error 10.
Concluded loop 5 with error 7.
Concluded loop 6 with error 4.
Concluded loop 7 with error 3.
Concluded loop 8 with error 1.
Concluded loop 9 with error 1.
Concluded loop 10 with error 1.
Concluded loop 11 with error 1.
Concluded loop 12 with error 0.
TOC: Elapsed: 0:00:0.04
```

```
print(f"PI completed in {pi_time} seconds.")
print("Running value function iteration.")
qe.tic()
σ_vfi = value_iteration(model, tol=1e-5)
vfi_time = qe.toc()
print(f"VFI completed in {vfi_time} seconds.")
```

```
PI completed in 0.04577064514160156 seconds.
Running value function iteration.
```

```
TOC: Elapsed: 0:00:1.44
VFI completed in 1.442993402481079 seconds.
```

```
opi_times = []
for m in m_vals:
print(f"Running optimistic policy iteration with m={m}.")
qe.tic()
σ_opi = optimistic_policy_iteration(model, m=m, tol=1e-5)
opi_time = qe.toc()
print(f"OPI with m={m} completed in {opi_time} seconds.")
opi_times.append(opi_time)
```

## Show code cell output

```
Running optimistic policy iteration with m=5.
```

```
TOC: Elapsed: 0:00:0.83
OPI with m=5 completed in 0.8374240398406982 seconds.
Running optimistic policy iteration with m=45.
```

```
TOC: Elapsed: 0:00:0.63
OPI with m=45 completed in 0.6356334686279297 seconds.
Running optimistic policy iteration with m=85.
```

```
TOC: Elapsed: 0:00:0.66
OPI with m=85 completed in 0.6696212291717529 seconds.
Running optimistic policy iteration with m=125.
```

```
TOC: Elapsed: 0:00:0.68
OPI with m=125 completed in 0.6830582618713379 seconds.
Running optimistic policy iteration with m=165.
```

```
TOC: Elapsed: 0:00:0.73
OPI with m=165 completed in 0.7334933280944824 seconds.
Running optimistic policy iteration with m=205.
```

```
TOC: Elapsed: 0:00:0.84
OPI with m=205 completed in 0.8447787761688232 seconds.
Running optimistic policy iteration with m=245.
```

```
TOC: Elapsed: 0:00:0.90
OPI with m=245 completed in 0.905911922454834 seconds.
Running optimistic policy iteration with m=285.
```

```
TOC: Elapsed: 0:00:0.97
OPI with m=285 completed in 0.9756362438201904 seconds.
Running optimistic policy iteration with m=325.
```

```
TOC: Elapsed: 0:00:1.12
OPI with m=325 completed in 1.120331048965454 seconds.
Running optimistic policy iteration with m=365.
```

```
TOC: Elapsed: 0:00:1.27
OPI with m=365 completed in 1.2743430137634277 seconds.
Running optimistic policy iteration with m=405.
```

```
TOC: Elapsed: 0:00:1.40
OPI with m=405 completed in 1.4083504676818848 seconds.
Running optimistic policy iteration with m=445.
```

```
TOC: Elapsed: 0:00:1.56
OPI with m=445 completed in 1.5652389526367188 seconds.
Running optimistic policy iteration with m=485.
```

```
TOC: Elapsed: 0:00:1.67
OPI with m=485 completed in 1.675036907196045 seconds.
Running optimistic policy iteration with m=525.
```

```
TOC: Elapsed: 0:00:1.82
OPI with m=525 completed in 1.8284337520599365 seconds.
Running optimistic policy iteration with m=565.
```

```
TOC: Elapsed: 0:00:1.96
OPI with m=565 completed in 1.9662644863128662 seconds.
```

```
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(m_vals, jnp.full(len(m_vals), pi_time),
lw=2, label="Howard policy iteration")
ax.plot(m_vals, jnp.full(len(m_vals), vfi_time),
lw=2, label="value function iteration")
ax.plot(m_vals, opi_times, lw=2, label="optimistic policy iteration")
ax.legend(fontsize=12, frameon=False)
ax.set_xlabel("$m$", fontsize=12)
ax.set_ylabel("time(s)", fontsize=12)
plt.show()
```