# 5. Inventory Dynamics#

GPU

This lecture was built using hardware that has access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

Contents

## 5.1. Overview#

This lecture explores JAX implementations of the exercises in the lecture on inventory dynamics.

We will use the following imports:

```
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, lax
from collections import namedtuple
```

Let’s check the GPU we are running

```
!nvidia-smi
```

```
Thu Dec 21 00:47:43 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03 Driver Version: 470.182.03 CUDA Version: 12.1 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
```

```
| 0 Tesla V100-SXM2... Off | 00000000:00:1E.0 Off | 0 |
| N/A 31C P0 39W / 300W | 0MiB / 16160MiB | 2% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
```

## 5.2. Sample paths#

Consider a firm with inventory \(X_t\).

The firm waits until \(X_t \leq s\) and then restocks up to \(S\) units.

It faces stochastic demand \(\{ D_t \}\), which we assume is IID.

With notation \(a^+ := \max\{a, 0\}\), inventory dynamics can be written as

(See our earlier lecture on inventory dynamics for background and motivation.)

In what follows, we will assume that each \(D_t\) is lognormal, so that

where \(\mu\) and \(\sigma\) are parameters and \(\{Z_t\}\) is IID and standard normal.

Here’s a `namedtuple`

that stores parameters.

```
Firm = namedtuple('Firm', ['s', 'S', 'mu', 'sigma'])
firm = Firm(s=10, S=100, mu=1.0, sigma=0.5)
```

## 5.3. Example 1: marginal distributions#

Now let’s look at the marginal distribution \(\psi_T\) of \(X_T\) for some fixed \(T\).

We can approximate the distribution using a kernel density estimator.

Kernel density estimators can be thought of as smoothed histograms.

We will use a kernel density estimator from scikit-learn.

Here is an example of using kernel density estimators and plotting the result

```
from sklearn.neighbors import KernelDensity
def plot_kde(sample, ax, label=''):
xmin, xmax = 0.9 * min(sample), 1.1 * max(sample)
xgrid = np.linspace(xmin, xmax, 200)
kde = KernelDensity(kernel='gaussian').fit(sample[:, None])
log_dens = kde.score_samples(xgrid[:, None])
ax.plot(xgrid, np.exp(log_dens), label=label)
# Generate simulated data
np.random.seed(42)
sample_1 = np.random.normal(0, 2, size=10_000)
sample_2 = np.random.gamma(2, 2, size=10_000)
# Create a plot
fig, ax = plt.subplots()
# Plot the samples
ax.hist(sample_1, alpha=0.2, density=True, bins=50)
ax.hist(sample_2, alpha=0.2, density=True, bins=50)
# Plot the KDE for each sample
plot_kde(sample_1, ax, label=r'KDE over $X \sim N(0, 2)$')
plot_kde(sample_2, ax, label=r'KDE over $X \sim Gamma(0, 2)$')
ax.set_xlabel('value')
ax.set_ylabel('density')
ax.set_xlim([-5, 10])
ax.set_title('KDE of Simulated Normal and Gamma Data')
ax.legend()
plt.show()
```

This model for inventory dynamics is asymptotically stationary, with a unique stationary distribution.

In particular, the sequence of marginal distributions \(\{\psi_t\}\) converges to a unique limiting distribution that does not depend on initial conditions.

Although we will not prove this here, we can investigate it using simulation.

We can generate and plot the sequence \(\{\psi_t\}\) at times \(t = 10, 50, 250, 500, 750\) based on the kernel density estimator.

We will see convergence, in the sense that differences between successive distributions are getting smaller.

Here is one realization of the process in JAX using `for`

loop

```
# Define a jit-compiled function to update X and key
@jax.jit
def update_X(X, firm, D):
# Restock if the inventory is below the threshold
res = jnp.where(X <= firm.s,
jnp.maximum(firm.S - D, 0),
jnp.maximum(X - D, 0))
return res
def shift_firms_forward(x_init, firm, sample_dates,
key, num_firms=50_000, sim_length=750):
X = res = jnp.full((num_firms, ), x_init)
# Use for loop to update X and collect samples
for i in range(sim_length):
Z = random.normal(key, shape=(num_firms, ))
D = jnp.exp(firm.mu + firm.sigma * Z)
X = update_X(X, firm, D)
_, key = random.split(key)
# draw a sample at the sample dates
if (i+1 in sample_dates):
res = jnp.vstack((res, X))
return res[1:]
```

```
x_init = 50
num_firms = 50_000
sample_dates = 10, 50, 250, 500, 750
key = random.PRNGKey(10)
fig, ax = plt.subplots()
%time X = shift_firms_forward(x_init, firm, \
sample_dates, key).block_until_ready()
for i, date in enumerate(sample_dates):
plot_kde(X[i, :], ax, label=f't = {date}')
ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
```

Note that we did not JIT-compile the outer loop, since

`jit`

compilation of the`for`

loop can be very time consuming andcompiling outer loops only leads to minor speed gains.

### 5.3.1. Alternative implementation with `lax.scan`

#

An alternative to the `for`

loop implementation is `lax.scan`

.

Here is an example of the same function in `lax.scan`

```
@jax.jit
def shift_firms_forward(x_init, firm, key,
num_firms=50_000, sim_length=750):
s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
X = jnp.full((num_firms, ), x_init)
Z = random.normal(key, shape=(sim_length, num_firms))
D = jnp.exp(mu + sigma * Z)
# Define the function for each update
def update_X(X, D):
res = jnp.where(X <= s,
jnp.maximum(S - D, 0),
jnp.maximum(X - D, 0))
return res, res
# Use lax.scan to perform the calculations on all states
_, X_final = lax.scan(update_X, X, D)
return X_final
```

The benefit of the `lax.scan`

implementation is that we compile the whole
operation.

The disadvantages are that

as mentioned above, there are only limited speed gains in accelerating outer loops,

`lax.scan`

has a more complicated syntax, and, most importantly,the

`lax.scan`

implementation consumes far more memory, as we need to have to store large matrices of random draws

Let’s call the code to generate a cross-section that is in approximate equilibrium.

```
fig, ax = plt.subplots()
%time X = shift_firms_forward(x_init, firm, key).block_until_ready()
for date in sample_dates:
plot_kde(X[date, :], ax, label=f't = {date}')
ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
```

Notice that by \(t=500\) or \(t=750\) the densities are barely changing.

We have reached a reasonable approximation of the stationary density.

You can test a few more initial conditions to show that they do not affect long-run outcomes.

For example, try rerunning the code above with all firms starting at \(X_0 = 20\)

```
x_init = 20.0
fig, ax = plt.subplots()
%time X = shift_firms_forward(x_init, firm, key).block_until_ready()
for date in sample_dates:
plot_kde(X[date, :], ax, label=f't = {date}')
ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
```

## 5.4. Example 2: restock frequency#

Let’s go through another example where we calculate the probability of firms having restocks.

Specifically we set the starting stock level to 70 (\(X_0 = 70\)), as we calculate the proportion of firms that need to order twice or more in the first 50 periods.

You will need a large sample size to get an accurate reading.

Again, we start with an easier `for`

loop implementation

```
# Define a jitted function for each update
@jax.jit
def update_stock(n_restock, X, firm, D):
n_restock = jnp.where(X <= firm.s,
n_restock + 1,
n_restock)
X = jnp.where(X <= firm.s,
jnp.maximum(firm.S - D, 0),
jnp.maximum(X - D, 0))
return n_restock, X, key
def compute_freq(firm, key,
x_init=70,
sim_length=50,
num_firms=1_000_000):
# Prepare initial arrays
X = jnp.full((num_firms, ), x_init)
# Stack the restock counter on top of the inventory
n_restock = jnp.zeros((num_firms, ))
# Use a for loop to perform the calculations on all states
for i in range(sim_length):
Z = random.normal(key, shape=(num_firms, ))
D = jnp.exp(firm.mu + firm.sigma * Z)
n_restock, X, key = update_stock(
n_restock, X, firm, D)
key = random.fold_in(key, i)
return jnp.mean(n_restock > 1, axis=0)
```

```
key = random.PRNGKey(27)
%time freq = compute_freq(firm, key).block_until_ready()
print(f"Frequency of at least two stock outs = {freq}")
```

```
CPU times: user 836 ms, sys: 0 ns, total: 836 ms
Wall time: 1.21 s
Frequency of at least two stock outs = 0.4472379982471466
```

### 5.4.1. Alternative implementation with `lax.fori_loop`

#

Now let’s write a `lax.fori_loop`

version that JIT compiles the whole function

```
@jax.jit
def compute_freq(firm, key,
x_init=70,
sim_length=50,
num_firms=1_000_000):
s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
# Prepare initial arrays
X = jnp.full((num_firms, ), x_init)
Z = random.normal(key, shape=(sim_length, num_firms))
D = jnp.exp(mu + sigma * Z)
# Stack the restock counter on top of the inventory
restock_count = jnp.zeros((num_firms, ))
Xs = (X, restock_count)
# Define the function for each update
def update_X(i, Xs):
# Separate the inventory and restock counter
x, restock_count = Xs[0], Xs[1]
restock_count = jnp.where(x <= s,
restock_count + 1,
restock_count)
x = jnp.where(x <= s,
jnp.maximum(S - D[i], 0),
jnp.maximum(x - D[i], 0))
Xs = (x, restock_count)
return Xs
# Use lax.fori_loop to perform the calculations on all states
X_final = lax.fori_loop(0, sim_length, update_X, Xs)
return jnp.mean(X_final[1] > 1)
```

Note the time the routine takes to run, as well as the output

```
%time freq = compute_freq(firm, key).block_until_ready()
print(f"Frequency of at least two stock outs = {freq}")
```

```
CPU times: user 351 ms, sys: 0 ns, total: 351 ms
Wall time: 316 ms
Frequency of at least two stock outs = 0.44674399495124817
```