5. Inventory Dynamics#

GPU

This lecture was built using hardware that has access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

5.1. Overview#

This lecture explores JAX implementations of the exercises in the lecture on inventory dynamics.

We will use the following imports:

import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, lax
from collections import namedtuple

Let’s check the GPU we are running

!nvidia-smi
Fri Sep 22 00:28:03 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 12.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   29C    P0    38W / 300W |      0MiB / 16160MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

5.2. Sample paths#

Consider a firm with inventory \(X_t\).

The firm waits until \(X_t \leq s\) and then restocks up to \(S\) units.

It faces stochastic demand \(\{ D_t \}\), which we assume is IID.

With notation \(a^+ := \max\{a, 0\}\), inventory dynamics can be written as

\[\begin{split} X_{t+1} = \begin{cases} ( S - D_{t+1})^+ & \quad \text{if } X_t \leq s \\ ( X_t - D_{t+1} )^+ & \quad \text{if } X_t > s \end{cases} \end{split}\]

(See our earlier lecture on inventory dynamics for background and motivation.)

In what follows, we will assume that each \(D_t\) is lognormal, so that

\[ D_t = \exp(\mu + \sigma Z_t) \]

where \(\mu\) and \(\sigma\) are parameters and \(\{Z_t\}\) is IID and standard normal.

Here’s a namedtuple that stores parameters.

Firm = namedtuple('Firm', ['s', 'S', 'mu', 'sigma'])

firm = Firm(s=10, S=100, mu=1.0, sigma=0.5)

5.3. Example 1: marginal distributions#

Now let’s look at the marginal distribution \(\psi_T\) of \(X_T\) for some fixed \(T\).

We can approximate the distribution using a kernel density estimator.

Kernel density estimators can be thought of as smoothed histograms.

We will use a kernel density estimator from scikit-learn.

Here is an example of using kernel density estimators and plotting the result

from sklearn.neighbors import KernelDensity

def plot_kde(sample, ax, label=''):
    xmin, xmax = 0.9 * min(sample), 1.1 * max(sample)
    xgrid = np.linspace(xmin, xmax, 200)
    kde = KernelDensity(kernel='gaussian').fit(sample[:, None])
    log_dens = kde.score_samples(xgrid[:, None])

    ax.plot(xgrid, np.exp(log_dens), label=label)

# Generate simulated data
np.random.seed(42)
sample_1 = np.random.normal(0, 2, size=10_000)
sample_2 = np.random.gamma(2, 2, size=10_000)

# Create a plot
fig, ax = plt.subplots()

# Plot the samples
ax.hist(sample_1, alpha=0.2, density=True, bins=50)
ax.hist(sample_2, alpha=0.2, density=True, bins=50)

# Plot the KDE for each sample
plot_kde(sample_1, ax, label=r'KDE over $X \sim N(0, 2)$')
plot_kde(sample_2, ax, label=r'KDE over $X \sim Gamma(0, 2)$')
ax.set_xlabel('value')
ax.set_ylabel('density')
ax.set_xlim([-5, 10])
ax.set_title('KDE of Simulated Normal and Gamma Data')
ax.legend()
plt.show()
_images/2dff78f9d10198effda3d05de17138875f57cdc5e86a1611cd1372c678718ecc.png

This model for inventory dynamics is asymptotically stationary, with a unique stationary distribution.

In particular, the sequence of marginal distributions \(\{\psi_t\}\) converges to a unique limiting distribution that does not depend on initial conditions.

Although we will not prove this here, we can investigate it using simulation.

We can generate and plot the sequence \(\{\psi_t\}\) at times \(t = 10, 50, 250, 500, 750\) based on the kernel density estimator.

We will see convergence, in the sense that differences between successive distributions are getting smaller.

Here is one realization of the process in JAX using for loop

# Define a jit-compiled function to update X and key
@jax.jit
def update_X(X, firm, D):
    # Restock if the inventory is below the threshold
    res = jnp.where(X <= firm.s,
            jnp.maximum(firm.S - D, 0),
            jnp.maximum(X - D, 0))
    return res


def shift_firms_forward(x_init, firm, sample_dates,
                        key, num_firms=50_000, sim_length=750):

    X = res = jnp.full((num_firms, ), x_init)

    # Use for loop to update X and collect samples
    for i in range(sim_length):
        Z = random.normal(key, shape=(num_firms, ))
        D = jnp.exp(firm.mu + firm.sigma * Z)

        X = update_X(X, firm, D)
        _, key = random.split(key)

        # draw a sample at the sample dates
        if (i+1 in sample_dates):
          res = jnp.vstack((res, X))

    return res[1:]
x_init = 50
num_firms = 50_000
sample_dates = 10, 50, 250, 500, 750
key = random.PRNGKey(10)

fig, ax = plt.subplots()

%time X = shift_firms_forward(x_init, firm, \
                              sample_dates, key).block_until_ready()

for i, date in enumerate(sample_dates):
    plot_kde(X[i, :], ax, label=f't = {date}')

ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
CPU times: user 2.8 s, sys: 492 ms, total: 3.3 s
Wall time: 2.91 s
_images/741901a18debea0ae9eb6721eacc097a246709cd88dc3db718970f7a2d18b1dd.png

Note that we did not JIT-compile the outer loop, since

  1. jit compilation of the for loop can be very time consuming and

  2. compiling outer loops only leads to minor speed gains.

5.3.1. Alternative implementation with lax.scan#

An alternative to the for loop implementation is lax.scan.

Here is an example of the same function in lax.scan

@jax.jit
def shift_firms_forward(x_init, firm, key,
                        num_firms=50_000, sim_length=750):

    s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
    X = jnp.full((num_firms, ), x_init)
    Z = random.normal(key, shape=(sim_length, num_firms))
    D = jnp.exp(mu + sigma * Z)

    # Define the function for each update
    def update_X(X, D):
        res = jnp.where(X <= s,
                  jnp.maximum(S - D, 0),
                  jnp.maximum(X - D, 0))
        return res, res

    # Use lax.scan to perform the calculations on all states
    _, X_final = lax.scan(update_X, X, D)

    return X_final

The benefit of the lax.scan implementation is that we compile the whole operation.

The disadvantages are that

  1. as mentioned above, there are only limited speed gains in accelerating outer loops,

  2. lax.scan has a more complicated syntax, and, most importantly,

  3. the lax.scan implementation consumes far more memory, as we need to have to store large matrices of random draws

Let’s call the code to generate a cross-section that is in approximate equilibrium.

fig, ax = plt.subplots()

%time X = shift_firms_forward(x_init, firm, key).block_until_ready()

for date in sample_dates:
    plot_kde(X[date, :], ax, label=f't = {date}')

ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
CPU times: user 306 ms, sys: 0 ns, total: 306 ms
Wall time: 275 ms
_images/db5bcba37f118e95a7a6a19f8e7e4471f5d4d01c648f8f83a6842c416b6c6b0d.png

Notice that by \(t=500\) or \(t=750\) the densities are barely changing.

We have reached a reasonable approximation of the stationary density.

You can test a few more initial conditions to show that they do not affect long-run outcomes.

For example, try rerunning the code above with all firms starting at \(X_0 = 20\)

x_init = 20.0

fig, ax = plt.subplots()

%time X = shift_firms_forward(x_init, firm, key).block_until_ready()

for date in sample_dates:
    plot_kde(X[date, :], ax, label=f't = {date}')

ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
CPU times: user 283 ms, sys: 0 ns, total: 283 ms
Wall time: 229 ms
_images/a89ccee605d93748ead554bd7ccb751bfce2ba1a82dd8fe4279d661333134230.png

5.4. Example 2: restock frequency#

Let’s go through another example where we calculate the probability of firms having restocks.

Specifically we set the starting stock level to 70 (\(X_0 = 70\)), as we calculate the proportion of firms that need to order twice or more in the first 50 periods.

You will need a large sample size to get an accurate reading.

Again, we start with an easier for loop implementation

# Define a jitted function for each update
@jax.jit
def update_stock(n_restock, X, firm, D):
    n_restock = jnp.where(X <= firm.s,
                          n_restock + 1,
                          n_restock)
    X = jnp.where(X <= firm.s,
                  jnp.maximum(firm.S - D, 0),
                  jnp.maximum(X - D, 0))
    return n_restock, X, key

def compute_freq(firm, key,
                 x_init=70,
                 sim_length=50,
                 num_firms=1_000_000):

    # Prepare initial arrays
    X = jnp.full((num_firms, ), x_init)

    # Stack the restock counter on top of the inventory
    n_restock = jnp.zeros((num_firms, ))

    # Use a for loop to perform the calculations on all states
    for i in range(sim_length):
        Z = random.normal(key, shape=(num_firms, ))
        D = jnp.exp(firm.mu + firm.sigma * Z)
        n_restock, X, key = update_stock(
            n_restock, X, firm, D)
        key = random.fold_in(key, i)

    return jnp.mean(n_restock > 1, axis=0)
key = random.PRNGKey(27)
%time freq = compute_freq(firm, key).block_until_ready()
print(f"Frequency of at least two stock outs = {freq}")
CPU times: user 741 ms, sys: 0 ns, total: 741 ms
Wall time: 1.1 s
Frequency of at least two stock outs = 0.4472379982471466

5.4.1. Alternative implementation with lax.fori_loop#

Now let’s write a lax.fori_loop version that JIT compiles the whole function

@jax.jit
def compute_freq(firm, key,
                 x_init=70,
                 sim_length=50,
                 num_firms=1_000_000):

    s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
    # Prepare initial arrays
    X = jnp.full((num_firms, ), x_init)
    Z = random.normal(key, shape=(sim_length, num_firms))
    D = jnp.exp(mu + sigma * Z)

    # Stack the restock counter on top of the inventory
    restock_count = jnp.zeros((num_firms, ))
    Xs = (X, restock_count)

    # Define the function for each update
    def update_X(i, Xs):
        # Separate the inventory and restock counter
        x, restock_count = Xs[0], Xs[1]
        restock_count = jnp.where(x <= s,
                                restock_count + 1,
                                restock_count)
        x = jnp.where(x <= s,
                      jnp.maximum(S - D[i], 0),
                      jnp.maximum(x - D[i], 0))

        Xs = (x, restock_count)
        return Xs

    # Use lax.fori_loop to perform the calculations on all states
    X_final = lax.fori_loop(0, sim_length, update_X, Xs)

    return jnp.mean(X_final[1] > 1)

Note the time the routine takes to run, as well as the output

%time freq = compute_freq(firm, key).block_until_ready()
print(f"Frequency of at least two stock outs = {freq}")
CPU times: user 380 ms, sys: 0 ns, total: 380 ms
Wall time: 334 ms
Frequency of at least two stock outs = 0.44674399495124817