3. Adventures with Autodiff#

GPU

This lecture was built using a machine with JAX installed and access to a GPU.

To run this lecture on Google Colab, click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.

To run this lecture on your own machine, you need to install Google JAX.

3.1. Overview#

This lecture gives a brief introduction to automatic differentiation using Google JAX.

Automatic differentiation is one of the key elements of modern machine learning and artificial intelligence.

As such it has attracted a great deal of investment and there are several powerful implementations available.

One of the best of these is the automatic differentiation routines contained in JAX.

While other software packages also offer this feature, the JAX version is particularly powerful because it integrates so well with other core components of JAX (e.g., JIT compilation and parallelization).

As we will see in later lectures, automatic differentiation can be used not only for AI but also for many problems faced in mathematical modeling, such as multi-dimensional nonlinear optimization and root-finding problems.

We need the following imports

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

Checking for a GPU:

!nvidia-smi
Mon Oct 21 10:29:29 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.6     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       On  | 00000001:00:00.0 Off |                  Off |
| N/A   43C    P0              35W /  70W |      2MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

3.2. What is automatic differentiation?#

Autodiff is a technique for calculating derivatives on a computer.

3.2.1. Autodiff is not finite differences#

The derivative of \(f(x) = \exp(2x)\) is

\[ f'(x) = 2 \exp(2x) \]

A computer that doesn’t know how to take derivatives might approximate this with the finite difference ratio

\[ (Df)(x) := \frac{f(x+h) - f(x)}{h} \]

where \(h\) is a small positive number.

def f(x):
    "Original function."
    return np.exp(2 * x)

def f_prime(x):
    "True derivative."
    return 2 * np.exp(2 * x)

def Df(x, h=0.1):
    "Approximate derivative (finite difference)."
    return (f(x + h) - f(x))/h

x_grid = np.linspace(-2, 1, 200)
fig, ax = plt.subplots()
ax.plot(x_grid, f_prime(x_grid), label="$f'$")
ax.plot(x_grid, Df(x_grid), label="$Df$")
ax.legend()
plt.show()
_images/4d30539531bfa95c8c886df346ba838bf3460ca8e1a73797509d55014132412a.png

This kind of numerical derivative is often inaccurate and unstable.

One reason is that

\[ \frac{f(x+h) - f(x)}{h} \approx \frac{0}{0} \]

Small numbers in the numerator and denominator causes rounding errors.

The situation is exponentially worse in high dimensions / with higher order derivatives

3.2.2. Autodiff is not symbolic calculus#

Symbolic calculus tries to use rules for differentiation to produce a single closed-form expression representing a derivative.

from sympy import symbols, diff

m, a, b, x = symbols('m a b x')
f_x = (a*x + b)**m
f_x.diff((x, 6))  # 6-th order derivative
\[\displaystyle \frac{a^{6} m \left(a x + b\right)^{m} \left(m^{5} - 15 m^{4} + 85 m^{3} - 225 m^{2} + 274 m - 120\right)}{\left(a x + b\right)^{6}}\]

Symbolic calculus is not well suited to high performance computing.

One disadvantage is that symbolic calculus cannot differentiate through control flow.

Also, using symbolic calculus might involve redundant calculations.

For example, consider

\[ (f g h)' = (f' g + g' f) h + (f g) h' \]

If we evaluate at \(x\), then we evalute \(f(x)\) and \(g(x)\) twice each.

Also, computing \(f'(x)\) and \(f(x)\) might involve similar terms (e.g., \((f(x) = \exp(2x)' \implies f'(x) = 2f(x)\)) but this is not exploited in symbolic algebra.

3.2.3. Autodiff#

Autodiff produces functions that evaluates derivatives at numerical values passed in by the calling code, rather than producing a single symbolic expression representing the entire derivative.

Derivatives are constructed by breaking calculations into component parts via the chain rule.

The chain rule is applied until the point where the terms reduce to primitive functions that the program knows how to differentiate exactly (addition, subtraction, exponentiation, sine and cosine, etc.)

3.3. Some experiments#

Let’s start with some real-valued functions on \(\mathbb R\).

3.3.1. A differentiable function#

Let’s test JAX’s auto diff with a relatively simple function.

def f(x):
    return jnp.sin(x) - 2 * jnp.cos(3 * x) * jnp.exp(- x**2)

We use grad to compute the gradient of a real-valued function:

f_prime = jax.grad(f)

Let’s plot the result:

x_grid = jnp.linspace(-5, 5, 100)
fig, ax = plt.subplots()
ax.plot(x_grid, [f(x) for x in x_grid], label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend()
plt.show()
_images/073c0de878f4d6809089a3f5c0ba70d96104d181230bf034a11a3426a8751ae7.png

3.3.2. Absolute value function#

What happens if the function is not differentiable?

def f(x):
    return jnp.abs(x)
f_prime = jax.grad(f)
fig, ax = plt.subplots()
ax.plot(x_grid, [f(x) for x in x_grid], label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend()
plt.show()
_images/96f69ec9fd774ddb04151a6add9ad0502edb0b2cb60b9fd3fb6f7d7a202fd7df.png

At the nondifferentiable point \(0\), jax.grad returns the right derivative:

f_prime(0.0)
Array(1., dtype=float32, weak_type=True)

3.3.3. Differentiating through control flow#

Let’s try differentiating through some loops and conditions.

def f(x):
    def f1(x):
        for i in range(2):
            x *= 0.2 * x
        return x
    def f2(x):
        x = sum((x**i + i) for i in range(3))
        return x
    y = f1(x) if x < 0 else f2(x)
    return y
f_prime = jax.grad(f)
x_grid = jnp.linspace(-5, 5, 100)
fig, ax = plt.subplots()
ax.plot(x_grid, [f(x) for x in x_grid], label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend()
plt.show()
_images/cdae99e107de907c4ccb95c70dcc24121cd212b1650d3b5670bbc3ebd402b75c.png

3.3.4. Differentiating through a linear interpolation#

We can differentiate through linear interpolation, even though the function is not smooth:

n = 20
xp = jnp.linspace(-5, 5, n)
yp = jnp.cos(2 * xp)

fig, ax = plt.subplots()
ax.plot(x_grid, jnp.interp(x_grid, xp, yp))
plt.show()
_images/e2695e57f81cb0b3d361a7be764d2f93081d62ab36f71f220e5a092f3af55ac2.png
f_prime = jax.grad(jnp.interp)
f_prime_vec = jax.vmap(f_prime, in_axes=(0, None, None))
fig, ax = plt.subplots()
ax.plot(x_grid, f_prime_vec(x_grid, xp, yp))
plt.show()
_images/aaa125a378ff8af36004340f618726812e7cc6eaa831025968fa4a7fd6b579c5.png

3.4. Gradient Descent#

Let’s try implementing gradient descent.

As a simple application, we’ll use gradient descent to solve for the OLS parameter estimates in simple linear regression.

3.4.1. A function for gradient descent#

Here’s an implementation of gradient descent.

def grad_descent(f,       # Function to be minimized
                 args,    # Extra arguments to the function
                 x0,      # Initial condition
                 λ=0.1,   # Initial learning rate
                 tol=1e-5, 
                 max_iter=1_000):
    """
    Minimize the function f via gradient descent, starting from guess x0.

    The learning rate is computed according to the Barzilai-Borwein method.
    
    """
    
    f_grad = jax.grad(f)
    x = jnp.array(x0)
    df = f_grad(x, args)
    ϵ = tol + 1
    i = 0
    while ϵ > tol and i < max_iter:
        new_x = x - λ * df
        new_df = f_grad(new_x, args)
        Δx = new_x - x
        Δdf = new_df - df
        λ = jnp.abs(Δx @ Δdf) / (Δdf @ Δdf)
        ϵ = jnp.max(jnp.abs(Δx))
        x, df = new_x, new_df
        i += 1
        
    return x
    

3.4.2. Simulated data#

We’re going to test our gradient descent function my minimizing a sum of least squares in a regression problem.

Let’s generate some simulated data:

n = 100
key = jax.random.PRNGKey(1234)
x = jax.random.uniform(key, (n,))

α, β, σ = 0.5, 1.0, 0.1  # Set the true intercept and slope.
key, subkey = jax.random.split(key)
ϵ = jax.random.normal(subkey, (n,))

y = α * x + β + σ * ϵ
fig, ax = plt.subplots()
ax.scatter(x, y)
plt.show()
_images/e3883063b64b69f24a23cf140662e6993d47b5273d60aa03e03e1bbd63d58054.png

Let’s start by calculating the estimated slope and intercept using closed form solutions.

mx = x.mean()
my = y.mean()
α_hat = jnp.sum((x - mx) * (y - my)) / jnp.sum((x - mx)**2)
β_hat = my - α_hat * mx
α_hat, β_hat
(Array(0.45994443, dtype=float32), Array(1.0219053, dtype=float32))
fig, ax = plt.subplots()
ax.scatter(x, y)
ax.plot(x, α_hat * x + β_hat, 'k-')
ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$')
ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$')
plt.show()
_images/7b532ad1e4b1ddd39cd8095c9a2fbb613b9ffd8efb146cd618085161d50b7a1b.png

3.4.3. Minimizing squared loss by gradient descent#

Let’s see if we can get the same values with our gradient descent function.

First we set up the least squares loss function.

@jax.jit
def loss(params, data):
    a, b = params
    x, y = data
    return jnp.sum((y - a * x - b)**2)

Now we minimize it:

p0 = jnp.zeros(2)  # Initial guess for α, β
data = x, y
α_hat, β_hat = grad_descent(loss, data, p0)

Let’s plot the results.

fig, ax = plt.subplots()
x_grid = jnp.linspace(0, 1, 100)
ax.scatter(x, y)
ax.plot(x_grid, α_hat * x_grid + β_hat, 'k-', alpha=0.6)
ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$')
ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$')
plt.show()
_images/9dde5f11937ead1849cbd80e8bdc81500442fdd1a0c65fae8b7806133ec9ca13.png

Notice that we get the same estimates as we did from the closed form solutions.

3.4.4. Adding a squared term#

Now let’s try fitting a second order polynomial.

Here’s our new loss function.

@jax.jit
def loss(params, data):
    a, b, c = params
    x, y = data
    return jnp.sum((y - a * x**2 - b * x - c)**2)

Now we’re minimizing in three dimensions.

Let’s try it.

p0 = jnp.zeros(3)
α_hat, β_hat, γ_hat = grad_descent(loss, data, p0)

fig, ax = plt.subplots()
ax.scatter(x, y)
ax.plot(x_grid, α_hat * x_grid**2 + β_hat * x_grid + γ_hat, 'k-', alpha=0.6)
ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$')
ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$')
plt.show()
_images/52672278403aaa45cc87df383bc2fa66a446ec1a856f9b81559a495961b37bd6.png

3.5. Exercises#

Exercise 3.1

The function jnp.polyval evaluates polynomials.

For example, if len(p) is 3, then jnp.polyval(p, x) returns

\[ f(p, x) := p_0 x^2 + p_1 x + p_2 \]

Use this function for polynomial regression.

The (empirical) loss becomes

\[ \ell(p, x, y) = \sum_{i=1}^n (y_i - f(p, x_i))^2 \]

Set \(k=4\) and set the initial guess of params to jnp.zeros(k).

Use gradient descent to find the array params that minimizes the loss function and plot the result (following the examples above).