JAX Sharp Edges¶
This guide documents common pitfalls when working with JAX in janssen. Understanding these “sharp edges” helps avoid subtle bugs and performance issues.
Arrays First, Always¶
The most important rule: prefer JAX arrays over Python lists/tuples.
JAX traces through your code to compile it. Python lists and tuples are opaque to the tracer, leading to:
Silent performance degradation (loops don’t vectorize)
ConcretizationTypeErrorwhen JIT encounters traced values in Python control flow
Bad: Python List with Loop¶
def generate_modes_bad(max_order: int) -> list:
"""Returns a Python list - can't be vmapped."""
mode_indices = []
for n in range(max_order):
for m in range(max_order):
mode_indices.append((n, m))
return mode_indices
# This loop runs in Python, not on GPU
for n, m in generate_modes_bad(10):
process_mode(n, m)
Good: JAX Array with vmap¶
def generate_modes_good(max_order: int) -> Float[Array, " num_modes 2"]:
"""Returns a JAX array - ready for vmap."""
indices = []
for n in range(max_order):
for m in range(max_order):
indices.append((n, m))
return jnp.array(indices, dtype=jnp.float64)
# This runs vectorized on GPU
mode_indices = generate_modes_good(10)
jax.vmap(lambda idx: process_mode(idx[0], idx[1]))(mode_indices)
Note: The Python loop in generate_modes_good is fine because it runs
once at trace time to build the array. The key is that the output is
a JAX array that can be vmapped over.
Control Flow with Traced Values¶
Python control flow (if, for, while) doesn’t work with traced values.
Bad: Python if with Traced Value¶
@jax.jit
def bad_conditional(x):
if x > 0: # ConcretizationTypeError!
return x
else:
return -x
Good: jax.lax.cond¶
@jax.jit
def good_conditional(x):
return jax.lax.cond(
x > 0,
lambda: x,
lambda: -x
)
Bad: Python for with Traced Bound¶
@jax.jit
def bad_loop(n, x):
result = 0.0
for i in range(n): # n is traced - fails!
result += x[i]
return result
Good: jax.lax.fori_loop¶
@jax.jit
def good_loop(n, x):
def body(i, acc):
return acc + x[i]
return jax.lax.fori_loop(0, n, body, 0.0)
Static vs Dynamic Arguments¶
Some arguments determine output shape and must be static (known at compile time).
The _impl Pattern¶
from functools import partial
@partial(jax.jit, static_argnums=(0, 1))
def _compute_impl(height: int, width: int, data: Array) -> Array:
"""Height and width are static - they determine output shape."""
result = jnp.zeros((height, width))
# ... computation ...
return result
@jaxtyped(typechecker=beartype)
def compute(grid_size: Tuple[int, int], data: Array) -> Array:
"""Public API extracts static args."""
return _compute_impl(grid_size[0], grid_size[1], data)
Common Patterns¶
Hermite Polynomial Recurrence¶
The Hermite polynomial recurrence $H_n(x) = 2xH_{n-1}(x) - 2(n-1)H_{n-2}(x)$
requires iteration. Use jax.lax.fori_loop:
def hermite_polynomial(order: Array, x: Array) -> Array:
def body_fn(k, carry):
h_prev2, h_prev1 = carry
h_curr = 2.0 * x * h_prev1 - 2.0 * (k - 1) * h_prev2
return h_prev1, h_curr
h0 = jnp.ones_like(x)
h1 = 2.0 * x
_, h_n = jax.lax.fori_loop(2, order + 1, body_fn, (h0, h1))
return jnp.where(order == 0, h0, jnp.where(order == 1, h1, h_n))
Vectorizing Over Indices¶
Instead of looping over mode indices, use vmap:
# Bad
modes = []
for n, m in mode_indices:
modes.append(compute_mode(n, m))
modes = jnp.stack(modes)
# Good
def compute_single_mode(indices):
n, m = indices[0], indices[1]
return compute_mode(n, m)
modes = jax.vmap(compute_single_mode)(mode_indices_array)
Why jaxtyping + beartype Helps¶
With runtime type checking, incorrect array types fail fast:
@jaxtyped(typechecker=beartype)
def process(data: Float[Array, " N 2"]) -> Float[Array, " N"]:
return jnp.sum(data, axis=1)
# This fails immediately with a clear error:
process([(1, 2), (3, 4)]) # TypeError: expected Array, got list
# This works:
process(jnp.array([[1, 2], [3, 4]]))
The type annotations serve as documentation and runtime validation, catching list/tuple misuse at function boundaries rather than deep in traced code.
Summary¶
Pattern |
Bad |
Good |
|---|---|---|
Data structure |
Python list/tuple |
|
Iteration |
|
|
Conditional |
|
|
Shape params |
Traced values |
|
Type safety |
Hope for the best |
|