15. Inventory Management Model#
GPU
This lecture was built using a machine with JAX installed and access to a GPU.
To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.
To run this lecture on your own machine, you need to install Google JAX.
This lecture provides a JAX implementation of a model in Dynamic Programming.
In addition to JAX and Anaconda, this lecture will need the following libraries:
!pip install --upgrade quantecon
Show code cell output
Requirement already satisfied: quantecon in /opt/conda/envs/quantecon/lib/python3.12/site-packages (0.7.2)
Requirement already satisfied: numba>=0.49.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from quantecon) (0.60.0)
Requirement already satisfied: numpy>=1.17.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from quantecon) (1.26.4)
Requirement already satisfied: requests in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from quantecon) (2.32.3)
Requirement already satisfied: scipy>=1.5.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from quantecon) (1.13.1)
Requirement already satisfied: sympy in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from quantecon) (1.13.2)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from numba>=0.49.0->quantecon) (0.43.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from requests->quantecon) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from requests->quantecon) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from requests->quantecon) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from requests->quantecon) (2024.8.30)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/envs/quantecon/lib/python3.12/site-packages (from sympy->quantecon) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
15.1. A model with constant discounting#
We study a firm where a manager tries to maximize shareholder value.
To simplify the problem, we assume that the firm only sells one product.
Letting \(\pi_t\) be profits at time \(t\) and \(r > 0\) be the interest rate, the value of the firm is
Suppose the firm faces exogenous demand process \((D_t)_{t \geq 0}\).
We assume \((D_t)_{t \geq 0}\) is IID with common distribution \(\phi \in (Z_+)\).
Inventory \((X_t)_{t \geq 0}\) of the product obeys
The term \(A_t\) is units of stock ordered this period, which take one period to arrive.
We assume that the firm can store at most \(K\) items at one time.
Profits are given by
We take the minimum of current stock and demand because orders in excess of inventory are assumed to be lost rather than back-filled.
Here \(c\) is unit product cost and \(\kappa\) is a fixed cost of ordering inventory.
We can map our inventory problem into a dynamic program with state space \(X := \{0, \ldots, K\}\) and action space \(A := X\).
The feasible correspondence \(\Gamma\) is
which represents the set of feasible orders when the current inventory state is \(x\).
The reward function is expected current profits, or
The stochastic kernel (i.e., state-transition probabilities) from the set of feasible state-action pairs is
When discounting is constant, the Bellman equation takes the form
15.2. Time varing discount rates#
We wish to consider a more sophisticated model with time-varying discounting.
This time variation accommodates non-constant interest rates.
To this end, we replace the constant \(\beta\) in (15.1) with a stochastic process \((\beta_t)\) where
\(\beta_t = 1/(1+r_t)\) and
\(r_t\) is the interest rate at time \(t\)
We suppose that the dynamics can be expressed as \(\beta_t = \beta(Z_t)\), where the exogenous process \((Z_t)_{t \geq 0}\) is a Markov chain on \(Z\) with Markov matrix \(Q\).
After relabeling inventory \(X_t\) as \(Y_t\) and \(x\) as \(y\), the Bellman equation becomes
where
We set
Now \(R(y, a, y')\) is the probability of realizing next period inventory level \(y'\) when the current level is \(y\) and the action is \(a\).
Hence we can rewrite (15.2) as
Let’s begin with the following imports
import quantecon as qe
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple
import numba
from numba import prange
from time import time
Let’s check the GPU we are running
!nvidia-smi
Tue Nov 19 23:57:40 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.6 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla T4 On | 00000001:00:00.0 Off | 0 |
| N/A 43C P0 28W / 70W | 2MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
We will use 64 bit floats with JAX in order to increase the precision.
jax.config.update("jax_enable_x64", True)
Let’s define a model to represent the inventory management.
# NamedTuple Model
Model = namedtuple("Model", ("c", "κ", "p", "z_vals", "Q"))
We need the following successive approximation function.
def successive_approx(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
max_iter=10_000, # Max iteration bound
print_step=25, # Print at multiples
verbose=False):
x = x_0
error = tolerance + 1
k = 1
while error > tolerance and k <= max_iter:
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
if verbose and k % print_step == 0:
print(f"Completed iteration {k} with error {error}.")
x = x_new
k += 1
if error > tolerance:
print(f"Warning: Iteration hit upper bound {max_iter}.")
elif verbose:
print(f"Terminated successfully in {k} iterations.")
return x
@jax.jit
def demand_pdf(p, d):
return (1 - p)**d * p
K = 100
D_MAX = 101
Let’s define a function to create an inventory model using the given parameters.
def create_sdd_inventory_model(
ρ=0.98, ν=0.002, n_z=100, b=0.97, # Z state parameters
c=0.2, κ=0.8, p=0.6, # firm and demand parameters
use_jax=True):
mc = qe.tauchen(n_z, ρ, ν)
z_vals, Q = mc.state_values + b, mc.P
if use_jax:
z_vals, Q = map(jnp.array, (z_vals, Q))
return Model(c=c, κ=κ, p=p, z_vals=z_vals, Q=Q)
Here’s the function B
on the right-hand side of the Bellman equation.
@jax.jit
def B(x, i_z, a, v, model):
"""
The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′).
"""
c, κ, p, z_vals, Q = model
z = z_vals[i_z]
d_vals = jnp.arange(D_MAX)
ϕ_vals = demand_pdf(p, d_vals)
revenue = jnp.sum(jnp.minimum(x, d_vals)*ϕ_vals)
profit = revenue - c * a - κ * (a > 0)
v_R = jnp.sum(v[jnp.maximum(x - d_vals, 0) + a].T * ϕ_vals, axis=1)
cv = jnp.sum(v_R*Q[i_z])
return profit + z * cv
We need to vectorize this function so that we can use it efficiently in JAX.
We apply a sequence of vmap
operations to vectorize appropriately in each
argument.
B_vec_a = jax.vmap(B, in_axes=(None, None, 0, None, None))
@jax.jit
def B2(x, i_z, v, model):
"""
The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′).
"""
c, κ, p, z_vals, Q = model
a_vals = jnp.arange(K)
res = B_vec_a(x, i_z, a_vals, v, model)
return jnp.where(a_vals < K - x + 1, res, -jnp.inf)
B2_vec_z = jax.vmap(B2, in_axes=(None, 0, None, None))
B2_vec_z_x = jax.vmap(B2_vec_z, in_axes=(0, None, None, None))
Next we define the Bellman operator.
@jax.jit
def T(v, model):
"""The Bellman operator."""
c, κ, p, z_vals, Q = model
i_z_range = jnp.arange(len(z_vals))
x_range = jnp.arange(K + 1)
res = B2_vec_z_x(x_range, i_z_range, v, model)
return jnp.max(res, axis=2)
The following function computes a v-greedy policy.
@jax.jit
def get_greedy(v, model):
"""Get a v-greedy policy. Returns a zero-based array."""
c, κ, p, z_vals, Q = model
i_z_range = jnp.arange(len(z_vals))
x_range = jnp.arange(K + 1)
res = B2_vec_z_x(x_range, i_z_range, v, model)
return jnp.argmax(res, axis=2)
Here’s code to solve the model using value function iteration.
def solve_inventory_model(v_init, model):
"""Use successive_approx to get v_star and then compute greedy."""
v_star = successive_approx(lambda v: T(v, model), v_init, verbose=True)
σ_star = get_greedy(v_star, model)
return v_star, σ_star
Now let’s create an instance and solve it.
model = create_sdd_inventory_model()
c, κ, p, z_vals, Q = model
n_z = len(z_vals)
v_init = jnp.zeros((K + 1, n_z), dtype=float)
start = time()
v_star, σ_star = solve_inventory_model(v_init, model)
jax_time_with_compile = time() - start
print("Jax compile plus execution time = ", jax_time_with_compile)
Completed iteration 25 with error 0.5613828428334688.
Completed iteration 50 with error 0.37764643476880266.
Completed iteration 75 with error 0.2272706235969011.
Completed iteration 100 with error 0.12872204940709508.
Completed iteration 125 with error 0.06744149371262154.
Completed iteration 150 with error 0.03037463954767361.
Completed iteration 175 with error 0.01423099032950148.
Completed iteration 200 with error 0.007396776219316337.
Completed iteration 225 with error 0.0039122383045793185.
Completed iteration 250 with error 0.002068091416653317.
Completed iteration 275 with error 0.001092307533355097.
Completed iteration 300 with error 0.0005766427105911021.
Completed iteration 325 with error 0.00030433217072101115.
Completed iteration 350 with error 0.00016059073674767887.
Completed iteration 375 with error 8.473334524694565e-05.
Completed iteration 400 with error 4.4706045166265085e-05.
Completed iteration 425 with error 2.3586619946058818e-05.
Completed iteration 450 with error 1.2443945934137446e-05.
Completed iteration 475 with error 6.5651783245357365e-06.
Completed iteration 500 with error 3.463639430378862e-06.
Completed iteration 525 with error 1.827332347659194e-06.
Terminated successfully in 550 iterations.
Jax compile plus execution time = 5.9160120487213135
Let’s run again to get rid of the compile time.
start = time()
v_star, σ_star = solve_inventory_model(v_init, model)
jax_time_without_compile = time() - start
print("Jax execution time = ", jax_time_without_compile)
Completed iteration 25 with error 0.5613828428334688.
Completed iteration 50 with error 0.37764643476880266.
Completed iteration 75 with error 0.2272706235969011.
Completed iteration 100 with error 0.12872204940709508.
Completed iteration 125 with error 0.06744149371262154.
Completed iteration 150 with error 0.03037463954767361.
Completed iteration 175 with error 0.01423099032950148.
Completed iteration 200 with error 0.007396776219316337.
Completed iteration 225 with error 0.0039122383045793185.
Completed iteration 250 with error 0.002068091416653317.
Completed iteration 275 with error 0.001092307533355097.
Completed iteration 300 with error 0.0005766427105911021.
Completed iteration 325 with error 0.00030433217072101115.
Completed iteration 350 with error 0.00016059073674767887.
Completed iteration 375 with error 8.473334524694565e-05.
Completed iteration 400 with error 4.4706045166265085e-05.
Completed iteration 425 with error 2.3586619946058818e-05.
Completed iteration 450 with error 1.2443945934137446e-05.
Completed iteration 475 with error 6.5651783245357365e-06.
Completed iteration 500 with error 3.463639430378862e-06.
Completed iteration 525 with error 1.827332347659194e-06.
Terminated successfully in 550 iterations.
Jax execution time = 5.303308963775635
z_mc = qe.MarkovChain(Q, z_vals)
def sim_inventories(ts_length, X_init=0):
"""Simulate given the optimal policy."""
global p, z_mc
i_z = z_mc.simulate_indices(ts_length, init=1)
X = np.zeros(ts_length, dtype=np.int32)
X[0] = X_init
rand = np.random.default_rng().geometric(p=p, size=ts_length-1) - 1
for t in range(ts_length-1):
X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], i_z[t]]
return X, z_vals[i_z]
def plot_ts(ts_length=400, fontsize=10):
X, Z = sim_inventories(ts_length)
fig, axes = plt.subplots(2, 1, figsize=(9, 5.5))
ax = axes[0]
ax.plot(X, label=r"$X_t$", alpha=0.7)
ax.set_xlabel(r"$t$", fontsize=fontsize)
ax.set_ylabel("inventory", fontsize=fontsize)
ax.legend(fontsize=fontsize, frameon=False)
ax.set_ylim(0, np.max(X)+3)
# calculate interest rate from discount factors
r = (1 / Z) - 1
ax = axes[1]
ax.plot(r, label=r"$r_t$", alpha=0.7)
ax.set_xlabel(r"$t$", fontsize=fontsize)
ax.set_ylabel("interest rate", fontsize=fontsize)
ax.legend(fontsize=fontsize, frameon=False)
plt.tight_layout()
plt.show()
15.3. Numba implementation#
Let’s try the same operations in Numba in order to compare the speed.
@numba.njit
def demand_pdf_numba(p, d):
return (1 - p)**d * p
@numba.njit
def B_numba(x, i_z, a, v, model):
"""
The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′).
"""
c, κ, p, z_vals, Q = model
z = z_vals[i_z]
d_vals = np.arange(D_MAX)
ϕ_vals = demand_pdf_numba(p, d_vals)
revenue = np.sum(np.minimum(x, d_vals)*ϕ_vals)
profit = revenue - c * a - κ * (a > 0)
v_R = np.sum(v[np.maximum(x - d_vals, 0) + a].T * ϕ_vals, axis=1)
cv = np.sum(v_R*Q[i_z])
return profit + z * cv
@numba.njit(parallel=True)
def T_numba(v, model):
"""The Bellman operator."""
c, κ, p, z_vals, Q = model
new_v = np.empty_like(v)
for i_z in prange(len(z_vals)):
for x in prange(K+1):
v_1 = np.array([B_numba(x, i_z, a, v, model)
for a in range(K-x+1)])
new_v[x, i_z] = np.max(v_1)
return new_v
@numba.njit(parallel=True)
def get_greedy_numba(v, model):
"""Get a v-greedy policy. Returns a zero-based array."""
c, κ, p, z_vals, Q = model
n_z = len(z_vals)
σ_star = np.zeros((K+1, n_z), dtype=np.int32)
for i_z in prange(n_z):
for x in range(K+1):
v_1 = np.array([B_numba(x, i_z, a, v, model)
for a in range(K-x+1)])
σ_star[x, i_z] = np.argmax(v_1)
return σ_star
def solve_inventory_model_numba(v_init, model):
"""Use successive_approx to get v_star and then compute greedy."""
v_star = successive_approx(lambda v: T_numba(v, model), v_init, verbose=True)
σ_star = get_greedy_numba(v_star, model)
return v_star, σ_star
model = create_sdd_inventory_model(use_jax=False)
c, κ, p, z_vals, Q = model
n_z = len(z_vals)
v_init = np.zeros((K + 1, n_z), dtype=float)
start = time()
v_star_numba, σ_star_numba = solve_inventory_model_numba(v_init, model)
numba_time_with_compile = time() - start
print("Numba compile plus execution time = ", numba_time_with_compile)
Completed iteration 25 with error 0.5613828428334706.
Completed iteration 50 with error 0.37764643476879556.
Completed iteration 75 with error 0.22727062359689398.
Completed iteration 100 with error 0.12872204940708798.
Completed iteration 125 with error 0.06744149371262864.
Completed iteration 150 with error 0.030374639547666504.
Completed iteration 175 with error 0.01423099032948727.
Completed iteration 200 with error 0.007396776219316337.
Completed iteration 225 with error 0.003912238304593529.
Completed iteration 250 with error 0.002068091416653317.
Completed iteration 275 with error 0.0010923075333622023.
Completed iteration 300 with error 0.0005766427105911021.
Completed iteration 325 with error 0.0003043321707281166.
Completed iteration 350 with error 0.00016059073676188973.
Completed iteration 375 with error 8.473334525405107e-05.
Completed iteration 400 with error 4.470604518047594e-05.
Completed iteration 425 with error 2.3586619960269672e-05.
Completed iteration 450 with error 1.2443945934137446e-05.
Completed iteration 475 with error 6.565178331641164e-06.
Completed iteration 500 with error 3.4636394445897167e-06.
Completed iteration 525 with error 1.827332347659194e-06.
Terminated successfully in 550 iterations.
Numba compile plus execution time = 1015.8022680282593
Let’s run again to eliminate the compile time.
start = time()
v_star_numba, σ_star_numba = solve_inventory_model_numba(v_init, model)
numba_time_without_compile = time() - start
print("Numba execution time = ", numba_time_without_compile)
Completed iteration 25 with error 0.5613828428334706.
Completed iteration 50 with error 0.37764643476879556.
Completed iteration 75 with error 0.22727062359689398.
Completed iteration 100 with error 0.12872204940708798.
Completed iteration 125 with error 0.06744149371262864.
Completed iteration 150 with error 0.030374639547666504.
Completed iteration 175 with error 0.01423099032948727.
Completed iteration 200 with error 0.007396776219316337.
Completed iteration 225 with error 0.003912238304593529.
Completed iteration 250 with error 0.002068091416653317.
Completed iteration 275 with error 0.0010923075333622023.
Completed iteration 300 with error 0.0005766427105911021.
Completed iteration 325 with error 0.0003043321707281166.
Completed iteration 350 with error 0.00016059073676188973.
Completed iteration 375 with error 8.473334525405107e-05.
Completed iteration 400 with error 4.470604518047594e-05.
Completed iteration 425 with error 2.3586619960269672e-05.
Completed iteration 450 with error 1.2443945934137446e-05.
Completed iteration 475 with error 6.565178331641164e-06.
Completed iteration 500 with error 3.4636394445897167e-06.
Completed iteration 525 with error 1.827332347659194e-06.
Terminated successfully in 550 iterations.
Numba execution time = 965.9102244377136
Let’s verify that the Numba and JAX implementations converge to the same solution.
np.allclose(v_star_numba, v_star)
True
Here’s the speed comparison.
print("JAX vectorized implementation is "
f"{numba_time_without_compile/jax_time_without_compile} faster "
"than Numba's parallel implementation")
JAX vectorized implementation is 182.13350024209112 faster than Numba's parallel implementation