JAX Sharp Edges

This guide documents common pitfalls when working with JAX in janssen. Understanding these “sharp edges” helps avoid subtle bugs and performance issues.

Arrays First, Always

The most important rule: prefer JAX arrays over Python lists/tuples.

JAX traces through your code to compile it. Python lists and tuples are opaque to the tracer, leading to:

  • Silent performance degradation (loops don’t vectorize)

  • ConcretizationTypeError when JIT encounters traced values in Python control flow

Bad: Python List with Loop

def generate_modes_bad(max_order: int) -> list:
    """Returns a Python list - can't be vmapped."""
    mode_indices = []
    for n in range(max_order):
        for m in range(max_order):
            mode_indices.append((n, m))
    return mode_indices

# This loop runs in Python, not on GPU
for n, m in generate_modes_bad(10):
    process_mode(n, m)

Good: JAX Array with vmap

def generate_modes_good(max_order: int) -> Float[Array, " num_modes 2"]:
    """Returns a JAX array - ready for vmap."""
    indices = []
    for n in range(max_order):
        for m in range(max_order):
            indices.append((n, m))
    return jnp.array(indices, dtype=jnp.float64)

# This runs vectorized on GPU
mode_indices = generate_modes_good(10)
jax.vmap(lambda idx: process_mode(idx[0], idx[1]))(mode_indices)

Note: The Python loop in generate_modes_good is fine because it runs once at trace time to build the array. The key is that the output is a JAX array that can be vmapped over.

Control Flow with Traced Values

Python control flow (if, for, while) doesn’t work with traced values.

Bad: Python if with Traced Value

@jax.jit
def bad_conditional(x):
    if x > 0:  # ConcretizationTypeError!
        return x
    else:
        return -x

Good: jax.lax.cond

@jax.jit
def good_conditional(x):
    return jax.lax.cond(
        x > 0,
        lambda: x,
        lambda: -x
    )

Bad: Python for with Traced Bound

@jax.jit
def bad_loop(n, x):
    result = 0.0
    for i in range(n):  # n is traced - fails!
        result += x[i]
    return result

Good: jax.lax.fori_loop

@jax.jit
def good_loop(n, x):
    def body(i, acc):
        return acc + x[i]
    return jax.lax.fori_loop(0, n, body, 0.0)

Static vs Dynamic Arguments

Some arguments determine output shape and must be static (known at compile time).

The _impl Pattern

from functools import partial

@partial(jax.jit, static_argnums=(0, 1))
def _compute_impl(height: int, width: int, data: Array) -> Array:
    """Height and width are static - they determine output shape."""
    result = jnp.zeros((height, width))
    # ... computation ...
    return result

@jaxtyped(typechecker=beartype)
def compute(grid_size: Tuple[int, int], data: Array) -> Array:
    """Public API extracts static args."""
    return _compute_impl(grid_size[0], grid_size[1], data)

Common Patterns

Hermite Polynomial Recurrence

The Hermite polynomial recurrence $H_n(x) = 2xH_{n-1}(x) - 2(n-1)H_{n-2}(x)$ requires iteration. Use jax.lax.fori_loop:

def hermite_polynomial(order: Array, x: Array) -> Array:
    def body_fn(k, carry):
        h_prev2, h_prev1 = carry
        h_curr = 2.0 * x * h_prev1 - 2.0 * (k - 1) * h_prev2
        return h_prev1, h_curr

    h0 = jnp.ones_like(x)
    h1 = 2.0 * x

    _, h_n = jax.lax.fori_loop(2, order + 1, body_fn, (h0, h1))

    return jnp.where(order == 0, h0, jnp.where(order == 1, h1, h_n))

Vectorizing Over Indices

Instead of looping over mode indices, use vmap:

# Bad
modes = []
for n, m in mode_indices:
    modes.append(compute_mode(n, m))
modes = jnp.stack(modes)

# Good
def compute_single_mode(indices):
    n, m = indices[0], indices[1]
    return compute_mode(n, m)

modes = jax.vmap(compute_single_mode)(mode_indices_array)

Why jaxtyping + beartype Helps

With runtime type checking, incorrect array types fail fast:

@jaxtyped(typechecker=beartype)
def process(data: Float[Array, " N 2"]) -> Float[Array, " N"]:
    return jnp.sum(data, axis=1)

# This fails immediately with a clear error:
process([(1, 2), (3, 4)])  # TypeError: expected Array, got list

# This works:
process(jnp.array([[1, 2], [3, 4]]))

The type annotations serve as documentation and runtime validation, catching list/tuple misuse at function boundaries rather than deep in traced code.

Summary

Pattern

Bad

Good

Data structure

Python list/tuple

jnp.array

Iteration

for i in range(n)

jax.vmap or jax.lax.fori_loop

Conditional

if x > 0

jax.lax.cond

Shape params

Traced values

static_argnums

Type safety

Hope for the best

jaxtyping + beartype