24. Policy Gradient-Based Optimal Savings#

GPU

This lecture was built using a machine with JAX installed and access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

24.1. Introduction#

In this notebook we solve infinite horizon optimal savings problems using deep learning and policy gradient ascent with JAX.

Each policy is represented as a fully connected feed forward neural network.

We begin with a cake eating problem with a known analytical solution.

Then we shift to an income fluctuation problem where we can compute an optimal policy easily with the endogenous grid method.

We do this first and then try to learn the same policy with deep learning.

Throughout, utility takes the CRRA form \(u(c) = c^{1-\gamma} / (1-\gamma)\) and the discount factor is \(\beta\).

We’ll use the following libraries

!pip install optax

Hide code cell output

Collecting optax
  Downloading optax-0.2.6-py3-none-any.whl.metadata (7.6 kB)
Collecting absl-py>=0.7.1 (from optax)
  Downloading absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)
Collecting chex>=0.1.87 (from optax)
  Downloading chex-0.1.91-py3-none-any.whl.metadata (18 kB)
Requirement already satisfied: jax>=0.5.3 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from optax) (0.8.1)
Requirement already satisfied: jaxlib>=0.5.3 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from optax) (0.8.1)
Requirement already satisfied: numpy>=1.18.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from optax) (2.1.3)
Collecting typing_extensions>=4.15.0 (from chex>=0.1.87->optax)
  Downloading typing_extensions-4.15.0-py3-none-any.whl.metadata (3.3 kB)
Requirement already satisfied: toolz>=1.0.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from chex>=0.1.87->optax) (1.0.0)
Requirement already satisfied: ml_dtypes>=0.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax>=0.5.3->optax) (0.5.4)
Requirement already satisfied: opt_einsum in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax>=0.5.3->optax) (3.4.0)
Requirement already satisfied: scipy>=1.13 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax>=0.5.3->optax) (1.15.3)
Downloading optax-0.2.6-py3-none-any.whl (367 kB)
Downloading absl_py-2.3.1-py3-none-any.whl (135 kB)
Downloading chex-0.1.91-py3-none-any.whl (100 kB)
Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
Installing collected packages: typing_extensions, absl-py, chex, optax
  Attempting uninstall: typing_extensions
    Found existing installation: typing_extensions 4.12.2
    Uninstalling typing_extensions-4.12.2:
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/4 [typing_extensions]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/4 [typing_extensions]
      Successfully uninstalled typing_extensions-4.12.2
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/4 [typing_extensions]
   ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/4 [absl-py]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━ 3/4 [optax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/4 [optax]
Successfully installed absl-py-2.3.1 chex-0.1.91 optax-0.2.6 typing_extensions-4.15.0

We’ll use the following imports

import jax
import jax.numpy as jnp
from jax import grad, jit, random
import optax
import matplotlib.pyplot as plt
from functools import partial
from typing import NamedTuple

24.2. Cake Eating Case#

With \(R\) as the gross interest rate, assets evolve according to

\[ a' = R (a - c) \]

To ensure stability we need \(\beta R^{1-\gamma} < 1\).

For this model, it is known that the optimal policy is \(c = \kappa a\), where

\[ \kappa := 1 - [\beta R^{1-\gamma}]^{1/\gamma} \]

We use this known exact solution to check our numerical methods.

Initial assets \(a_0\) is fixed at 1.0, so the objective function is

\[ \max_{\sigma \in \Sigma} v_\sigma(a_0) \quad \text{with} \quad a_0 := 1.0 \]

Here

  • \(\Sigma\) is the set of all feasible policies and

  • \(v_\sigma(a)\) is the lifetime value of following stationary policy \(\sigma\), given initial assets \(a\).

24.3. Set up#

We use a class called CakeEatingModel to store model parameters.

class CakeEatingModel(NamedTuple):
    """
    Stores parameters for the model.

    """
    γ: float = 1.5
    β: float = 0.96
    R: float = 1.01

We use a class called LayerParams to store parameters representing a single layer of the neural network.

class LayerParams(NamedTuple):
    """
    Stores parameters for one layer of the neural network.

    """
    W: jnp.ndarray     # weights
    b: jnp.ndarray     # biases

The next class stores some fixed values that form part of the network training configuration.

class Config:
    """
    Configuration and parameters for training the neural network.

    """
    seed = 42                           # Seed for network initialization
    epochs = 400                        # No of training epochs
    path_length = 320                   # Length of each consumption path
    layer_sizes = 1, 6, 6, 6, 6, 6, 1   # Network layer sizes
    learning_rate = 0.001               # Constant learning rate

The following function initializes a single layer of the network using Le Cun initialization.

(Le Cun initialization is thought to pair well with selu activation.)

def initialize_layer(in_dim, out_dim, key):
    """
    Initialize weights and biases for a single layer of a the network.
    Use LeCun initialization.

    """
    s = jnp.sqrt(1.0 / in_dim)
    W = jax.random.normal(key, (in_dim, out_dim)) * s
    b = jnp.zeros((out_dim,))
    return LayerParams(W, b)

The next function builds an entire network, as represented by its parameters, by initializing layers and stacking them into a list.

def initialize_network(key, layer_sizes):
    """
    Build a network by initializing all of the parameters.
    A network is a list of LayerParams instances, each 
    containing a weight-bias pair (W, b).

    """
    params = []
    # For all layers but the output layer
    for i in range(len(layer_sizes) - 1):
        # Build the layer 
        key, subkey = jax.random.split(key)
        layer = initialize_layer(
            layer_sizes[i],      # in dimension for layer
            layer_sizes[i + 1],  # out dimension for layer
            subkey 
        )
        # And add it to the parameter list
        params.append(layer)

    return params

Now we provide a function to do a forward pass through the network, given the parameters.

def forward(params, a):
    """
    Evaluate neural network policy: maps a given asset level a to
    consumption rate c/a by running a forward pass through the network.

    """
    σ = jax.nn.selu          # Activation function
    x = jnp.array((a,))      # Make state a 1D array
    # Forward pass through network, without the last step
    for W, b in params[:-1]:
        x = σ(x @ W + b)
    # Complete with sigmoid activation for consumption rate
    W, b = params[-1]
    # Direct output in [0, 0.99] range for stability
    x = jax.nn.sigmoid(x @ W + b) * 0.99 
    # Extract and return consumption rate
    consumption_rate = x[0]
    return consumption_rate

We use CRRA utility.

def u(c, γ):
    """ Utility function. """
    c = jnp.maximum(c, 1e-10)
    return c**(1 - γ) / (1 - γ)

The next function approximates lifetime value associated with a given policy, as represented by the parameters of a neural network.

@partial(jax.jit, static_argnames=('path_length'))
def compute_lifetime_value(params, model, path_length):
    """
    Compute the lifetime value of a path generated from
    the policy embedded in params and the initial condition a_0 = 1.

    """
    γ, β, R = model.γ, model.β, model.R
    initial_a = 1.0

    def update(t, state):
        # Unpack and compute consumption given current assets
        a, value, discount = state
        consumption_rate = forward(params, a)
        c = consumption_rate * a
        # Update loop state and return it
        a = R * (a - c)
        value = value + discount * u(c, γ)
        discount = discount * β
        new_state = a, value, discount
        return new_state

    initial_value, initial_discount = 0.0, 1.0
    initial_state = initial_a, initial_value, initial_discount
    final_a, final_value, discount = jax.lax.fori_loop(
        0, path_length, update, initial_state
    )
    return final_value

Here’s the loss function we will minimize.

def loss_function(params, model, path_length):
    """
    Loss is the negation of the lifetime value of the policy 
    identified by `params`.

    """
    return -compute_lifetime_value(params, model, path_length)

24.4. Train and solve#

First we create an instance of the model and unpack names

model = CakeEatingModel()
γ, β, R = model.γ, model.β, model.R
seed, epochs = Config.seed, Config.epochs
path_length = Config.path_length
layer_sizes = Config.layer_sizes

We test stability.

assert β * R**(1 - γ) < 1, "Parameters fail stability test."

We compute the optimal consumption rate and lifetime value from the analytical expressions.

κ = 1 - (β * R**(1 - γ))**(1/γ)
print(f"Optimal consumption rate = {κ}.\n")
v_max = κ**(-γ) * u(1.0, γ)
print(f"Theoretical maximum lifetime value = {v_max}.\n")
Optimal consumption rate = 0.03007006297501369.
Theoretical maximum lifetime value = -383.5557556152344.
W1129 00:49:12.314425    2034 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1129 00:49:12.318021    1973 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.

Let’s now set up the Optax minimizer, using Adam with a constant learning rate.

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),  # Gradient clipping for stability
    optax.adam(learning_rate=Config.learning_rate)
)

We initialize the parameters in the neural network and the state of the optimizer.

key = random.PRNGKey(seed)
params = initialize_network(key, layer_sizes)
opt_state = optimizer.init(params)

Now let’s train the network.

value_history = []
best_value = -jnp.inf
best_params = params
for i in range(epochs):

    # Compute value and gradients at existing parameterization
    loss, grads = jax.value_and_grad(loss_function)(params, model, path_length)
    lifetime_value = - loss
    value_history.append(lifetime_value)

    # Track best parameters
    if lifetime_value > best_value:
        best_value = lifetime_value
        best_params = params

    # Update parameters using optimizer
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    if i % 100 == 0:
        print(f"Iteration {i}: Value = {lifetime_value:.4f}")

# Use best parameters instead of final
params = best_params
print(f"\nBest value: {best_value:.4f}")
print(f"Final value: {value_history[-1]:.4f}")
Iteration 0: Value = -1495089.7500
Iteration 100: Value = -1252.4659
Iteration 200: Value = -383.5421
Iteration 300: Value = -383.5356

Best value: -383.5334
Final value: -383.5335

First we plot the evolution of lifetime value over the epochs.

# Plot learning progress
fig, ax = plt.subplots()
ax.plot(value_history, 'b-', linewidth=2)
ax.set_xlabel('iteration')
ax.set_ylabel('policy value')
ax.set_title('learning progress')
plt.show()

Next we compare the learned and optimal policies.

a_grid = jnp.linspace(0.01, 1.0, 1000)
policy_vmap = jax.vmap(lambda a: forward(params, a))
consumption_rate = policy_vmap(a_grid)
# Compute actual consumption: c = (c/a) * a
c_learned = consumption_rate * a_grid
c_optimal = κ * a_grid

fig, ax = plt.subplots()
ax.plot(a_grid, c_learned, linestyle='--', lw=4, label='learned policy')
ax.plot(a_grid, c_optimal, lw=2, label='optimal')
ax.set_xlabel('assets')
ax.set_ylabel('consumption')
ax.set_title('Consumption policy')
ax.legend()
plt.show()

Let’s have a look at paths for consumption and assets under the learned and optimal policies.

The figures below show that the learned policies are close to optimal.

def simulate_consumption_path(params, T=120):
    """
    Compute consumption path using neural network policy identified by params.

    """
    a_sim = [1.0]   # 1.0 is the initial assets
    c_sim = []
    a_opt = [1.0]
    c_opt = []

    a = 1.0
    for t in range(T):

        # Update policy path - forward returns consumption rate
        c = forward(params, a) * a
        c_sim.append(float(c))
        a = R * (a - c)
        a_sim.append(float(a))

        if a <= 1e-10:
            break

    a = 1.0
    for t in range(T):

        # Update optimal path
        c = κ * a
        c_opt.append(c)
        a = R * (a - c)
        a_opt.append(a)

        if a <= 1e-10:
            break

    return a_sim, c_sim, a_opt, c_opt
# Simulate and plot path
a_sim, c_sim, a_opt, c_opt = simulate_consumption_path(params)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.plot(a_sim, lw=4, linestyle='--', label='learned policy')
ax1.plot(a_opt, lw=2, label='optimal')
ax1.set_xlabel('Time')
ax1.set_ylabel('Assets')
ax1.set_title('Assets over time')
ax1.legend()

ax2.plot(c_sim, lw=4, linestyle='--', label='learned policy')
ax2.plot(c_opt, lw=2, label='optimal')
ax2.set_xlabel('Time')
ax2.set_ylabel('Consumption')
ax2.set_title('Consumption over time')
ax2.legend()

plt.tight_layout()
plt.show()

24.5. Extension: stochastic labor income with IID shocks#

Now let’s solve a model with IID stochastic labor income using deep learning.

24.5.1. Set-Up#

A household chooses a consumption plan \(\{c_t\}_{t \geq 0}\) to maximize

\[ \mathbb{E} \, \sum_{t=0}^{\infty} \beta^t u(c_t) \]

subject to

\[ a_{t+1} = R (a_t - c_t) + Y_{t+1}, \quad c_t \geq 0, \quad a_t \geq 0, \quad t = 0, 1, \ldots \]

Here \(Y_t\) is labor income, which is IID and normally distributed:

\[ Z_t \sim N(m, v), \quad Y_t = \exp(Z_t) \]

Since the shocks are IID, the optimal policy depends only on current assets \(a\), not on the shock history.

We assume:

  1. \(\beta R < 1\)

  2. \(u\) is CRRA with parameter \(\gamma\)

24.5.2. JAX Implementation#

We start with a class called IFP that stores the model primitives.

class IFP(NamedTuple):
    R: float                 # Gross interest rate R = 1 + r
    β: float                 # Discount factor
    γ: float                 # Preference parameter
    z_mean: float            # Mean of log income shock
    z_std: float             # Std dev of log income shock
    z_samples: jnp.ndarray   # Std dev of log income shock


def create_ifp(
        r=0.01,
        β=0.96,
        γ=1.5,
        z_mean=0.1,
        z_std=0.1,
        n_shocks=200,
        seed=42
    ):
    R = 1 + r
    assert R * β < 1, "Stability condition violated."
    key = random.PRNGKey(seed)
    z_samples = z_mean + z_std * jax.random.normal(key, n_shocks)
    return IFP(R, β, γ, z_mean, z_std, z_samples)

24.5.3. Solving the IID model using the EGM#

Since the shocks are IID, the optimal policy depends only on current assets \(a\).

For the IID normal case, we need to compute the expectation:

\[ \mathbb{E}[u'(\sigma(R s + Y))] \]

where \(Z \sim N(m, v)\) and \(Y = \exp(Z)\).

We approximate this expectation using Monte Carlo.

Here is the EGM operator \(K\) for the IID case:

def K(c_in, a_in, ifp, s_grid, n_shocks=50):
    """
    The Euler equation operator for the IFP model with IID shocks using EGM.

    Args:
        c_in: Current consumption policy on endogenous grid
        a_in: Current endogenous asset grid
        ifp: IFP model instance
        s_grid: Exogenous savings grid
        n_shocks: Number of points for Monte Carlo integration

    Returns:
        c_out: Updated consumption policy
        a_out: Updated endogenous asset grid
    """
    R, β, γ, z_mean, z_std, z_samples = ifp
    y_samples = jnp.exp(z_samples)
    u_prime = lambda c: c**(-γ)
    u_prime_inv = lambda c: c**(-1/γ)

    def compute_c_i(s_i):
        """Compute consumption for savings level s_i."""

        # For each income realization, compute next period assets and consumption
        def compute_mu_k(y_k):
            next_a = R * s_i + y_k
            # Interpolate to get consumption
            next_c = jnp.interp(next_a, a_in, c_in)
            return u_prime(next_c)

        # Compute expectation over income shocks (Monte Carlo average)
        mu_values = jax.vmap(compute_mu_k)(y_samples)
        expectation = jnp.mean(mu_values)

        # Invert to get consumption (handles s_i=0 case via smooth function)
        c = u_prime_inv(β * R * expectation)

        # For s_i = 0, consumption should be 0
        return jnp.where(s_i == 0, 0.0, c)

    # Compute consumption for all savings levels
    c_out = jax.vmap(compute_c_i)(s_grid)
    # Compute endogenous asset grid
    a_out = c_out + s_grid

    return c_out, a_out

Here’s the solver using time iteration:

def solve_model(ifp, s_grid, n_shocks=50, tol=1e-5, max_iter=1000):
    """
    Solve the IID model using time iteration with EGM.

    Args:
        ifp: IFP model instance
        s_grid: Exogenous savings grid
        n_shocks: Number of income shock realizations for integration
        tol: Convergence tolerance
        max_iter: Maximum iterations

    Returns:
        c_out: Optimal consumption policy on endogenous grid
        a_out: Endogenous asset grid
    """
    # Initialize with consumption = assets (consume everything)
    a_init = s_grid.copy()
    c_init = s_grid.copy()
    c_in, a_in = c_init, a_init

    for i in range(max_iter):
        c_out, a_out = K(c_in, a_in, ifp, s_grid, n_shocks)

        error = jnp.max(jnp.abs(c_out - c_in))

        if error < tol:
            print(f"Converged in {i} iterations, error = {error:.2e}")
            break

        c_in, a_in = c_out, a_out

        if i % 100 == 0:
            print(f"Iteration {i}, error = {error:.2e}")

    return c_out, a_out

Let’s solve the model and plot the optimal policy:

# Create model instance
ifp = create_ifp(z_mean=0.1, z_std=0.1)

# Create savings grid
s_grid = jnp.linspace(0, 10, 200)

# Solve using EGM
print("Solving IFP model using EGM...\n")
c_egm, a_egm = solve_model(ifp, s_grid, n_shocks=100)
Solving IFP model using EGM...
Iteration 0, error = 1.40e+00
Converged in 38 iterations, error = 6.79e-06

Plot the optimal consumption policy:

fig, ax = plt.subplots()
ax.plot(a_egm, c_egm, 'b-', lw=2, label='EGM solution')
ax.set_xlabel('assets')
ax.set_ylabel('consumption')
ax.set_title('Optimal consumption policy (IFP model, EGM)')
ax.legend()
plt.show()

24.5.4. Solving the IID model with DL#

Since the shocks are IID, the policy depends only on current assets \(a\).

We use the same network architecture as the deterministic case.

The forward pass uses the forward function from the deterministic case.

Here we implement lifetime value computation.

The key is to simulate paths with IID normal income shocks.

@partial(jax.jit, static_argnames=('path_length', 'num_paths'))
def compute_lifetime_value_ifp(params, ifp, path_length, num_paths, key):
    """
    Compute expected lifetime value by averaging over multiple 
    simulated paths.

    Args:
        params: Neural network parameters
        ifp: IFP model instance
        path_length: Length of each simulated path
        num_paths: Number of paths to simulate for averaging
        key: JAX random key for generating income shocks

    Returns:
        Average lifetime value across all simulated paths
    """
    R, β, γ, z_mean, z_std, z_samples = ifp

    def simulate_path(subkey):
        """Simulate a single path and return its lifetime value."""
        z_shocks = z_mean + z_std * jax.random.normal(subkey, path_length)
        Y = jnp.exp(z_shocks)

        def update(t, loop_state):
            a, value, discount = loop_state
            consumption_rate = forward(params, a)
            c = consumption_rate * a
            next_value = value + discount * u(c, γ)
            next_a = R * (a - c) + Y[t]
            next_discount = discount * β
            return next_a, next_value, next_discount

        initial_a = 10.0
        initial_value = 0.0
        initial_discount = 1.0
        initial_state = (initial_a, initial_value, initial_discount)
        final_a, final_value, final_discount = jax.lax.fori_loop(
            0, path_length, update, initial_state
        )

        return final_value

    # Generate keys for all paths
    path_keys = jax.random.split(key, num_paths)

    # Simulate all paths and average
    values = jax.vmap(simulate_path)(path_keys)
    return jnp.mean(values)

The loss function is the negation of the expected lifetime value.

def loss_function_ifp(params, ifp, path_length, num_paths, key):
    return -compute_lifetime_value_ifp(
        params, ifp, path_length, num_paths, key
    )

Now let’s set up and train the network.

We use the same ifp instance that was created for the EGM solution above.

stochastic_config = {
    'seed': 1234,
    'epochs': 400,
    'path_length': 320,
    'num_paths': 500,  # Number of paths to average over
    'learning_rate': 0.001
}

We initialize parameters.

key = random.PRNGKey(seed)
ifp_params = initialize_network(key, layer_sizes)

Let’s set up the optimizer.

ifp_optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),  # Gradient clipping for stability
    optax.adam(learning_rate=stochastic_config['learning_rate'])
)
ifp_opt_state = ifp_optimizer.init(ifp_params)

Train the network using policy gradient ascent.

We use a fixed random key at each epoch for variance reduction.

ifp_value_history = []
best_ifp_value = -jnp.inf
best_ifp_params = ifp_params
fixed_key = random.PRNGKey(stochastic_config['seed'])

print("Training IFP model with deep learning...\n")

for i in range(stochastic_config['epochs']):

    # Compute loss and gradients
    loss, grads = jax.value_and_grad(loss_function_ifp)(
        ifp_params, ifp,
        stochastic_config['path_length'],
        stochastic_config['num_paths'],
        fixed_key
    )
    lifetime_value = -loss
    ifp_value_history.append(lifetime_value)

    # Track best parameters
    if lifetime_value > best_ifp_value:
        best_ifp_value = lifetime_value
        best_ifp_params = ifp_params

    # Update parameters
    updates, ifp_opt_state = ifp_optimizer.update(grads, ifp_opt_state)
    ifp_params = optax.apply_updates(ifp_params, updates)

    if i % 50 == 0:
        print(f"Iteration {i}: Value = {lifetime_value:.4f}")

# Use best parameters
ifp_params = best_ifp_params
print(f"\nBest value: {best_ifp_value:.4f}")
print(f"Final value: {ifp_value_history[-1]:.4f}")
Training IFP model with deep learning...
Iteration 0: Value = -46.4904
Iteration 50: Value = -44.8104
Iteration 100: Value = -43.0656
Iteration 150: Value = -43.0144
Iteration 200: Value = -42.9911
Iteration 250: Value = -42.9793
Iteration 300: Value = -42.9777
Iteration 350: Value = -42.9733

Best value: -42.9710
Final value: -42.9713

Plot the learning progress.

fig, ax = plt.subplots()
ax.plot(ifp_value_history, 'b-', linewidth=2)
ax.set_xlabel('iteration')
ax.set_ylabel('policy value')
ax.set_title('Learning progress')
plt.show()

Compare EGM and DL solutions.

# Evaluate DL policy on asset grid
a_grid_dl = jnp.linspace(0.01, 10.0, 200)
policy_vmap = jax.vmap(lambda a: forward(ifp_params, a))
consumption_rate_dl = policy_vmap(a_grid_dl)
c_dl = consumption_rate_dl * a_grid_dl

fig, ax = plt.subplots()
ax.plot(a_egm, c_egm, lw=2, label='EGM solution')
ax.plot(a_grid_dl, c_dl, lw=2, label='DL solution')
ax.set_xlabel('assets', fontsize=12)
ax.set_ylabel('consumption', fontsize=12)
ax.set_xlim(0, min(a_grid_dl[-1], a_egm[-1]))
ax.legend()
plt.show()