PyTree Architecture for JAX Compatibility¶
Janssen uses JAX PyTrees as the foundation for all data structures. This enables automatic differentiation, JIT compilation, and parallel execution across the entire library.
What is a PyTree?¶
Definition¶
A PyTree is a nested structure of containers (lists, tuples, dicts) with arbitrary leaves. JAX can automatically traverse, transform, and differentiate through PyTrees.
# Examples of PyTrees
simple = [1, 2, 3] # List of scalars
nested = {"a": [1, 2], "b": (3, 4)} # Dict with list and tuple
arrays = (jnp.array([1, 2]), jnp.array([3, 4])) # Tuple of arrays
Why PyTrees?¶
Composability: Combine multiple arrays into one logical unit
Transformation: Apply
jit,grad,vmapto entire structuresParallelism: Automatic vectorization over batches
Type safety: Preserve structure through computations
Janssen’s PyTree hierarchy. Core types like OpticalWavefront and
CoherentModeSet are registered as PyTrees, enabling seamless use
with JAX transformations.¶
Core Data Types¶
OpticalWavefront¶
The fundamental type for representing optical fields:
from janssen.utils import OpticalWavefront
wavefront = OpticalWavefront(
field=complex_field, # Complex[Array, "H W"] or [H W 2]
wavelength=632.8e-9, # Scalar wavelength
dx=1e-6, # Pixel spacing
z_position=0.0, # Axial position
polarization=False, # Scalar (False) or Jones (True)
)
Wavefront type variants. (a) Scalar wavefront with complex field, (b) polarized wavefront with Jones vector at each pixel.¶
PropagatingWavefront¶
Extends OpticalWavefront with propagation metadata:
from janssen.utils import PropagatingWavefront
prop_wf = PropagatingWavefront(
field=complex_field,
wavelength=632.8e-9,
dx=1e-6,
z_position=0.0,
polarization=False,
source_distance=float("inf"), # From plane wave
propagation_method="angular_spectrum",
)
CoherentModeSet¶
For partially coherent fields (see Coherence Guide):
from janssen.utils import CoherentModeSet
modes = CoherentModeSet(
modes=mode_array, # Complex[Array, "M H W"]
weights=weight_array, # Float[Array, "M"]
wavelength=632.8e-9,
dx=1e-6,
z_position=0.0,
polarization=False,
)
# Access properties
print(f"Number of modes: {modes.num_modes}")
print(f"Effective modes: {modes.effective_mode_count}")
Coherence types. CoherentModeSet stores multiple modes with weights,
while PolychromaticWavefront stores fields at multiple wavelengths.¶
Factory Functions¶
Why Factories?¶
Factory functions provide validation and normalization before creating PyTree instances:
from janssen.utils import make_optical_wavefront
# Factory validates inputs
wavefront = make_optical_wavefront(
field=complex_field,
wavelength=632.8e-9,
dx=1e-6,
)
# - Converts to correct dtype
# - Validates shapes
# - Sets defaults
# - Normalizes parameters
Factory function workflow. Raw inputs are validated, converted to JAX arrays, normalized, and assembled into a registered PyTree.¶
Available Factories¶
Factory |
Creates |
Validations |
|---|---|---|
|
|
Shape, dtype, wavelength > 0 |
|
|
Weights ≥ 0, matching shapes |
|
|
Wavelengths > 0, weights sum |
|
|
Focal length, NA consistency |
|
|
Position count matches patterns |
PyTree Registration¶
The @register_pytree_node_class Decorator¶
Custom classes must be registered with JAX:
from jax.tree_util import register_pytree_node_class
from typing import NamedTuple
@register_pytree_node_class
class MyDataType(NamedTuple):
"""Custom PyTree type."""
field: Array
parameter: float
def tree_flatten(self):
"""Return (children, aux_data)."""
return ((self.field, self.parameter), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct from flattened form."""
return cls(*children)
NamedTuple Pattern¶
Janssen uses NamedTuple for PyTrees because:
Immutable (required for JAX)
Named fields (self-documenting)
Memory efficient (no
__dict__)Works with type checkers
JAX Transformations¶
JIT Compilation¶
PyTrees enable JIT compilation of entire workflows:
from jax import jit
@jit
def propagate_and_measure(wavefront, distance):
output = angular_spectrum(wavefront, distance)
return output.intensity
# Compiled function works with PyTree input
intensity = propagate_and_measure(input_wavefront, 1e-3)
Automatic Differentiation¶
Gradients flow through PyTree structures:
from jax import grad
def loss_function(wavefront):
output = propagate(wavefront)
return jnp.sum((output.intensity - target)**2)
# Gradient with respect to wavefront field
grad_fn = grad(loss_function)
gradient = grad_fn(input_wavefront)
Vectorization with vmap¶
Process batches automatically:
from jax import vmap
# Vectorize over multiple wavefronts
batched_propagate = vmap(
angular_spectrum,
in_axes=(0, None), # Batch over wavefronts, same distance
)
# Process batch at once
outputs = batched_propagate(wavefront_batch, distance)
Best Practices¶
Immutability¶
PyTree leaves should be immutable. Use functional updates:
# Wrong: in-place modification
wavefront.field = new_field # Error!
# Right: create new instance
new_wavefront = wavefront._replace(field=new_field)
# Or use factory
new_wavefront = make_optical_wavefront(
field=new_field,
wavelength=wavefront.wavelength,
dx=wavefront.dx,
)
Type Annotations¶
Use jaxtyping for array shapes:
from jaxtyping import Array, Complex, Float
def my_function(
field: Complex[Array, "H W"],
wavelength: Float[Array, ""],
) -> Complex[Array, "H W"]:
...
Device Placement¶
Arrays in PyTrees can be placed on specific devices:
from jax import device_put
# Move entire PyTree to GPU
gpu_wavefront = device_put(wavefront, jax.devices("gpu")[0])
Memory Considerations¶
Efficient PyTree Operations¶
# Prefer: single PyTree operation
output = propagate(wavefront)
# Avoid: extracting and re-creating
field = wavefront.field
wavelength = wavefront.wavelength
result = manual_propagate(field, wavelength) # Loses structure
Large Mode Sets¶
For many modes, consider chunked processing:
from janssen.cohere import propagate_coherent_mode_set
# Process in chunks to manage memory
output = propagate_coherent_mode_set(
mode_set,
distance=1e-3,
chunk_size=10, # Process 10 modes at a time
)
References¶
Bradbury, J. et al. “JAX: composable transformations of Python+NumPy programs” (2018)