janssen.utils¶
Utility functions for distributed computing and math operations.
Extended Summary¶
Utilities for distributed JAX computing across multiple devices, mathematical helper functions for complex-valued operations, and general-purpose optimization algorithms.
Routine Listings¶
bessel_j0()Compute J_0(x), regular Bessel function of the first kind, order 0.
bessel_jn()Compute J_n(x), regular Bessel function of the first kind, order n.
bessel_kv()Compute K_v(x), modified Bessel function of the second kind.
jt_residual()Compute residuals and J^T @ r simultaneously.
create_mesh()Creates a device mesh for data parallelism across available devices.
jtj_diag()Estimate diagonal of J^T J for preconditioning.
max_eigenval()Estimate largest eigenvalue of J^T J via power iteration.
flatten_params()Flatten complex arrays to real parameter vector for optimization.
fourier_shift()FFT-based sub-pixel shifting of 2D fields.
gn_solve()High-level solver that runs Gauss-Newton until convergence.
gn_loss_history()Gauss-Newton solver with per-iteration loss tracking only.
gn_history()Gauss-Newton solver with per-iteration history tracking.
gn_step()Generic Gauss-Newton step with trust-region damping.
get_device_count()Gets the number of available JAX devices.
get_device_memory_gb()Detects device count and memory per device (GB) for GPUs/CPUs.
hessian_matvec()Create exact Hessian-vector product operator.
jtj_matvec()Create Jacobian-free (J^T J + λI) operator for Gauss-Newton.
shard_batch()Shards array data across the batch dimension for parallel processing.
unflatten_params()Unflatten real parameter vector back to complex arrays.
wirtinger_grad()Compute the Wirtinger gradient of a complex-valued function.
Notes
For type definitions and PyTree classes, see the janssen.types module.
- janssen.utils.bessel_j0(x: Float[Array, '...']) Float[Array, '...'][source]¶
Compute J_0(x), regular Bessel function of the first kind, order 0.
- Parameters:
x (
Float[Array,"..."]) – Input array.- Returns:
Values of J_0(x).
- Return type:
Float[Array," ..."]
Notes
This is a wrapper around JAX’s scipy implementation for consistency. J_0(x) is the regular Bessel function of the first kind of order 0.
The function is differentiable, JIT-compatible, and supports broadcasting.
Examples
>>> x = jnp.linspace(0, 10, 100) >>> j0_vals = bessel_j0(x)
- janssen.utils.bessel_jn(n: int | Int[Array, ''], x: Float[Array, '...']) Float[Array, '...'][source]¶
Compute J_n(x), regular Bessel function of the first kind, order n.
- Parameters:
n (
int) – Order of the Bessel function (integer). Must be a compile-time constant for JIT compilation.x (
Float[Array,"..."]) – Input array.
- Returns:
Values of J_n(x).
- Return type:
Float[Array," ..."]
Notes
This is a wrapper around JAX’s scipy implementation for consistency. J_n(x) is the regular Bessel function of the first kind of order n.
The function is differentiable and supports broadcasting. Note that the order n must be a compile-time constant (not a traced value) when used inside JIT-compiled functions.
Examples
>>> x = jnp.linspace(0, 10, 100) >>> j1_vals = bessel_jn(1, x) >>> j2_vals = bessel_jn(2, x)
- janssen.utils.bessel_kv(v: float | Float[Array, ''], x: Float[Array, '...']) Float[Array, '...'][source]¶
Compute the modified Bessel function of the second kind K_v(x).
- Parameters:
v (
float) – Order of the Bessel function (v >= 0).x (
Float[Array,"..."]) – Positive real input array.
- Returns:
Approximated values of K_v(x).
- Return type:
Float[Array," ..."]
Notes
Computes K_v(x) for real order v >= 0 and x > 0, using a numerically stable and differentiable JAX-compatible approximation.
Valid for v >= 0 and x > 0
Supports broadcasting and autodiff
JIT-safe and VMAP-safe
Uses series expansion for small x (x <= 2.0) and asymptotic expansion for large x
For non-integer v, uses the reflection formula: K_v = π/(2sin(πv)) * (I_{-v} - I_v)
For integer v, uses specialized series and recurrence relations
Special exact formula for v = 0.5: K_{1/2}(x) = sqrt(π/(2x)) * exp(-x)
The transition point between small and large x approximations is set at x = 2.0
Algorithm
For integer orders n > 1, uses recurrence relations with masked updates to only update values within the target range
- janssen.utils.create_mesh(n_devices: int | None = None) Mesh[source]¶
Create a device mesh for data parallelism.
Creates a 1D device mesh suitable for sharding arrays across their batch dimension. The mesh can be used with sharding specifications to distribute computation across multiple devices.
- Parameters:
n_devices (
int, optional) – Number of devices to use in the mesh. If None, uses all available devices detected by JAX (default: None).- Returns:
mesh – Device mesh with axis name ‘batch’ for data parallelism. The mesh contains a linear arrangement of devices for distributing batched computations.
- Return type:
Mesh
Examples
>>> # Create mesh with all available devices >>> mesh = create_mesh() >>> print(mesh.shape) {'batch': 4} # If 4 devices are available
>>> # Create mesh with specific number of devices >>> mesh = create_mesh(n_devices=2) >>> print(mesh.shape) {'batch': 2}
Notes
The returned mesh has a single axis named ‘batch’, making it suitable for distributing the first dimension of arrays across devices. For more complex sharding patterns, consider creating custom meshes using jax.sharding.Mesh directly.
- janssen.utils.get_device_count() int[source]¶
Get number of available JAX devices.
- Returns:
n_devices – Number of available accelerators (GPUs/TPUs)
- Return type:
Examples
>>> import janssen as jns >>> n = jns.utils.get_device_count() >>> print(f"Found {n} devices") Found 8 devices
- janssen.utils.get_device_memory_gb() tuple[int, float][source]¶
Detect device count and memory per device.
Attempts to detect the number of JAX devices and total memory per device using platform-specific methods: - NVIDIA GPUs: nvidia-smi - CPUs: system RAM via psutil or /proc/meminfo - Other platforms: conservative fallback
- Return type:
- Returns:
Notes
Detection Methods by Platform:
NVIDIA GPUs: - Uses nvidia-smi to query first GPU’s total memory - Converts from MB to GB - Multi-GPU systems: assumes all GPUs have same memory
CPUs: - First tries psutil.virtual_memory() for cross-platform detection - Falls back to /proc/meminfo on Linux if psutil unavailable - Returns total system RAM (all devices share this pool)
TPUs/Other Accelerators: - No direct memory detection available - Falls back to conservative defaults
Fallback Values:
GPU fallback (16.0 GB) works for: - NVIDIA V100 (16 GB variant) - NVIDIA Tesla T4 (16 GB) - NVIDIA RTX 6000 (24 GB, safe to use 16) - NVIDIA A100 (40 GB variant, safe to use 16)
CPU fallback (8.0 GB): - Conservative for modern systems (most have 16+ GB) - Prevents OOM on low-memory machines
Platform Limitations:
NVIDIA GPU: Only detects first GPU memory
AMD GPU (ROCm): No detection, uses fallback
Intel GPU: No detection, uses fallback
Google TPU: No detection, uses fallback
CPU: Detects total system RAM (shared across all cores)
Examples
>>> from janssen.utils.distributed import get_device_memory_gb >>> num_devices, memory = get_device_memory_gb() >>> print(f"Detected {num_devices} devices with {memory:.1f} GB each") Detected 8 devices with 16.0 GB each
See also
get_device_countGet number of available devices
- janssen.utils.shard_batch(data: Shaped[Array, '...'], mesh: Mesh) Shaped[Array, '...'][source]¶
Shard data across batch dimension.
Distributes an array’s first dimension (batch dimension) across devices in the provided mesh. This enables parallel processing of batched data with automatic memory distribution and computation parallelism.
- Parameters:
data (
Shaped[Array," ..."]) – Input array to shard. The first dimension is treated as the batch dimension and will be distributed across devices. Can be any JAX or NumPy array.mesh (
Mesh) – Device mesh created by create_mesh() or custom mesh with a ‘batch’ axis. Defines how the data will be distributed across devices.
- Returns:
sharded_data – Input array with the batch dimension sharded across devices in the mesh. The array’s computation will be automatically parallelized across devices.
- Return type:
Shaped[Array," ..."]
Examples
>>> import jax.numpy as jnp >>> from janssen.utils.distributed import create_mesh, shard_batch >>> >>> # Create sample data with batch dimension >>> data = jnp.ones((8, 256, 256)) >>> >>> # Create mesh and shard data >>> mesh = create_mesh() >>> sharded_data = shard_batch(data, mesh) >>> >>> # The first dimension is now distributed across devices >>> # Operations on sharded_data will run in parallel
Notes
The batch dimension size should ideally be divisible by the number of devices for optimal load balancing.
Sharding is applied using NamedSharding with PartitionSpec(‘batch’), which partitions only the first dimension.
Subsequent operations on the sharded array will automatically maintain the sharding pattern where possible.
- janssen.utils.jt_residual(residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], params: Float[Array, 'n'], weights: int | float | complex | Num[Array, ''] | Float[Array, 'm'] = -1.0) tuple[Float[Array, 'm'], Float[Array, 'n']][source]¶
Compute residuals and J^T @ W @ r simultaneously.
Uses reverse-mode autodiff to compute the gradient of the loss L = 0.5 * r^T W r, which equals J^T @ W @ r.
For unweighted least squares (W = I), this reduces to J^T @ r.
Implementation Logic¶
The function exploits the chain rule via reverse-mode autodiff:
Compute residuals and VJP function: residuals, vjp_fn = jax.vjp(residual_fn, params) - Evaluates r(θ) at current parameters - Constructs vjp_fn: vector → J^T @ vector - Single forward pass through residual_fn
Weight handling (conditional based on weights parameter): - If weights < 0 (sentinel): set weights_normalized = ones(m)
→ Unweighted case W = I
Else: broadcast weights to residual shape → Weighted case W = diag(weights)
Uses jax.lax.cond for JIT compatibility (control flow traced)
Apply weighting and compute weighted gradient: weighted_residuals = weights_normalized * residuals jt_wr = vjp_fn(weighted_residuals)[0] - Element-wise multiplication: W @ r for diagonal W - vjp_fn(W @ r) = J^T @ W @ r by chain rule - Single backward pass through residual_fn
Total cost: 1 forward + 1 backward pass through residual_fn, regardless of whether weighting is used.
Mathematical Derivation¶
For loss L(θ) = 0.5 * ||W^{1/2} r(θ)||² = 0.5 r(θ)^T W r(θ):
∇L = J^T W r
where J is the Jacobian ∂r/∂θ. The VJP operator computes J^T @ v for any vector v. Setting v = W @ r gives J^T @ W @ r directly.
- type residual_fn:
Callable[[Float[Array, 'n']],Float[Array, 'm']]- param residual_fn:
Function mapping parameters to residuals.
- type residual_fn:
Callable[[Float[Array," n"]],Float[Array," m"]]- type params:
Float[Array, 'n']- param params:
Current parameters.
- type params:
Float[Array," n"]- type weights:
Union[int,float,complex,Num[Array, ''],Float[Array, 'm']], default:-1.0- param weights:
Diagonal weight matrix entries for weighted least squares. If -1.0 (default), assumes unweighted (W = I).
- type weights:
Union[ScalarNumeric,Float[Array," m"]], optional- rtype:
tuple[Float[Array, 'm'],Float[Array, 'n']]- returns:
residuals (
Float[Array," m"]) – Current residual vector r(θ).jt_wr (
Float[Array," n"]) – Weighted gradient J^T @ W @ r = ∇(0.5 * r^T W r).
Notes
For diagonal weighting W = diag(w), the weighted gradient is J^T @ W @ r where each residual r_i is weighted by w_i.
Examples
>>> def residual_fn(x): ... return x**2 - jnp.array([1.0, 4.0]) >>> params = jnp.array([1.5, 2.5]) >>> r, jt_r = jt_residual(residual_fn, params) >>> # Weighted version: >>> weights = jnp.array([2.0, 0.5]) >>> r, jt_wr = jt_residual(residual_fn, params, weights)
- janssen.utils.jtj_diag(residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], params: Float[Array, 'n'], num_samples: int = 10) Float[Array, 'n'][source]¶
Estimate diagonal of J^T J via Hutchinson’s trace estimator.
The diagonal approximates per-parameter sensitivities and is useful for diagonal preconditioning in conjugate gradient.
Implementation Logic¶
The function implements Hutchinson’s stochastic trace estimator:
Setup: - Construct matvec operator: (J^T J) @ v via jtj_matvec - Generate num_samples independent random keys
Per-Sample Estimation (parallelized via vmap): Internal function estimate_one(key) performs: a. Sample probe vector z ~ Rademacher({-1, +1}^n)
Each entry independently ±1 with equal probability
Uses params.dtype to match precision (float32/float64)
Compute Az = (J^T J) @ z via matvec - Cost: 1 forward + 1 backward pass through residual_fn
Compute element-wise product: z ⊙ Az - This is the single-sample diagonal estimate
Return z ⊙ Az
Uses jax.vmap(estimate_one)(keys): - Vectorizes over num_samples keys - Compiles to parallel batch of matrix-vector products - Returns shape (num_samples, n) stacked estimates
Averaging: - Compute mean across samples: diagonal = mean(estimates, axis=0) - Reduces variance by √(num_samples)
Algorithm: Hutchinson’s Estimator¶
For any symmetric matrix A, diagonal entries satisfy:
A_ii = E_z[z_i (A @ z)_i]
where z ~ Rademacher({-1, +1}^n). This is unbiased:
- E[z_i (A @ z)_i] = E[z_i Σⱼ A_ij z_j]
= Σⱼ A_ij E[z_i z_j] = A_ii
since E[z_i z_j] = δ_ij (orthogonality of Rademacher).
Variance and Sampling¶
The variance of the single-sample estimator is:
Var[z_i (A @ z)_i] ≈ 2 Σⱼ A_ij²
For matrices with strong off-diagonal structure (like J^T J in ptychography due to overlapping scan positions), this variance can be large. The estimator variance scales as:
Var[diagonal_estimate] ~ (2/num_samples) Σⱼ A_ij²
With num_samples=10, standard error is ~√(0.2 Σⱼ A_ij²). For ill-conditioned problems, increase to 50-100 samples.
Why Rademacher? Could also use Gaussian, but Rademacher: - Has lower variance for heavy-tailed eigenvalue distributions - Is memory-efficient (binary values) - Allows exact integer arithmetic in some contexts
Note: Individual diagonal estimates can be negative due to finite sampling, even though true diagonal entries are non-negative (J^T J is positive semi-definite). When using for preconditioning, clamp to zero: max(diagonal, 0) before inversion.
Computational Cost¶
Total: num_samples × (1 forward + 1 backward pass)
For num_samples=10: 10 forward + 10 backward passes
All samples computed in parallel via vmap (batch size = num_samples)
Memory: O(num_samples × n) for stacked estimates
- type residual_fn:
Callable[[Float[Array, 'n']],Float[Array, 'm']]- param residual_fn:
Residual function.
- type residual_fn:
Callable[[Float[Array," n"]],Float[Array," m"]]- type params:
Float[Array, 'n']- param params:
Current parameters.
- type params:
Float[Array," n"]- type num_samples:
int, default:10- param num_samples:
Number of random probes. Default is 10.
- type num_samples:
int, optional- returns:
diagonal – Estimated diagonal of J^T J.
- rtype:
Float[Array," n"]
Notes
Uses Hutchinson’s trick: diag(A) ≈ E[z ⊙ (A @ z)] for random z ∈ {-1, +1}^n (Rademacher distribution).
- The diagonal can be used to construct a preconditioner:
M = diag(J^T J + ε)^{-1/2}
Trade-off: More samples → lower variance, higher cost. For most problems, 10 samples is sufficient. For high accuracy or when diagonal entries vary by orders of magnitude, use 50-100 samples.
Examples
>>> diag = jtj_diag(residual_fn, params) >>> precond = 1.0 / jnp.sqrt(diag + 1e-6)
- janssen.utils.max_eigenval(residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], params: Float[Array, 'n'], num_iterations: int = 20) Float[Array, ''][source]¶
Estimate largest eigenvalue of J^T J via power iteration.
A large eigenvalue indicates potential ill-conditioning, which can cause convergence issues. Use this diagnostic to tune damping.
Implementation Logic¶
The function implements the power iteration algorithm:
Setup: - Construct matvec operator: (J^T J) @ v via jtj_matvec - Initialize random unit vector v₀ ~ N(0, I), normalized - Uses fixed seed (42) for reproducibility
Power Iteration Loop (num_iterations times): Internal function power_step(v_curr) performs: a. Compute Av = (J^T J) @ v_curr via matvec
Cost: 1 forward + 1 backward pass through residual_fn
Normalize: v_next = Av / ||Av|| - Ensures numerical stability, prevents overflow
Return v_next (carry state for next iteration)
Uses jax.lax.scan for efficient iteration: - Compiled as single fused kernel by XLA - Avoids Python loop overhead - scan(power_step, v₀, None, length=num_iterations)
Eigenvalue Extraction: - Compute final Rayleigh quotient: λ = v^T (J^T J) v - After k iterations, vₖ ≈ dominant eigenvector - Rayleigh quotient converges to λ_max(J^T J)
Algorithm: Power Iteration¶
For symmetric positive semi-definite A = J^T J:
vₖ₊₁ = A vₖ / ||A vₖ||
Converges to dominant eigenvector v₁ at rate (λ₂/λ₁)^k where λ₁ ≥ λ₂ ≥ … are eigenvalues. Convergence is geometric.
With k=20 iterations and typical spectral gaps, relative error is usually < 1%. For very ill-conditioned problems (λ₁ ≈ λ₂), increase num_iterations to 50-100.
Computational Cost¶
Total: num_iterations × (1 forward + 1 backward pass)
For num_iterations=20: 20 forward + 20 backward passes
Plus 1 final matvec for Rayleigh quotient
Memory: O(n) for storing vₖ
- type residual_fn:
Callable[[Float[Array, 'n']],Float[Array, 'm']]- param residual_fn:
Residual function.
- type residual_fn:
Callable[[Float[Array," n"]],Float[Array," m"]]- type params:
Float[Array, 'n']- param params:
Current parameters.
- type params:
Float[Array," n"]- type num_iterations:
int, default:20- param num_iterations:
Number of power iterations. Default is 20.
- type num_iterations:
int, optional- returns:
lambda_max – Estimate of the largest eigenvalue of J^T J.
- rtype:
Float[Array," "]
Notes
Power iteration computes: v_{k+1} = (J^T J) @ v_k / ||(J^T J) @ v_k||
After convergence, the Rayleigh quotient v^T (J^T J) v gives λ_max.
This is useful for: - Diagnosing ill-conditioning: condition number κ ≈ λ_max / λ_min - Tuning initial damping: set λ₀ ~ 10^{-3} λ_max as starting point - Detecting rank deficiency: if λ_max ≈ 0, J has null space
Examples
>>> lambda_max = max_eigenval(residual_fn, params) >>> print(f"Maximum eigenvalue: {lambda_max:.2e}")
- janssen.utils.gn_solve(state: GaussNewtonState, residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-05, use_preconditioner: bool = False) GaussNewtonState[source]¶
Run Gauss-Newton optimization until convergence or max iterations.
High-level solver that repeatedly calls gn_step until the optimization converges or reaches the maximum iteration limit. This provides a single entry point for running the full optimization.
Implementation Logic¶
The function uses jax.lax.scan for efficient iteration:
Internal Step Function step_fn(carry, _): Wraps gn_step with early stopping logic: a. Check if carry.converged is True b. If converged: return carry unchanged (no-op) c. If not converged: call gn_step(carry, …) d. Return updated state
Uses jax.lax.cond for control flow: - JIT-compatible conditional execution - When converged=True, subsequent iterations are no-ops - Ensures fixed iteration count for XLA compilation
Main Loop: jax.lax.scan(step_fn, state, None, length=max_iterations) - Iterates exactly max_iterations times (required for JIT) - Carries GaussNewtonState through iterations - Stops updating when converged=True (via lax.cond guard) - Returns final_state after max_iterations or convergence
Why scan instead of while_loop?¶
JAX provides two iteration primitives:
jax.lax.while_loop: Stops when condition is False → Non-deterministic iteration count → Harder to JIT-compile (variable-length trace)
jax.lax.scan: Runs fixed number of iterations → Deterministic iteration count → Efficient JIT compilation → Early stopping via lax.cond guard inside body
We use scan with internal lax.cond guard to get both: - Deterministic compilation (fixed max_iterations) - Early stopping (when converged=True, step_fn returns carry)
Trade-off: If convergence happens at iteration 20, iterations 21-100 are no-ops (just carry forwarding). Cost is negligible compared to actual Gauss-Newton steps.
Computational Cost¶
If convergence happens at iteration k < max_iterations: - Active iterations: k × cost(gn_step) - No-op iterations: (max_iterations - k) × cost(lax.cond) - No-op cost is ~microseconds, negligible vs GN step (~seconds)
Total cost ≈ k × [20-40 forward/backward passes per step]
For 100 max iterations with typical convergence at k=30: - Active cost: 30 × 40 = 1200 forward+backward passes - No-op cost: 70 × 0 ≈ 0 (compiler optimizes)
- type state:
GaussNewtonState- param state:
Initial optimization state containing sample, probe, iteration, loss, damping, and convergence status.
- type state:
GaussNewtonState- type residual_fn:
Callable[[Float[Array, 'n']],Float[Array, 'm']]- param residual_fn:
Function mapping flattened parameters to residuals.
- type residual_fn:
Callable[[Float[Array," n"]],Float[Array," m"]]- type max_iterations:
int, default:100- param max_iterations:
Maximum number of Gauss-Newton iterations. Default is 100.
- type max_iterations:
int, optional- type cg_maxiter:
int, default:50- param cg_maxiter:
Maximum conjugate gradient iterations per step. Default is 50.
- type cg_maxiter:
int, optional- type cg_tol:
float, default:1e-05- param cg_tol:
CG convergence tolerance. Default is 1e-5.
- type cg_tol:
float, optional- type use_preconditioner:
bool, default:False- param use_preconditioner:
Whether to use diagonal preconditioning for CG. Default is False.
- type use_preconditioner:
bool, optional- returns:
final_state – Final optimization state after convergence or max iterations.
- rtype:
GaussNewtonState
Notes
The solver stops when either: - The state’s converged flag becomes True - The iteration count reaches max_iterations
For monitoring progress during optimization, consider using this function with jax.lax.scan and a custom body function that logs intermediate state. This function is fully JIT-compatible.
Examples
>>> state = make_gauss_newton_state(sample, probe) >>> final_state = gn_solve(state, my_residual_fn, ... max_iterations=50) >>> print(f"Converged: {final_state.converged}") >>> print(f"Final loss: {final_state.loss}")
- janssen.utils.gn_loss_history(state: GaussNewtonState, residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-05, use_preconditioner: bool = False) tuple[GaussNewtonState, Float[Array, 'N']][source]¶
Run Gauss-Newton optimization and return per-iteration losses.
This variant tracks only the scalar loss per iteration and not the full state history, which keeps memory overhead low for large ptychography problems.
- Return type:
tuple[GaussNewtonState,Float[Array, 'N']]
- janssen.utils.gn_history(state: GaussNewtonState, residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-05, use_preconditioner: bool = False) tuple[GaussNewtonState, GaussNewtonState, Float[Array, 'N']][source]¶
Run Gauss-Newton optimization with per-iteration history tracking.
Like gn_solve but returns intermediate states and losses at each iteration, enabling convergence diagnostics and visualization.
Implementation Logic¶
Same iteration strategy as gn_solve using jax.lax.scan, but scan outputs capture per-iteration states:
Internal Step Function step_fn(carry, _): a. Check if carry.converged is True b. If converged: return carry unchanged (no-op) c. If not converged: call gn_step(carry, …) d. Return updated state and output tuple (state, loss)
Main Loop: jax.lax.scan(step_fn, state, None, length=max_iterations) - Returns both final_state and outputs PyTree - Outputs contain all intermediate states and losses
The scan outputs are PyTrees where each leaf has an additional dimension of size max_iterations. For example, if state.sample has shape (H, W), then all_states.sample has shape (H, W, N) where N = max_iterations.
- type state:
GaussNewtonState- param state:
Initial optimization state containing sample, probe, iteration, loss, damping, and convergence status.
- type state:
GaussNewtonState- type residual_fn:
Callable[[Float[Array, 'n']],Float[Array, 'm']]- param residual_fn:
Function mapping flattened parameters to residuals.
- type residual_fn:
Callable[[Float[Array," n"]],Float[Array," m"]]- type max_iterations:
int, default:100- param max_iterations:
Maximum number of Gauss-Newton iterations. Default is 100.
- type max_iterations:
int, optional- type cg_maxiter:
int, default:50- param cg_maxiter:
Maximum conjugate gradient iterations per step. Default is 50.
- type cg_maxiter:
int, optional- type cg_tol:
float, default:1e-05- param cg_tol:
CG convergence tolerance. Default is 1e-5.
- type cg_tol:
float, optional- type use_preconditioner:
bool, default:False- param use_preconditioner:
Whether to use diagonal preconditioning for CG. Default is False.
- type use_preconditioner:
bool, optional- rtype:
tuple[GaussNewtonState,GaussNewtonState,Float[Array, 'N']]- returns:
final_state (
GaussNewtonState) – Final optimization state after convergence or max iterations.all_states (
GaussNewtonState) – PyTree with all intermediate states. Each field has an extra dimension of size max_iterations stacked along the last axis. For example, if state.sample is (H, W), all_states.sample is (H, W, max_iterations).all_losses (
Float[Array," N"]) – Loss value at each iteration, shape (max_iterations,).
Notes
This function is useful for: - Convergence diagnostics and plotting - Creating visualizations of optimization progress - Debugging optimization issues - Resume/continuation with full history tracking
If you only need the final state (not intermediate history), use gn_solve instead for slightly lower memory usage.
Examples
>>> state = make_gauss_newton_state(sample, probe) >>> final, history, losses = gn_history( ... state, my_residual_fn, max_iterations=50 ... ) >>> print(f"Converged: {final.converged}") >>> print(f"Loss history shape: {losses.shape}") # (50,) >>> print(f"Sample history shape: {history.sample.shape}") # (H, W, 50)
- janssen.utils.gn_step(state: GaussNewtonState, residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], cg_maxiter: int = 50, cg_tol: float = 1e-05, use_preconditioner: bool | Bool[Array, ''] = False) GaussNewtonState[source]¶
Perform Gauss-Newton step with Levenberg-Marquardt damping.
Solves the trust-region subproblem using conjugate gradient:
(J^T J + λI) δ = -J^T r
then updates parameters: θ_{k+1} = θ_k + δ
The damping λ adapts based on actual vs predicted reduction.
Implementation Logic¶
The function executes a complete trust-region Gauss-Newton step:
Phase 1: Setup and Gradient Computation 1. Extract sample and probe shapes from state 2. Flatten complex arrays (sample, probe) → real vector params 3. Compute residuals r and gradient J^T @ r via jt_residual
Cost: 1 forward + 1 backward pass through residual_fn
Compute current loss: 0.5 ||r||²
Phase 2: Linear System Construction 5. Build matrix-vector operator: matvec = jtj_matvec(…)
Computes (J^T J + λI) @ v using JVP/VJP composition
Caches vjp_fn to avoid redundant computation in CG
If use_preconditioner=True: a. Estimate diagonal via Hutchinson: diag_jtj ≈ diag(J^T J)
Uses 10 Rademacher samples (random ±1 vectors)
Cost: 10 × (1 forward + 1 backward pass)
Clamp negative estimates: max(diag_jtj, 0) + λ - Ensures positive definite preconditioner - Negative estimates can occur due to finite sampling
Construct preconditioner: M(v) = v / (diag + λ) - Diagonal scaling: approximates (diag(H))^{-1}
Phase 3: Solve Trust-Region Subproblem 7. Solve (J^T J + λI) δ = -J^T r via conjugate gradient
Inputs: matvec operator, RHS = -J^T r, preconditioner M
Iterative solver: converges in typically 10-50 iterations
Cost per iteration: 1 matvec call = 1 forward + 1 backward
- Total cost: ~20 × (1 forward + 1 backward) without precond
~10 × (1 forward + 1 backward) with precond
Compute step: params_new = params + δ
Unflatten params_new → (sample_new, probe_new)
Phase 4: Evaluate New Point 10. Compute residuals_new = residual_fn(params_new)
Cost: 1 forward pass
Compute new_loss = 0.5 ||residuals_new||²
Phase 5: Trust-Region Decision 12. Compute predicted reduction: pred = 0.5 δ^T (J^T J + λI) δ
Uses h_delta = matvec(δ) already computed
Predicted decrease in quadratic model
Compute actual reduction: actual = current_loss - new_loss
Compute reduction ratio: ρ = actual / predicted - If predicted ≤ 0: set ρ = 0 (non-descent direction)
Accept/reject decision: accept = (actual > 0) AND (predicted > 0) AND (ρ > 0) - Requires all three: loss decreased, descent direction, positive ratio
Phase 6: Damping Adaptation 16. Update damping λ based on ρ (using nested jax.lax.cond):
If predicted ≤ 0: λ ← 10λ (drastic increase, bad step)
Else if ρ > 0.75: λ ← λ/3 (excellent step, be aggressive)
Else if ρ > 0.25: λ ← λ (acceptable step, maintain)
Else: λ ← 3λ (poor step, be conservative)
Clamp damping: λ ∈ [10^{-12}, 10^8] - Prevents both over-damping (slow convergence) and
under-damping (instability)
Phase 7: State Update 18. If accepted:
Use new sample, probe, loss
If rejected: - Keep current sample, probe, loss
Compute relative improvement |current_loss - final_loss| / current_loss
Check convergence: accept AND (rel_improvement < 10^{-8})
Return updated GaussNewtonState with: - sample, probe (updated or kept) - iteration counter incremented - loss (new or current) - damping (adapted) - converged flag
Internal Helper Functions¶
Uses nested function _preconditioner_matvec(v) when preconditioning: - Computes element-wise: v / (diag_with_damping) - Passed as M parameter to jax.scipy.sparse.linalg.cg - Transforms system from (J^T J + λI) δ = -J^T r
to M (J^T J + λI) δ = M (-J^T r)
Accelerates CG by improving condition number
Computational Complexity¶
Without preconditioning: - Setup: 1 forward + 1 backward (gradient) - CG solve: ~20 × (1 forward + 1 backward) typical - Evaluation: 1 forward (new residuals) - Total: ~22 forward + ~21 backward passes
With preconditioning: - Setup: 1 forward + 1 backward (gradient) - Diagonal estimation: 10 × (1 forward + 1 backward) - CG solve: ~10 × (1 forward + 1 backward) typical - Evaluation: 1 forward - Total: ~21 forward + ~21 backward passes
Preconditioning breaks even immediately (fewer CG iterations offset the diagonal estimation cost), and provides 2-5× speedup for ill-conditioned problems.
- type state:
GaussNewtonState- param state:
Current optimization state containing sample, probe, iteration, loss, damping, and convergence status.
- type state:
GaussNewtonState- type residual_fn:
Callable[[Float[Array, 'n']],Float[Array, 'm']]- param residual_fn:
Function mapping flattened parameters to residuals.
- type residual_fn:
Callable[[Float[Array," n"]],Float[Array," m"]]- type cg_maxiter:
int, default:50- param cg_maxiter:
Maximum conjugate gradient iterations. Default is 50.
- type cg_maxiter:
int, optional- type cg_tol:
float, default:1e-05- param cg_tol:
CG convergence tolerance. Default is 1e-5.
- type cg_tol:
float, optional- type use_preconditioner:
Union[bool,Bool[Array, '']], default:False- param use_preconditioner:
Whether to use diagonal preconditioning for CG. Preconditioning can significantly improve convergence rate for ill-conditioned problems. Default is False.
- type use_preconditioner:
bool, optional- returns:
new_state – Updated optimization state after the step.
- rtype:
GaussNewtonState
Notes
This function is generic and works with any residual function. The parameters are flattened from (sample, probe) to a real vector before passing to residual_fn.
Trust-region adaptation: - If ρ > 0.75: very good step, decrease damping (be aggressive) - If 0.25 < ρ < 0.75: acceptable step, keep damping - If ρ < 0.25: poor step, increase damping (be conservative)
where ρ = actual_reduction / predicted_reduction.
When use_preconditioner=True, a diagonal preconditioner is computed using Hutchinson’s estimator: M = diag(diag(J^T J) + λ)^{-1}. This improves CG convergence for ill-conditioned problems.
Convergence detection: - If loss < 1e-14: already at optimum, mark converged - If residuals contain NaN/Inf: mark converged (failure mode) - If relative improvement < 1e-8: normal convergence
CG convergence limitation: JAX’s sparse_linalg.cg currently returns None for the convergence info (second return value). This means we cannot detect when CG fails to converge within cg_maxiter iterations. The only recourse is choosing appropriate cg_maxiter and cg_tol, or enabling preconditioning for ill-conditioned problems. Future JAX versions may provide convergence diagnostics.
Examples
>>> state = make_gauss_newton_state(sample, probe) >>> new_state = gn_step(state, my_residual_fn) >>> # With preconditioning for better convergence: >>> new_state = gn_step(state, my_residual_fn, ... use_preconditioner=True)
- janssen.utils.hessian_matvec(loss_fn: Callable[[Float[Array, 'n']], Float[Array, '']], params: Float[Array, 'n']) Callable[[Float[Array, 'n']], Float[Array, 'n']][source]¶
Create exact Hessian-vector product operator.
For problems where the Gauss-Newton approximation H ≈ J^T J is insufficient, this computes products with the exact Hessian:
H @ v = d/dε [∇L(θ + εv)]|_{ε=0}
using forward-over-reverse autodiff (jvp through grad).
Implementation Logic¶
The function uses forward-over-reverse composition:
Setup Phase (executed once, outside returned closure): grad_fn = jax.grad(loss_fn) - Constructs the gradient function ∇L: R^n → R^n - This is reverse-mode autodiff: cost O(n) per gradient eval - Computing once and reusing avoids redundant graph construction
Hessian-Vector Product (executed per matrix-vector product): Internal function _hvp(v) computes: hv = jax.jvp(grad_fn, (params,), (v,))[1] - Forward-mode autodiff through the gradient function - Computes directional derivative: d/dε [∇L(θ + εv)]|_{ε=0} - This equals H @ v by definition of the Hessian
Mathematical Background¶
The Hessian H = ∇²L is the Jacobian of the gradient:
H @ v = J_∇L @ v = d/dε [∇L(θ + εv)]|_{ε=0}
Forward-over-reverse (fwd ∘ rev) is efficient for matrix-vector products, with cost roughly 2× a gradient evaluation.
Contrast with Gauss-Newton approximation J^T J: - Exact Hessian: H = J^T J + Σᵢ rᵢ ∇²rᵢ - GN approximation: H ≈ J^T J (drops second term) - When residuals rᵢ are small or linear, J^T J is accurate - When residuals are large and nonlinear, exact H is needed
- type loss_fn:
Callable[[Float[Array, 'n']],Float[Array, '']]- param loss_fn:
Scalar loss function L(θ).
- type loss_fn:
Callable[[Float[Array," n"]],Float[Array," "]]- type params:
Float[Array, 'n']- param params:
Current parameters (linearization point).
- type params:
Float[Array," n"]- returns:
hvp – Function computing H @ v for any vector v.
- rtype:
Callable[[Float[Array," n"]],Float[Array," n"]]
Notes
The exact Hessian includes second-order derivative information beyond J^T J, capturing curvature from the residual structure itself. This is important when residuals are highly nonlinear.
Cost: ~2× gradient evaluation per HVP. For large n, this is much cheaper than forming H explicitly (which would cost O(n) gradient evaluations and O(n²) storage).
Examples
>>> def loss_fn(x): ... return jnp.sum(x**4) >>> params = jnp.array([1.0, 2.0]) >>> hvp = hessian_matvec(loss_fn, params) >>> result = hvp(jnp.array([1.0, 0.0]))
- janssen.utils.jtj_matvec(residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], params: Float[Array, 'n'], damping: Float[Array, '']) Callable[[Float[Array, 'n']], Float[Array, 'n']][source]¶
Create matrix-vector product operator for (J^T J + λI).
This is the core of Jacobian-free Gauss-Newton optimization. Rather than forming the Jacobian matrix J explicitly, we compute products with (J^T J) using JAX’s autodiff primitives:
- (J^T J) @ v = J^T @ (J @ v)
= vjp(jvp(v))
where jvp is forward-mode autodiff and vjp is reverse-mode.
Implementation Logic¶
The function uses a two-stage construction:
Setup Phase (executed once, outside returned closure): - Computes vjp_fn = jax.vjp(residual_fn, params)[1] - This captures the reverse-mode autodiff operator J^T @ (·) - Computing vjp_fn once and reusing it in the closure eliminates
redundant computation when matvec is called multiple times
Critical optimization: without this, each matvec call would recompute vjp, doubling the number of forward passes
Matvec Phase (executed per matrix-vector product): Internal function _matvec(v) computes: a. jv = jax.jvp(residual_fn, (params,), (v,))[1]
Forward-mode autodiff: directional derivative J @ v
Cost: 1 forward pass through residual_fn
jtjv = vjp_fn(jv)[0] - Reverse-mode autodiff: J^T @ (J @ v) - Cost: 1 backward pass through residual_fn
return jtjv + damping * v - Adds Levenberg-Marquardt regularization λI
The key insight: J^T J never materialized. Each matvec costs only 1 forward + 1 backward pass, vs O(mn) storage and O(mn) flops for explicit formation.
- type residual_fn:
Callable[[Float[Array, 'n']],Float[Array, 'm']]- param residual_fn:
Function mapping parameters to residuals r(θ).
- type residual_fn:
Callable[[Float[Array," n"]],Float[Array," m"]]- type params:
Float[Array, 'n']- param params:
Current parameters (linearization point).
- type params:
Float[Array," n"]- type damping:
Float[Array, '']- param damping:
Levenberg-Marquardt damping parameter λ ≥ 0.
- type damping:
Float[Array," "]- returns:
matvec – Function computing (J^T J + λI) @ v for any vector v.
- rtype:
Callable[[Float[Array," n"]],Float[Array," n"]]
Notes
Memory complexity is O(m + n) per matrix-vector product, compared to O(m * n) to store J explicitly. For large-scale problems where m, n >> 1, this is essential for tractability.
The vjp_fn is computed once and captured in the closure. This is critical for performance: conjugate gradient will call the returned matvec 10-100 times per Gauss-Newton step. Without caching vjp_fn, we’d waste 50-100 extra forward passes.
Examples
>>> def residual_fn(x): ... return jnp.array([x[0]**2 - 1, x[1]**2 - 4]) >>> params = jnp.array([1.5, 2.5]) >>> matvec = jtj_matvec(residual_fn, params, jnp.array(1e-3)) >>> result = matvec(jnp.array([1.0, 0.0]))
- janssen.utils.flatten_params(sample: Complex[Array, 'Hs Ws'], probe: Complex[Array, 'Hp Wp']) Float[Array, 'n'][source]¶
Flatten complex arrays to real parameter vector for optimization.
Converts two complex 2D arrays into a single real-valued vector by separating real and imaginary components. Handles arrays with different shapes (e.g., sample=512×512 FOV, probe=64×64 spot).
- Parameters:
sample (
Complex[Array," Hs Ws"]) – First complex array (e.g., object transmission function).probe (
Complex[Array," Hp Wp"]) – Second complex array (e.g., probe wavefront).
- Returns:
params – Flattened real parameter vector with n = 2*(Hs*Ws + Hp*Wp). Layout: [sample_real, sample_imag, probe_real, probe_imag]
- Return type:
Float[Array," n"]
Examples
>>> sample = jnp.ones((512, 512), dtype=jnp.complex128) >>> probe = jnp.ones((64, 64), dtype=jnp.complex128) * (1+1j) >>> params = flatten_params(sample, probe) >>> params.shape (532480,) # 2*(512*512 + 64*64)
See also
unflatten_paramsInverse operation to reconstruct complex arrays
- janssen.utils.fourier_shift(field: Complex[Array, 'H W'], shift_x: float | Float[Array, ''], shift_y: float | Float[Array, '']) Complex[Array, 'H W'][source]¶
Shift a 2D field using FFT-based sub-pixel shifting.
Applies a phase ramp in Fourier space to shift the field by (shift_x, shift_y) pixels relative to the center of the image. Supports sub-pixel shifts with high accuracy.
- Parameters:
field (
Complex[Array," H W"]) – Input 2D complex field to shift. The field is assumed to be centered (i.e., the center of the image is the origin).shift_x (
float) – Shift in x direction (columns) in pixels. Positive shifts move the field to the right.shift_y (
float) – Shift in y direction (rows) in pixels. Positive shifts move the field downward.
- Returns:
shifted – Shifted field with same shape as input.
- Return type:
Complex[Array," H W"]
Notes
The shift is implemented by multiplying the Fourier transform of the field by a phase ramp:
\[\begin{split}F_{shifted}(f_x, f_y) = F(f_x, f_y) \\cdot \\exp(-2\\pi i (f_x \\cdot \\Delta x + f_y \\cdot \\Delta y))\end{split}\]This is equivalent to circular shifting with sub-pixel accuracy. The shift is relative to the center of the image, so a shift of (0, 0) leaves the field unchanged.
Examples
>>> import jax.numpy as jnp >>> field = jnp.ones((64, 64), dtype=jnp.complex128) >>> shifted = fourier_shift(field, 10.5, -5.25) # Sub-pixel shift
- janssen.utils.unflatten_params(params: Float[Array, 'n'], sample_shape: tuple[int, int], probe_shape: tuple[int, int]) tuple[Complex[Array, 'Hs Ws'], Complex[Array, 'Hp Wp']][source]¶
Unflatten real parameter vector back to complex arrays.
Reconstructs two complex 2D arrays from a flattened real parameter vector. Handles arrays with different shapes (e.g., sample=512×512 FOV, probe=64×64 spot).
- Parameters:
params (
Float[Array," n"]) – Flattened real parameter vector with n = 2*(Hs*Ws + Hp*Wp).sample_shape (
Tuple[int,int]) – Shape (Hs, Ws) of the sample array.probe_shape (
Tuple[int,int]) – Shape (Hp, Wp) of the probe array.
- Return type:
tuple[Complex[Array, 'Hs Ws'],Complex[Array, 'Hp Wp']]- Returns:
sample (
Complex[Array," Hs Ws"]) – First complex array reconstructed from first 2*Hs*Ws elements.probe (
Complex[Array," Hp Wp"]) – Second complex array reconstructed from remaining elements.
Notes
This function cannot be JIT-compiled directly because it uses dynamic indexing based on the shape parameters. When used within a JIT-compiled function, ensure shapes are static via static_argnums.
Examples
>>> params = jnp.arange(532480, dtype=jnp.float64) >>> sample, probe = unflatten_params(params, (512, 512), (64, 64)) >>> sample.shape, probe.shape ((512, 512), (64, 64))
See also
flatten_paramsForward operation to create flattened vector
- janssen.utils.wirtinger_grad(func2diff: Callable[[...], Float[Array, '...']], argnums: int | Sequence[int] | None = 0) Callable[[...], Complex[Array, '...'] | tuple[Complex[Array, '...'], ...]][source]¶
Compute the Wirtinger gradient of a complex-valued function.
This function returns a new function that computes the Wirtinger gradient of the input function f with respect to the specified argument(s). This is based on the formula for Wirtinger derivative:
\[\frac{\partial f}{\partial z} = \frac{1}{2} \left( \frac{\partial f}{\partial x} - i \frac{\partial f}{\partial y} \right)\]- Parameters:
func2diff (
Callable[...,Float[Array," ..."]]) – A complex-valued function to differentiate.argnums (
Union[int,Sequence[int]], optional) – Specifies which argument(s) to compute the gradient with respect to. Can be an int or a sequence of ints. Default is 0.
- Returns:
grad_f – A function that computes the Wirtinger gradient of f with respect to the specified argument(s). Returns a single array if argnums is an int, or a tuple of arrays if argnums is a sequence.
- Return type:
Callable
Notes
The Wirtinger derivative is essential for optimizing real-valued loss functions with respect to complex-valued parameters. It provides the correct gradient direction for steepest descent in the complex plane.
Examples
>>> import jax.numpy as jnp >>> def loss(z): ... return jnp.sum(jnp.abs(z)**2) >>> grad_fn = wirtinger_grad(loss) >>> z = jnp.array([1+2j, 3+4j]) >>> grad_fn(z) # Returns the Wirtinger gradient