5. Inventory Dynamics#

GPU

This lecture was built using a machine with JAX installed and access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

5.1. Overview#

This lecture explores the inventory dynamics of a firm using so-called s-S inventory control.

Loosely speaking, this means that the firm

  • waits until inventory falls below some value \(s\)

  • and then restocks with a bulk order of \(S\) units (or, in some models, restocks up to level \(S\)).

We will be interested in the distribution of the associated Markov process, which can be thought of as cross-sectional distributions of inventory levels across a large number of firms, all of which

  1. evolve independently and

  2. have the same dynamics.

Note that we also studied this model in a separate lecture, using Numba.

Here we study the same problem using JAX.

We will use the following imports:

import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, lax
from collections import namedtuple

Here’s a description of our GPU:

!nvidia-smi
Tue Nov 19 23:57:31 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.6     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       On  | 00000001:00:00.0 Off |                    0 |
| N/A   43C    P0              28W /  70W |      2MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

5.2. Sample paths#

Consider a firm with inventory \(X_t\).

The firm waits until \(X_t \leq s\) and then restocks up to \(S\) units.

It faces stochastic demand \(\{ D_t \}\), which we assume is IID across time and firms.

With notation \(a^+ := \max\{a, 0\}\), inventory dynamics can be written as

\[\begin{split} X_{t+1} = \begin{cases} ( S - D_{t+1})^+ & \quad \text{if } X_t \leq s \\ ( X_t - D_{t+1} )^+ & \quad \text{if } X_t > s \end{cases} \end{split}\]

In what follows, we will assume that each \(D_t\) is lognormal, so that

\[ D_t = \exp(\mu + \sigma Z_t) \]

where \(\mu\) and \(\sigma\) are parameters and \(\{Z_t\}\) is IID and standard normal.

Here’s a namedtuple that stores parameters.

Parameters = namedtuple('Parameters', ['s', 'S', 'μ', 'σ'])

# Create a default instance
params = Parameters(s=10, S=100, μ=1.0, σ=0.5)

5.3. Cross-sectional distributions#

Now let’s look at the marginal distribution \(\psi_T\) of \(X_T\) for some fixed \(T\).

The probability distribution \(\psi_T\) is the time \(T\) distribution of firm inventory levels implied by the model.

We will approximate this distribution by

  1. fixing \(n\) to be some large number, indicating the number of firms in the simulation,

  2. fixing \(T\), the time period we are interested in,

  3. generating \(n\) independent draws from some fixed distribution \(\psi_0\) that gives the initial cross-section of inventories for the \(n\) firms, and

  4. shifting this distribution forward in time \(T\) periods, updating each firm \(T\) times via the dynamics described above (independent of other firms).

We will then visualize \(\psi_T\) by histogramming the cross-section.

We will use the following code to update the cross-section of firms by one period.

@jax.jit
def update_cross_section(params, X_vec, D):
    """
    Update by one period a cross-section of firms with inventory levels given by
    X_vec, given the vector of demand shocks in D.

       * D[i] is the demand shock for firm i with current inventory X_vec[i]

    """
    # Unpack
    s, S = params.s, params.S
    # Restock if the inventory is below the threshold
    X_new = jnp.where(X_vec <= s, 
                      jnp.maximum(S - D, 0), jnp.maximum(X_vec - D, 0))
    return X_new

5.3.1. For loop version#

Now we provide code to compute the cross-sectional distribution \(\psi_T\) given some initial distribution \(\psi_0\) and a positive integer \(T\).

In this code we use an ordinary Python for loop to step forward through time

While Python loops are slow, this approach is reasonable here because efficiency of outer loops has far less influence on runtime than efficiency of inner loops.

(Below we will squeeze out more speed by compiling the outer loop as well as the update rule.)

In the code below, the initial distribution \(\psi_0\) takes all firms to have initial inventory x_init.

def compute_cross_section(params, x_init, T, key, num_firms=50_000):
    # Set up initial distribution
    X_vec = jnp.full((num_firms, ), x_init)
    # Loop
    for i in range(T):
        Z = random.normal(key, shape=(num_firms, ))
        D = jnp.exp(params.μ + params.σ * Z)

        X_vec = update_cross_section(params, X_vec, D)
        _, key = random.split(key)

    return X_vec

We’ll use the following specification

x_init = 50
T = 500
# Initialize random number generator
key = random.PRNGKey(10)

Let’s look at the timing.

%time X_vec = compute_cross_section(params, \
        x_init, T, key).block_until_ready()
CPU times: user 896 ms, sys: 172 ms, total: 1.07 s
Wall time: 816 ms

Let’s run again to eliminate compile time.

%time X_vec = compute_cross_section(params, \
        x_init, T, key).block_until_ready()
CPU times: user 407 ms, sys: 167 ms, total: 574 ms
Wall time: 285 ms

Here’s a histogram of inventory levels at time \(T\).

fig, ax = plt.subplots()
ax.hist(X_vec, bins=50, 
        density=True, 
        histtype='step', 
        label=f'cross-section when $t = {T}$')
ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
_images/47d2b277de30ed09f0290a5eb9df17089aa5ec4703a681c12a125915c8bea12c.png

5.3.2. Compiling the outer loop#

Now let’s see if we can gain some speed by compiling the outer loop, which steps through the time dimension.

We will do this using jax.jit and a fori_loop, which is a compiler-ready version of a for loop provided by JAX.

def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):

    s, S, μ, σ = params.s, params.S, params.μ, params.σ
    X = jnp.full((num_firms, ), x_init)

    # Define the function for each update
    def fori_update(t, inputs):
        # Unpack
        X, key = inputs
        # Draw shocks using key
        Z = random.normal(key, shape=(num_firms,))
        D = jnp.exp(μ + σ * Z)
        # Update X
        X = jnp.where(X <= s,
                  jnp.maximum(S - D, 0),
                  jnp.maximum(X - D, 0))
        # Refresh the key
        key, subkey = random.split(key)
        return X, subkey

    # Loop t from 0 to T, applying fori_update each time.
    # The initial condition for fori_update is (X, key).
    X, key = lax.fori_loop(0, T, fori_update, (X, key))

    return X

# Compile taking T and num_firms as static (changes trigger recompile)
compute_cross_section_fori = jax.jit(
    compute_cross_section_fori, static_argnums=(2, 4))

Let’s see how fast this runs with compile time.

%time X_vec = compute_cross_section_fori(params, \
                x_init, T, key).block_until_ready()
CPU times: user 365 ms, sys: 16.7 ms, total: 382 ms
Wall time: 333 ms

And let’s see how fast it runs without compile time.

%time X_vec = compute_cross_section_fori(params, \
                x_init, T, key).block_until_ready()
CPU times: user 8.54 ms, sys: 135 μs, total: 8.68 ms
Wall time: 10.7 ms

Compared to the original version with a pure Python outer loop, we have produced a nontrivial speed gain.

This is due to the fact that we have compiled the whole operation.

5.3.3. Further vectorization#

For relatively small problems, we can make this code run even faster by generating all random variables at once.

This improves efficiency because we are taking more operations out of the loop.

def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):

    s, S, μ, σ = params.s, params.S, params.μ, params.σ
    X = jnp.full((num_firms, ), x_init)
    Z = random.normal(key, shape=(T, num_firms))
    D = jnp.exp(μ + σ * Z)

    def update_cross_section(i, X):
        X = jnp.where(X <= s,
                  jnp.maximum(S - D[i, :], 0),
                  jnp.maximum(X - D[i, :], 0))
        return X

    X = lax.fori_loop(0, T, update_cross_section, X)

    return X

# Compile taking T and num_firms as static (changes trigger recompile)
compute_cross_section_fori = jax.jit(
    compute_cross_section_fori, static_argnums=(2, 4))

Let’s test it with compile time included.

%time X_vec = compute_cross_section_fori(params, \
                x_init, T, key).block_until_ready()
CPU times: user 242 ms, sys: 7.81 ms, total: 250 ms
Wall time: 240 ms

Let’s run again to eliminate compile time.

%time X_vec = compute_cross_section_fori(params, \
                x_init, T, key).block_until_ready()
CPU times: user 4.06 ms, sys: 2 μs, total: 4.06 ms
Wall time: 6.63 ms

On one hand, this version is faster than the previous one, where random variables were generated inside the loop.

On the other hand, this implementation consumes far more memory, as we need to store large arrays of random draws.

The high memory consumption becomes problematic for large problems.

5.4. Distribution dynamics#

Next let’s take a look at how the distribution sequence evolves over time.

We will go back to using ordinary Python for loops.

Here is code that repeatedly shifts the cross-section forward while recording the cross-section at the dates in sample_dates.

def shift_forward_and_sample(x_init, params, sample_dates,
                        key, num_firms=50_000, sim_length=750):

    X = res = jnp.full((num_firms, ), x_init)

    # Use for loop to update X and collect samples
    for i in range(sim_length):
        Z = random.normal(key, shape=(num_firms, ))
        D = jnp.exp(params.μ + params.σ * Z)

        X = update_cross_section(params, X, D)
        _, key = random.split(key)

        # draw a sample at the sample dates
        if (i+1 in sample_dates):
          res = jnp.vstack((res, X))

    return res[1:]

Let’s test it

x_init = 50
num_firms = 10_000
sample_dates = 10, 50, 250, 500, 750
key = random.PRNGKey(10)


%time X = shift_forward_and_sample(x_init, params, \
                    sample_dates, key).block_until_ready()
CPU times: user 817 ms, sys: 224 ms, total: 1.04 s
Wall time: 672 ms

We run the code again to eliminate compile time.

%time X = shift_forward_and_sample(x_init, params, \
                    sample_dates, key).block_until_ready()
CPU times: user 625 ms, sys: 241 ms, total: 866 ms
Wall time: 431 ms

Let’s plot the output.

fig, ax = plt.subplots()

for i, date in enumerate(sample_dates):
    ax.hist(X[i, :], bins=50, 
            density=True, 
            histtype='step',
            label=f'cross-section when $t = {date}$')

ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
_images/4e0be20f6a4d800f6bf45e9233bd340c03c9162bc70610f16f8abb3d6b065e36.png

This model for inventory dynamics is asymptotically stationary, with a unique stationary distribution.

In particular, the sequence of marginal distributions \(\{\psi_t\}\) converges to a unique limiting distribution that does not depend on initial conditions.

Although we will not prove this here, we can see it in the simulation above.

By \(t=500\) or \(t=750\) the distributions are barely changing.

If you test a few different initial conditions, you will see that they do not affect long-run outcomes.

5.5. Restock frequency#

As an exercise, let’s study the probability that firms need to restock over a given time period.

In the exercise, we will

  • set the starting stock level to \(X_0 = 70\) and

  • calculate the proportion of firms that need to order twice or more in the first 50 periods.

This proportion approximates the probability of the event when the sample size is large.

5.5.1. For loop version#

We start with an easier for loop implementation

# Define a jitted function for each update
@jax.jit
def update_stock(n_restock, X, params, D):
    n_restock = jnp.where(X <= params.s,
                          n_restock + 1,
                          n_restock)
    X = jnp.where(X <= params.s,
                  jnp.maximum(params.S - D, 0),
                  jnp.maximum(X - D, 0))
    return n_restock, X, key

def compute_freq(params, key,
                 x_init=70,
                 sim_length=50,
                 num_firms=1_000_000):

    # Prepare initial arrays
    X = jnp.full((num_firms, ), x_init)

    # Stack the restock counter on top of the inventory
    n_restock = jnp.zeros((num_firms, ))

    # Use a for loop to perform the calculations on all states
    for i in range(sim_length):
        Z = random.normal(key, shape=(num_firms, ))
        D = jnp.exp(params.μ + params.σ * Z)
        n_restock, X, key = update_stock(
            n_restock, X, params, D)
        key = random.fold_in(key, i)

    return jnp.mean(n_restock > 1, axis=0)
key = random.PRNGKey(27)

%time freq = compute_freq(params, key).block_until_ready()
CPU times: user 595 ms, sys: 42.1 ms, total: 637 ms
Wall time: 690 ms

We run the code again to get rid of compile time.

%time freq = compute_freq(params, key).block_until_ready()
CPU times: user 53.8 ms, sys: 25.6 ms, total: 79.4 ms
Wall time: 47.6 ms
print(f"Frequency of at least two stock outs = {freq}")
Frequency of at least two stock outs = 0.4472379982471466

Exercise 5.1

Write a fori_loop version of the last function. See if you can increase the speed while generating a similar answer.