29. Segregation with Persistent Shocks#

29.1. Overview#

In previous lectures, we saw that the Schelling model converges to a segregated equilibrium: agents relocate until everyone is happy, and then the system stops.

But real cities don’t work this way. People move in and out, neighborhoods change, and the population is constantly being reshuffled by small shocks.

In this lecture, we explore what happens when we add this kind of persistent randomness to the model.

Specifically, after each iteration, we randomly flip the type of some agents with a small probability.

We can interpret this as agents occasionally moving away and being replaced by new agents whose type is randomly determined.

With persistent shocks, the system never converges, so the segregation dynamics keep operating indefinitely.

Because agents are constantly being nudged out of equilibrium, the forces that drive segregation never shut off.

The result is that segregation levels continue to increase over time, reaching levels beyond what we see in the basic model.

We use the parallel JAX implementation for efficiency, allowing us to run longer simulations with more agents.

import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from functools import partial
from typing import NamedTuple
import time

29.2. Parameters#

We use 2000 agents of each type and add a flip_prob parameter that controls the probability of an agent’s type being flipped after each iteration.

class Params(NamedTuple):
    num_of_type_0: int = 2000    # number of agents of type 0 (orange)
    num_of_type_1: int = 2000    # number of agents of type 1 (green)
    num_neighbors: int = 10      # number of neighbors
    max_other_type: int = 6      # max number of different-type neighbors tolerated
    num_candidates: int = 3      # candidate locations per agent per iteration
    flip_prob: float = 0.01      # probability of flipping an agent's type


params = Params()

29.3. Setup#

The following functions are repeated from the previous lecture:

Hide code cell source

def initialize_state(key, params):
    n = params.num_of_type_0 + params.num_of_type_1
    locations = random.uniform(key, (n, 2))
    types = jnp.concatenate([jnp.zeros(params.num_of_type_0, dtype=int),
                              jnp.ones(params.num_of_type_1, dtype=int)])
    return locations, types


@partial(jit, static_argnames=('params',))
def is_happy(loc, agent_idx, locations, types, params):
    " True if an agent at loc has at most max_other_type different-type neighbors. "
    distances = jnp.sum((loc - locations)**2, axis=1)
    distances = distances.at[agent_idx].set(jnp.inf)
    _, neighbors = jax.lax.top_k(-distances, params.num_neighbors)
    num_other = jnp.sum(types[neighbors] != types[agent_idx])
    return num_other <= params.max_other_type


@partial(jit, static_argnames=('params',))
def get_unhappy_agents(locations, types, params):
    " Return a boolean array indicating which agents are unhappy. "
    n = params.num_of_type_0 + params.num_of_type_1

    def is_unhappy(i):
        return ~is_happy(locations[i], i, locations, types, params)

    return vmap(is_unhappy)(jnp.arange(n))


@partial(jit, static_argnames=('params',))
def update_agent_location(i, locations, types, key, params):
    """
    Consider current location and num_candidates random alternatives.
    Return the first happy one. Already happy agents stay put.
    """
    current_loc = locations[i, :]

    # Build candidate list: current location + num_candidates random ones
    random_locs = random.uniform(key, (params.num_candidates, 2))
    candidate_locations = jnp.vstack([current_loc[None, :], random_locs])

    # Check happiness at each candidate (in parallel)
    def check_candidate(loc):
        return is_happy(loc, i, locations, types, params)
    happy_at = vmap(check_candidate)(candidate_locations)

    # Return the first happy candidate location.
    # Already happy agents select candidate_locations[0] and stay put.
    # If no candidate is happy, argmax returns 0 and the agent stays put.
    return candidate_locations[jnp.argmax(happy_at)]


@partial(jit, static_argnames=('params',))
def parallel_update_step(locations, types, key, params):
    """
    One step of the parallel algorithm: for each agent, find a happy
    candidate location (in parallel). Happy agents stay put, unhappy
    agents search for new locations.
    """
    n = params.num_of_type_0 + params.num_of_type_1

    keys = random.split(key, n + 1)
    key = keys[0]
    agent_keys = keys[1:]

    # Closure: wraps update_agent_location so vmap can map over a single argument
    def update_one_agent(i):
        return update_agent_location(i, locations, types, agent_keys[i], params)
    new_locations = vmap(update_one_agent)(jnp.arange(n))

    return new_locations, key

29.4. Type Flipping#

This is the key addition in this lecture. After each iteration, we randomly flip the type of each agent with probability flip_prob.

@partial(jit, static_argnames=('params',))
def flip_types(types, key, params):
    """
    Randomly flip agent types with probability flip_prob.
    """
    n = params.num_of_type_0 + params.num_of_type_1
    flip_prob = params.flip_prob

    # Generate random numbers for each agent
    random_vals = random.uniform(key, n)

    # Determine which agents get flipped
    should_flip = random_vals < flip_prob

    # Flip: 0 -> 1, 1 -> 0  (equivalent to 1 - type)
    flipped_types = 1 - types

    # Apply flips only where should_flip is True
    new_types = jnp.where(should_flip, flipped_types, types)

    return new_types

Hide code cell source

def plot_distribution(locations, types, title):
    " Plot the distribution of agents. "
    locations_np = np.asarray(locations)
    types_np = np.asarray(types)

    fig, ax = plt.subplots()
    plot_args = {'markersize': 6, 'alpha': 0.8, 'markeredgecolor': 'black', 'markeredgewidth': 0.5}
    colors = 'darkorange', 'green'
    for agent_type, color in zip((0, 1), colors):
        idx = (types_np == agent_type)
        ax.plot(locations_np[idx, 0],
                locations_np[idx, 1],
                'o',
                markerfacecolor=color,
                **plot_args)
    ax.set_title(title)
    plt.show()

29.5. Simulation with Shocks#

The simulation loop now includes type flipping after each iteration. We run for a fixed number of iterations rather than waiting for convergence, since the system will never fully converge with ongoing shocks.

def run_simulation_with_shocks(params, max_iter=1000, seed=42, plot_every=100):
    """
    Run the Schelling simulation with random type flips.

    Parameters
    ----------
    params : Params
        Model parameters including flip_prob.
    max_iter : int
        Number of iterations to run.
    seed : int
        Random seed.
    plot_every : int
        Plot the distribution every this many iterations.
    """
    key = random.key(seed)
    key, init_key = random.split(key)
    locations, types = initialize_state(init_key, params)
    n = locations.shape[0]

    print(f"Running simulation with {n} agents for {max_iter} iterations")
    print(f"Flip probability: {params.flip_prob}")
    print()

    plot_distribution(locations, types, 'Initial distribution')

    start_time = time.time()

    for iteration in range(1, max_iter + 1):
        # Update locations (agents try to find happy spots)
        locations, key = parallel_update_step(locations, types, key, params)

        # Apply random type flips
        key, flip_key = random.split(key)
        types = flip_types(types, flip_key, params)

        # Periodically report progress and plot
        if iteration % plot_every == 0:
            unhappy = get_unhappy_agents(locations, types, params)
            elapsed = time.time() - start_time
            print(f'Iteration {iteration}: {int(jnp.sum(unhappy))} unhappy agents, {elapsed:.1f}s elapsed')
            plot_distribution(locations, types, f'Iteration {iteration}')

    elapsed = time.time() - start_time
    print(f'\nCompleted {max_iter} iterations in {elapsed:.2f} seconds.')

    return locations, types

29.6. Results#

Let’s warm up the JIT-compiled functions and run the simulation:

key = random.key(0)
key, init_key = random.split(key)
test_locations, test_types = initialize_state(init_key, params)

_ = is_happy(test_locations[0], 0, test_locations, test_types, params)
_ = get_unhappy_agents(test_locations, test_types, params)
key, subkey = random.split(key)
_ = update_agent_location(0, test_locations, test_types, subkey, params)
key, subkey = random.split(key)
_, _ = parallel_update_step(test_locations, test_types, subkey, params)
key, subkey = random.split(key)
_ = flip_types(test_types, subkey, params)

print("JAX functions compiled and ready!")
JAX functions compiled and ready!
locations, types = run_simulation_with_shocks(params, max_iter=1000, plot_every=200)
Running simulation with 4000 agents for 1000 iterations
Flip probability: 0.01
Iteration 200: 48 unhappy agents, 3.3s elapsed
Iteration 400: 57 unhappy agents, 6.7s elapsed
Iteration 600: 57 unhappy agents, 10.1s elapsed
Iteration 800: 48 unhappy agents, 13.5s elapsed
Iteration 1000: 45 unhappy agents, 16.9s elapsed

Completed 1000 iterations in 17.14 seconds.
_images/43cb535e11e6a48c03e4dc579586292b4dbd86ea34a2f7f5667edc46688747fd.png _images/9cf64d68418f16419523d8d2ea8972ca1e2787baae9ad034f75a2547d3482ffa.png _images/5a02820187b1d12f9f7c2147b487fa96ea4a9ea7558f28c6d3422b49d3039c16.png _images/64b3a7954535f494ba25cc3be27523935e638a77c63ea8b06383e7d593fe50ea.png _images/f8d8d348d74883bd1bdf2fa83297cc7372e68719b88ea51528f21a24f16f601f.png _images/6956a43170d38eceb6c9697319ee6376cf7f81c8d3d8f93cbc9c78dfa6e1f4f4.png

29.7. Discussion#

The figures show an interesting result: segregation levels at the end of the simulation are much higher than in the basic model without shocks.

Why does this happen?

In the basic model, the system converges to an equilibrium where everyone is happy, and then the dynamics stop.

With persistent shocks, the system never converges — random type flips create local pockets of unhappiness, triggering relocations that can cascade through the population.

The key insight is that the segregation dynamics never shut off.

The result is that segregation continues to increase over time, reaching levels far beyond what we observe when the system is allowed to converge.

This is arguably more realistic than the static equilibrium of the basic model.

Real cities experience constant population turnover, and the Schelling dynamics operate continuously on the evolving population.