janssen.utils

Utility functions for distributed computing and math operations.

Extended Summary

Utilities for distributed JAX computing across multiple devices, mathematical helper functions for complex-valued operations, and general-purpose optimization algorithms.

Routine Listings

bessel_j0()

Compute J_0(x), regular Bessel function of the first kind, order 0.

bessel_jn()

Compute J_n(x), regular Bessel function of the first kind, order n.

bessel_kv()

Compute K_v(x), modified Bessel function of the second kind.

jt_residual()

Compute residuals and J^T @ r simultaneously.

create_mesh()

Creates a device mesh for data parallelism across available devices.

jtj_diag()

Estimate diagonal of J^T J for preconditioning.

max_eigenval()

Estimate largest eigenvalue of J^T J via power iteration.

flatten_params()

Flatten complex arrays to real parameter vector for optimization.

fourier_shift()

FFT-based sub-pixel shifting of 2D fields.

gn_solve()

High-level solver that runs Gauss-Newton until convergence.

gn_loss_history()

Gauss-Newton solver with per-iteration loss tracking only.

gn_history()

Gauss-Newton solver with per-iteration history tracking.

gn_step()

Generic Gauss-Newton step with trust-region damping.

get_device_count()

Gets the number of available JAX devices.

get_device_memory_gb()

Detects device count and memory per device (GB) for GPUs/CPUs.

hessian_matvec()

Create exact Hessian-vector product operator.

jtj_matvec()

Create Jacobian-free (J^T J + λI) operator for Gauss-Newton.

shard_batch()

Shards array data across the batch dimension for parallel processing.

unflatten_params()

Unflatten real parameter vector back to complex arrays.

wirtinger_grad()

Compute the Wirtinger gradient of a complex-valued function.

Notes

For type definitions and PyTree classes, see the janssen.types module.

janssen.utils.bessel_j0(x: Float[Array, '...']) Float[Array, '...'][source]

Compute J_0(x), regular Bessel function of the first kind, order 0.

Parameters:

x (Float[Array, "..."]) – Input array.

Returns:

Values of J_0(x).

Return type:

Float[Array, " ..."]

Notes

This is a wrapper around JAX’s scipy implementation for consistency. J_0(x) is the regular Bessel function of the first kind of order 0.

The function is differentiable, JIT-compatible, and supports broadcasting.

Examples

>>> x = jnp.linspace(0, 10, 100)
>>> j0_vals = bessel_j0(x)
janssen.utils.bessel_jn(n: int | Int[Array, ''], x: Float[Array, '...']) Float[Array, '...'][source]

Compute J_n(x), regular Bessel function of the first kind, order n.

Parameters:
  • n (int) – Order of the Bessel function (integer). Must be a compile-time constant for JIT compilation.

  • x (Float[Array, "..."]) – Input array.

Returns:

Values of J_n(x).

Return type:

Float[Array, " ..."]

Notes

This is a wrapper around JAX’s scipy implementation for consistency. J_n(x) is the regular Bessel function of the first kind of order n.

The function is differentiable and supports broadcasting. Note that the order n must be a compile-time constant (not a traced value) when used inside JIT-compiled functions.

Examples

>>> x = jnp.linspace(0, 10, 100)
>>> j1_vals = bessel_jn(1, x)
>>> j2_vals = bessel_jn(2, x)
janssen.utils.bessel_kv(v: float | Float[Array, ''], x: Float[Array, '...']) Float[Array, '...'][source]

Compute the modified Bessel function of the second kind K_v(x).

Parameters:
  • v (float) – Order of the Bessel function (v >= 0).

  • x (Float[Array, "..."]) – Positive real input array.

Returns:

Approximated values of K_v(x).

Return type:

Float[Array, " ..."]

Notes

Computes K_v(x) for real order v >= 0 and x > 0, using a numerically stable and differentiable JAX-compatible approximation.

  • Valid for v >= 0 and x > 0

  • Supports broadcasting and autodiff

  • JIT-safe and VMAP-safe

  • Uses series expansion for small x (x <= 2.0) and asymptotic expansion for large x

  • For non-integer v, uses the reflection formula: K_v = π/(2sin(πv)) * (I_{-v} - I_v)

  • For integer v, uses specialized series and recurrence relations

  • Special exact formula for v = 0.5: K_{1/2}(x) = sqrt(π/(2x)) * exp(-x)

  • The transition point between small and large x approximations is set at x = 2.0

Algorithm

  • For integer orders n > 1, uses recurrence relations with masked updates to only update values within the target range

janssen.utils.create_mesh(n_devices: int | None = None) Mesh[source]

Create a device mesh for data parallelism.

Creates a 1D device mesh suitable for sharding arrays across their batch dimension. The mesh can be used with sharding specifications to distribute computation across multiple devices.

Parameters:

n_devices (int, optional) – Number of devices to use in the mesh. If None, uses all available devices detected by JAX (default: None).

Returns:

mesh – Device mesh with axis name ‘batch’ for data parallelism. The mesh contains a linear arrangement of devices for distributing batched computations.

Return type:

Mesh

Examples

>>> # Create mesh with all available devices
>>> mesh = create_mesh()
>>> print(mesh.shape)
{'batch': 4}  # If 4 devices are available
>>> # Create mesh with specific number of devices
>>> mesh = create_mesh(n_devices=2)
>>> print(mesh.shape)
{'batch': 2}

Notes

The returned mesh has a single axis named ‘batch’, making it suitable for distributing the first dimension of arrays across devices. For more complex sharding patterns, consider creating custom meshes using jax.sharding.Mesh directly.

janssen.utils.get_device_count() int[source]

Get number of available JAX devices.

Returns:

n_devices – Number of available accelerators (GPUs/TPUs)

Return type:

int

Examples

>>> import janssen as jns
>>> n = jns.utils.get_device_count()
>>> print(f"Found {n} devices")
Found 8 devices
janssen.utils.get_device_memory_gb() tuple[int, float][source]

Detect device count and memory per device.

Attempts to detect the number of JAX devices and total memory per device using platform-specific methods: - NVIDIA GPUs: nvidia-smi - CPUs: system RAM via psutil or /proc/meminfo - Other platforms: conservative fallback

Return type:

tuple[int, float]

Returns:

  • num_devices (int) – Number of JAX devices detected via jax.device_count()

  • memory_per_device_gb (float) – Memory per device in GB. For CPUs, returns total system RAM. Falls back to 16.0 GB for GPUs or 8.0 GB for unknown platforms.

Notes

Detection Methods by Platform:

NVIDIA GPUs: - Uses nvidia-smi to query first GPU’s total memory - Converts from MB to GB - Multi-GPU systems: assumes all GPUs have same memory

CPUs: - First tries psutil.virtual_memory() for cross-platform detection - Falls back to /proc/meminfo on Linux if psutil unavailable - Returns total system RAM (all devices share this pool)

TPUs/Other Accelerators: - No direct memory detection available - Falls back to conservative defaults

Fallback Values:

GPU fallback (16.0 GB) works for: - NVIDIA V100 (16 GB variant) - NVIDIA Tesla T4 (16 GB) - NVIDIA RTX 6000 (24 GB, safe to use 16) - NVIDIA A100 (40 GB variant, safe to use 16)

CPU fallback (8.0 GB): - Conservative for modern systems (most have 16+ GB) - Prevents OOM on low-memory machines

Platform Limitations:

  • NVIDIA GPU: Only detects first GPU memory

  • AMD GPU (ROCm): No detection, uses fallback

  • Intel GPU: No detection, uses fallback

  • Google TPU: No detection, uses fallback

  • CPU: Detects total system RAM (shared across all cores)

Examples

>>> from janssen.utils.distributed import get_device_memory_gb
>>> num_devices, memory = get_device_memory_gb()
>>> print(f"Detected {num_devices} devices with {memory:.1f} GB each")
Detected 8 devices with 16.0 GB each

See also

get_device_count

Get number of available devices

janssen.utils.shard_batch(data: Shaped[Array, '...'], mesh: Mesh) Shaped[Array, '...'][source]

Shard data across batch dimension.

Distributes an array’s first dimension (batch dimension) across devices in the provided mesh. This enables parallel processing of batched data with automatic memory distribution and computation parallelism.

Parameters:
  • data (Shaped[Array, " ..."]) – Input array to shard. The first dimension is treated as the batch dimension and will be distributed across devices. Can be any JAX or NumPy array.

  • mesh (Mesh) – Device mesh created by create_mesh() or custom mesh with a ‘batch’ axis. Defines how the data will be distributed across devices.

Returns:

sharded_data – Input array with the batch dimension sharded across devices in the mesh. The array’s computation will be automatically parallelized across devices.

Return type:

Shaped[Array, " ..."]

Examples

>>> import jax.numpy as jnp
>>> from janssen.utils.distributed import create_mesh, shard_batch
>>>
>>> # Create sample data with batch dimension
>>> data = jnp.ones((8, 256, 256))
>>>
>>> # Create mesh and shard data
>>> mesh = create_mesh()
>>> sharded_data = shard_batch(data, mesh)
>>>
>>> # The first dimension is now distributed across devices
>>> # Operations on sharded_data will run in parallel

Notes

  • The batch dimension size should ideally be divisible by the number of devices for optimal load balancing.

  • Sharding is applied using NamedSharding with PartitionSpec(‘batch’), which partitions only the first dimension.

  • Subsequent operations on the sharded array will automatically maintain the sharding pattern where possible.

janssen.utils.jt_residual(residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], params: Float[Array, 'n'], weights: int | float | complex | Num[Array, ''] | Float[Array, 'm'] = -1.0) tuple[Float[Array, 'm'], Float[Array, 'n']][source]

Compute residuals and J^T @ W @ r simultaneously.

Uses reverse-mode autodiff to compute the gradient of the loss L = 0.5 * r^T W r, which equals J^T @ W @ r.

For unweighted least squares (W = I), this reduces to J^T @ r.

Implementation Logic

The function exploits the chain rule via reverse-mode autodiff:

  1. Compute residuals and VJP function: residuals, vjp_fn = jax.vjp(residual_fn, params) - Evaluates r(θ) at current parameters - Constructs vjp_fn: vector → J^T @ vector - Single forward pass through residual_fn

  2. Weight handling (conditional based on weights parameter): - If weights < 0 (sentinel): set weights_normalized = ones(m)

    → Unweighted case W = I

    • Else: broadcast weights to residual shape → Weighted case W = diag(weights)

    • Uses jax.lax.cond for JIT compatibility (control flow traced)

  3. Apply weighting and compute weighted gradient: weighted_residuals = weights_normalized * residuals jt_wr = vjp_fn(weighted_residuals)[0] - Element-wise multiplication: W @ r for diagonal W - vjp_fn(W @ r) = J^T @ W @ r by chain rule - Single backward pass through residual_fn

Total cost: 1 forward + 1 backward pass through residual_fn, regardless of whether weighting is used.

Mathematical Derivation

For loss L(θ) = 0.5 * ||W^{1/2} r(θ)||² = 0.5 r(θ)^T W r(θ):

∇L = J^T W r

where J is the Jacobian ∂r/∂θ. The VJP operator computes J^T @ v for any vector v. Setting v = W @ r gives J^T @ W @ r directly.

type residual_fn:

Callable[[Float[Array, 'n']], Float[Array, 'm']]

param residual_fn:

Function mapping parameters to residuals.

type residual_fn:

Callable[[Float[Array, " n"]], Float[Array, " m"]]

type params:

Float[Array, 'n']

param params:

Current parameters.

type params:

Float[Array, " n"]

type weights:

Union[int, float, complex, Num[Array, ''], Float[Array, 'm']], default: -1.0

param weights:

Diagonal weight matrix entries for weighted least squares. If -1.0 (default), assumes unweighted (W = I).

type weights:

Union[ScalarNumeric, Float[Array, " m"]], optional

rtype:

tuple[Float[Array, 'm'], Float[Array, 'n']]

returns:
  • residuals (Float[Array, " m"]) – Current residual vector r(θ).

  • jt_wr (Float[Array, " n"]) – Weighted gradient J^T @ W @ r = ∇(0.5 * r^T W r).

Notes

For diagonal weighting W = diag(w), the weighted gradient is J^T @ W @ r where each residual r_i is weighted by w_i.

Examples

>>> def residual_fn(x):
...     return x**2 - jnp.array([1.0, 4.0])
>>> params = jnp.array([1.5, 2.5])
>>> r, jt_r = jt_residual(residual_fn, params)
>>> # Weighted version:
>>> weights = jnp.array([2.0, 0.5])
>>> r, jt_wr = jt_residual(residual_fn, params, weights)
janssen.utils.jtj_diag(residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], params: Float[Array, 'n'], num_samples: int = 10) Float[Array, 'n'][source]

Estimate diagonal of J^T J via Hutchinson’s trace estimator.

The diagonal approximates per-parameter sensitivities and is useful for diagonal preconditioning in conjugate gradient.

Implementation Logic

The function implements Hutchinson’s stochastic trace estimator:

  1. Setup: - Construct matvec operator: (J^T J) @ v via jtj_matvec - Generate num_samples independent random keys

  2. Per-Sample Estimation (parallelized via vmap): Internal function estimate_one(key) performs: a. Sample probe vector z ~ Rademacher({-1, +1}^n)

    • Each entry independently ±1 with equal probability

    • Uses params.dtype to match precision (float32/float64)

    1. Compute Az = (J^T J) @ z via matvec - Cost: 1 forward + 1 backward pass through residual_fn

    2. Compute element-wise product: z ⊙ Az - This is the single-sample diagonal estimate

    3. Return z ⊙ Az

    Uses jax.vmap(estimate_one)(keys): - Vectorizes over num_samples keys - Compiles to parallel batch of matrix-vector products - Returns shape (num_samples, n) stacked estimates

  3. Averaging: - Compute mean across samples: diagonal = mean(estimates, axis=0) - Reduces variance by √(num_samples)

Algorithm: Hutchinson’s Estimator

For any symmetric matrix A, diagonal entries satisfy:

A_ii = E_z[z_i (A @ z)_i]

where z ~ Rademacher({-1, +1}^n). This is unbiased:

E[z_i (A @ z)_i] = E[z_i Σⱼ A_ij z_j]

= Σⱼ A_ij E[z_i z_j] = A_ii

since E[z_i z_j] = δ_ij (orthogonality of Rademacher).

Variance and Sampling

The variance of the single-sample estimator is:

Var[z_i (A @ z)_i] ≈ 2 Σⱼ A_ij²

For matrices with strong off-diagonal structure (like J^T J in ptychography due to overlapping scan positions), this variance can be large. The estimator variance scales as:

Var[diagonal_estimate] ~ (2/num_samples) Σⱼ A_ij²

With num_samples=10, standard error is ~√(0.2 Σⱼ A_ij²). For ill-conditioned problems, increase to 50-100 samples.

Why Rademacher? Could also use Gaussian, but Rademacher: - Has lower variance for heavy-tailed eigenvalue distributions - Is memory-efficient (binary values) - Allows exact integer arithmetic in some contexts

Note: Individual diagonal estimates can be negative due to finite sampling, even though true diagonal entries are non-negative (J^T J is positive semi-definite). When using for preconditioning, clamp to zero: max(diagonal, 0) before inversion.

Computational Cost

  • Total: num_samples × (1 forward + 1 backward pass)

  • For num_samples=10: 10 forward + 10 backward passes

  • All samples computed in parallel via vmap (batch size = num_samples)

  • Memory: O(num_samples × n) for stacked estimates

type residual_fn:

Callable[[Float[Array, 'n']], Float[Array, 'm']]

param residual_fn:

Residual function.

type residual_fn:

Callable[[Float[Array, " n"]], Float[Array, " m"]]

type params:

Float[Array, 'n']

param params:

Current parameters.

type params:

Float[Array, " n"]

type num_samples:

int, default: 10

param num_samples:

Number of random probes. Default is 10.

type num_samples:

int, optional

returns:

diagonal – Estimated diagonal of J^T J.

rtype:

Float[Array, " n"]

Notes

Uses Hutchinson’s trick: diag(A) ≈ E[z ⊙ (A @ z)] for random z ∈ {-1, +1}^n (Rademacher distribution).

The diagonal can be used to construct a preconditioner:

M = diag(J^T J + ε)^{-1/2}

Trade-off: More samples → lower variance, higher cost. For most problems, 10 samples is sufficient. For high accuracy or when diagonal entries vary by orders of magnitude, use 50-100 samples.

Examples

>>> diag = jtj_diag(residual_fn, params)
>>> precond = 1.0 / jnp.sqrt(diag + 1e-6)
janssen.utils.max_eigenval(residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], params: Float[Array, 'n'], num_iterations: int = 20) Float[Array, ''][source]

Estimate largest eigenvalue of J^T J via power iteration.

A large eigenvalue indicates potential ill-conditioning, which can cause convergence issues. Use this diagnostic to tune damping.

Implementation Logic

The function implements the power iteration algorithm:

  1. Setup: - Construct matvec operator: (J^T J) @ v via jtj_matvec - Initialize random unit vector v₀ ~ N(0, I), normalized - Uses fixed seed (42) for reproducibility

  2. Power Iteration Loop (num_iterations times): Internal function power_step(v_curr) performs: a. Compute Av = (J^T J) @ v_curr via matvec

    • Cost: 1 forward + 1 backward pass through residual_fn

    1. Normalize: v_next = Av / ||Av|| - Ensures numerical stability, prevents overflow

    2. Return v_next (carry state for next iteration)

    Uses jax.lax.scan for efficient iteration: - Compiled as single fused kernel by XLA - Avoids Python loop overhead - scan(power_step, v₀, None, length=num_iterations)

  3. Eigenvalue Extraction: - Compute final Rayleigh quotient: λ = v^T (J^T J) v - After k iterations, vₖ ≈ dominant eigenvector - Rayleigh quotient converges to λ_max(J^T J)

Algorithm: Power Iteration

For symmetric positive semi-definite A = J^T J:

vₖ₊₁ = A vₖ / ||A vₖ||

Converges to dominant eigenvector v₁ at rate (λ₂/λ₁)^k where λ₁ ≥ λ₂ ≥ … are eigenvalues. Convergence is geometric.

With k=20 iterations and typical spectral gaps, relative error is usually < 1%. For very ill-conditioned problems (λ₁ ≈ λ₂), increase num_iterations to 50-100.

Computational Cost

  • Total: num_iterations × (1 forward + 1 backward pass)

  • For num_iterations=20: 20 forward + 20 backward passes

  • Plus 1 final matvec for Rayleigh quotient

  • Memory: O(n) for storing vₖ

type residual_fn:

Callable[[Float[Array, 'n']], Float[Array, 'm']]

param residual_fn:

Residual function.

type residual_fn:

Callable[[Float[Array, " n"]], Float[Array, " m"]]

type params:

Float[Array, 'n']

param params:

Current parameters.

type params:

Float[Array, " n"]

type num_iterations:

int, default: 20

param num_iterations:

Number of power iterations. Default is 20.

type num_iterations:

int, optional

returns:

lambda_max – Estimate of the largest eigenvalue of J^T J.

rtype:

Float[Array, " "]

Notes

Power iteration computes: v_{k+1} = (J^T J) @ v_k / ||(J^T J) @ v_k||

After convergence, the Rayleigh quotient v^T (J^T J) v gives λ_max.

This is useful for: - Diagnosing ill-conditioning: condition number κ ≈ λ_max / λ_min - Tuning initial damping: set λ₀ ~ 10^{-3} λ_max as starting point - Detecting rank deficiency: if λ_max ≈ 0, J has null space

Examples

>>> lambda_max = max_eigenval(residual_fn, params)
>>> print(f"Maximum eigenvalue: {lambda_max:.2e}")
janssen.utils.gn_solve(state: GaussNewtonState, residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-05, use_preconditioner: bool = False) GaussNewtonState[source]

Run Gauss-Newton optimization until convergence or max iterations.

High-level solver that repeatedly calls gn_step until the optimization converges or reaches the maximum iteration limit. This provides a single entry point for running the full optimization.

Implementation Logic

The function uses jax.lax.scan for efficient iteration:

  1. Internal Step Function step_fn(carry, _): Wraps gn_step with early stopping logic: a. Check if carry.converged is True b. If converged: return carry unchanged (no-op) c. If not converged: call gn_step(carry, …) d. Return updated state

    Uses jax.lax.cond for control flow: - JIT-compatible conditional execution - When converged=True, subsequent iterations are no-ops - Ensures fixed iteration count for XLA compilation

  2. Main Loop: jax.lax.scan(step_fn, state, None, length=max_iterations) - Iterates exactly max_iterations times (required for JIT) - Carries GaussNewtonState through iterations - Stops updating when converged=True (via lax.cond guard) - Returns final_state after max_iterations or convergence

Why scan instead of while_loop?

JAX provides two iteration primitives:

  • jax.lax.while_loop: Stops when condition is False → Non-deterministic iteration count → Harder to JIT-compile (variable-length trace)

  • jax.lax.scan: Runs fixed number of iterations → Deterministic iteration count → Efficient JIT compilation → Early stopping via lax.cond guard inside body

We use scan with internal lax.cond guard to get both: - Deterministic compilation (fixed max_iterations) - Early stopping (when converged=True, step_fn returns carry)

Trade-off: If convergence happens at iteration 20, iterations 21-100 are no-ops (just carry forwarding). Cost is negligible compared to actual Gauss-Newton steps.

Computational Cost

If convergence happens at iteration k < max_iterations: - Active iterations: k × cost(gn_step) - No-op iterations: (max_iterations - k) × cost(lax.cond) - No-op cost is ~microseconds, negligible vs GN step (~seconds)

Total cost ≈ k × [20-40 forward/backward passes per step]

For 100 max iterations with typical convergence at k=30: - Active cost: 30 × 40 = 1200 forward+backward passes - No-op cost: 70 × 0 ≈ 0 (compiler optimizes)

type state:

GaussNewtonState

param state:

Initial optimization state containing sample, probe, iteration, loss, damping, and convergence status.

type state:

GaussNewtonState

type residual_fn:

Callable[[Float[Array, 'n']], Float[Array, 'm']]

param residual_fn:

Function mapping flattened parameters to residuals.

type residual_fn:

Callable[[Float[Array, " n"]], Float[Array, " m"]]

type max_iterations:

int, default: 100

param max_iterations:

Maximum number of Gauss-Newton iterations. Default is 100.

type max_iterations:

int, optional

type cg_maxiter:

int, default: 50

param cg_maxiter:

Maximum conjugate gradient iterations per step. Default is 50.

type cg_maxiter:

int, optional

type cg_tol:

float, default: 1e-05

param cg_tol:

CG convergence tolerance. Default is 1e-5.

type cg_tol:

float, optional

type use_preconditioner:

bool, default: False

param use_preconditioner:

Whether to use diagonal preconditioning for CG. Default is False.

type use_preconditioner:

bool, optional

returns:

final_state – Final optimization state after convergence or max iterations.

rtype:

GaussNewtonState

Notes

The solver stops when either: - The state’s converged flag becomes True - The iteration count reaches max_iterations

For monitoring progress during optimization, consider using this function with jax.lax.scan and a custom body function that logs intermediate state. This function is fully JIT-compatible.

Examples

>>> state = make_gauss_newton_state(sample, probe)
>>> final_state = gn_solve(state, my_residual_fn,
...                                   max_iterations=50)
>>> print(f"Converged: {final_state.converged}")
>>> print(f"Final loss: {final_state.loss}")
janssen.utils.gn_loss_history(state: GaussNewtonState, residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-05, use_preconditioner: bool = False) tuple[GaussNewtonState, Float[Array, 'N']][source]

Run Gauss-Newton optimization and return per-iteration losses.

This variant tracks only the scalar loss per iteration and not the full state history, which keeps memory overhead low for large ptychography problems.

Return type:

tuple[GaussNewtonState, Float[Array, 'N']]

janssen.utils.gn_history(state: GaussNewtonState, residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-05, use_preconditioner: bool = False) tuple[GaussNewtonState, GaussNewtonState, Float[Array, 'N']][source]

Run Gauss-Newton optimization with per-iteration history tracking.

Like gn_solve but returns intermediate states and losses at each iteration, enabling convergence diagnostics and visualization.

Implementation Logic

Same iteration strategy as gn_solve using jax.lax.scan, but scan outputs capture per-iteration states:

  1. Internal Step Function step_fn(carry, _): a. Check if carry.converged is True b. If converged: return carry unchanged (no-op) c. If not converged: call gn_step(carry, …) d. Return updated state and output tuple (state, loss)

  2. Main Loop: jax.lax.scan(step_fn, state, None, length=max_iterations) - Returns both final_state and outputs PyTree - Outputs contain all intermediate states and losses

The scan outputs are PyTrees where each leaf has an additional dimension of size max_iterations. For example, if state.sample has shape (H, W), then all_states.sample has shape (H, W, N) where N = max_iterations.

type state:

GaussNewtonState

param state:

Initial optimization state containing sample, probe, iteration, loss, damping, and convergence status.

type state:

GaussNewtonState

type residual_fn:

Callable[[Float[Array, 'n']], Float[Array, 'm']]

param residual_fn:

Function mapping flattened parameters to residuals.

type residual_fn:

Callable[[Float[Array, " n"]], Float[Array, " m"]]

type max_iterations:

int, default: 100

param max_iterations:

Maximum number of Gauss-Newton iterations. Default is 100.

type max_iterations:

int, optional

type cg_maxiter:

int, default: 50

param cg_maxiter:

Maximum conjugate gradient iterations per step. Default is 50.

type cg_maxiter:

int, optional

type cg_tol:

float, default: 1e-05

param cg_tol:

CG convergence tolerance. Default is 1e-5.

type cg_tol:

float, optional

type use_preconditioner:

bool, default: False

param use_preconditioner:

Whether to use diagonal preconditioning for CG. Default is False.

type use_preconditioner:

bool, optional

rtype:

tuple[GaussNewtonState, GaussNewtonState, Float[Array, 'N']]

returns:
  • final_state (GaussNewtonState) – Final optimization state after convergence or max iterations.

  • all_states (GaussNewtonState) – PyTree with all intermediate states. Each field has an extra dimension of size max_iterations stacked along the last axis. For example, if state.sample is (H, W), all_states.sample is (H, W, max_iterations).

  • all_losses (Float[Array, " N"]) – Loss value at each iteration, shape (max_iterations,).

Notes

This function is useful for: - Convergence diagnostics and plotting - Creating visualizations of optimization progress - Debugging optimization issues - Resume/continuation with full history tracking

If you only need the final state (not intermediate history), use gn_solve instead for slightly lower memory usage.

Examples

>>> state = make_gauss_newton_state(sample, probe)
>>> final, history, losses = gn_history(
...     state, my_residual_fn, max_iterations=50
... )
>>> print(f"Converged: {final.converged}")
>>> print(f"Loss history shape: {losses.shape}")  # (50,)
>>> print(f"Sample history shape: {history.sample.shape}")  # (H, W, 50)
janssen.utils.gn_step(state: GaussNewtonState, residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], cg_maxiter: int = 50, cg_tol: float = 1e-05, use_preconditioner: bool | Bool[Array, ''] = False) GaussNewtonState[source]

Perform Gauss-Newton step with Levenberg-Marquardt damping.

Solves the trust-region subproblem using conjugate gradient:

(J^T J + λI) δ = -J^T r

then updates parameters: θ_{k+1} = θ_k + δ

The damping λ adapts based on actual vs predicted reduction.

Implementation Logic

The function executes a complete trust-region Gauss-Newton step:

Phase 1: Setup and Gradient Computation 1. Extract sample and probe shapes from state 2. Flatten complex arrays (sample, probe) → real vector params 3. Compute residuals r and gradient J^T @ r via jt_residual

  • Cost: 1 forward + 1 backward pass through residual_fn

  1. Compute current loss: 0.5 ||r||²

Phase 2: Linear System Construction 5. Build matrix-vector operator: matvec = jtj_matvec(…)

  • Computes (J^T J + λI) @ v using JVP/VJP composition

  • Caches vjp_fn to avoid redundant computation in CG

  1. If use_preconditioner=True: a. Estimate diagonal via Hutchinson: diag_jtj ≈ diag(J^T J)

    • Uses 10 Rademacher samples (random ±1 vectors)

    • Cost: 10 × (1 forward + 1 backward pass)

    1. Clamp negative estimates: max(diag_jtj, 0) + λ - Ensures positive definite preconditioner - Negative estimates can occur due to finite sampling

    2. Construct preconditioner: M(v) = v / (diag + λ) - Diagonal scaling: approximates (diag(H))^{-1}

Phase 3: Solve Trust-Region Subproblem 7. Solve (J^T J + λI) δ = -J^T r via conjugate gradient

  • Inputs: matvec operator, RHS = -J^T r, preconditioner M

  • Iterative solver: converges in typically 10-50 iterations

  • Cost per iteration: 1 matvec call = 1 forward + 1 backward

  • Total cost: ~20 × (1 forward + 1 backward) without precond

    ~10 × (1 forward + 1 backward) with precond

  1. Compute step: params_new = params + δ

  2. Unflatten params_new → (sample_new, probe_new)

Phase 4: Evaluate New Point 10. Compute residuals_new = residual_fn(params_new)

  • Cost: 1 forward pass

  1. Compute new_loss = 0.5 ||residuals_new||²

Phase 5: Trust-Region Decision 12. Compute predicted reduction: pred = 0.5 δ^T (J^T J + λI) δ

  • Uses h_delta = matvec(δ) already computed

  • Predicted decrease in quadratic model

  1. Compute actual reduction: actual = current_loss - new_loss

  2. Compute reduction ratio: ρ = actual / predicted - If predicted ≤ 0: set ρ = 0 (non-descent direction)

  3. Accept/reject decision: accept = (actual > 0) AND (predicted > 0) AND (ρ > 0) - Requires all three: loss decreased, descent direction, positive ratio

Phase 6: Damping Adaptation 16. Update damping λ based on ρ (using nested jax.lax.cond):

  • If predicted ≤ 0: λ ← 10λ (drastic increase, bad step)

  • Else if ρ > 0.75: λ ← λ/3 (excellent step, be aggressive)

  • Else if ρ > 0.25: λ ← λ (acceptable step, maintain)

  • Else: λ ← 3λ (poor step, be conservative)

  1. Clamp damping: λ ∈ [10^{-12}, 10^8] - Prevents both over-damping (slow convergence) and

    under-damping (instability)

Phase 7: State Update 18. If accepted:

  • Use new sample, probe, loss

  1. If rejected: - Keep current sample, probe, loss

  2. Compute relative improvement |current_loss - final_loss| / current_loss

  3. Check convergence: accept AND (rel_improvement < 10^{-8})

  4. Return updated GaussNewtonState with: - sample, probe (updated or kept) - iteration counter incremented - loss (new or current) - damping (adapted) - converged flag

Internal Helper Functions

Uses nested function _preconditioner_matvec(v) when preconditioning: - Computes element-wise: v / (diag_with_damping) - Passed as M parameter to jax.scipy.sparse.linalg.cg - Transforms system from (J^T J + λI) δ = -J^T r

to M (J^T J + λI) δ = M (-J^T r)

  • Accelerates CG by improving condition number

Computational Complexity

Without preconditioning: - Setup: 1 forward + 1 backward (gradient) - CG solve: ~20 × (1 forward + 1 backward) typical - Evaluation: 1 forward (new residuals) - Total: ~22 forward + ~21 backward passes

With preconditioning: - Setup: 1 forward + 1 backward (gradient) - Diagonal estimation: 10 × (1 forward + 1 backward) - CG solve: ~10 × (1 forward + 1 backward) typical - Evaluation: 1 forward - Total: ~21 forward + ~21 backward passes

Preconditioning breaks even immediately (fewer CG iterations offset the diagonal estimation cost), and provides 2-5× speedup for ill-conditioned problems.

type state:

GaussNewtonState

param state:

Current optimization state containing sample, probe, iteration, loss, damping, and convergence status.

type state:

GaussNewtonState

type residual_fn:

Callable[[Float[Array, 'n']], Float[Array, 'm']]

param residual_fn:

Function mapping flattened parameters to residuals.

type residual_fn:

Callable[[Float[Array, " n"]], Float[Array, " m"]]

type cg_maxiter:

int, default: 50

param cg_maxiter:

Maximum conjugate gradient iterations. Default is 50.

type cg_maxiter:

int, optional

type cg_tol:

float, default: 1e-05

param cg_tol:

CG convergence tolerance. Default is 1e-5.

type cg_tol:

float, optional

type use_preconditioner:

Union[bool, Bool[Array, '']], default: False

param use_preconditioner:

Whether to use diagonal preconditioning for CG. Preconditioning can significantly improve convergence rate for ill-conditioned problems. Default is False.

type use_preconditioner:

bool, optional

returns:

new_state – Updated optimization state after the step.

rtype:

GaussNewtonState

Notes

This function is generic and works with any residual function. The parameters are flattened from (sample, probe) to a real vector before passing to residual_fn.

Trust-region adaptation: - If ρ > 0.75: very good step, decrease damping (be aggressive) - If 0.25 < ρ < 0.75: acceptable step, keep damping - If ρ < 0.25: poor step, increase damping (be conservative)

where ρ = actual_reduction / predicted_reduction.

When use_preconditioner=True, a diagonal preconditioner is computed using Hutchinson’s estimator: M = diag(diag(J^T J) + λ)^{-1}. This improves CG convergence for ill-conditioned problems.

Convergence detection: - If loss < 1e-14: already at optimum, mark converged - If residuals contain NaN/Inf: mark converged (failure mode) - If relative improvement < 1e-8: normal convergence

CG convergence limitation: JAX’s sparse_linalg.cg currently returns None for the convergence info (second return value). This means we cannot detect when CG fails to converge within cg_maxiter iterations. The only recourse is choosing appropriate cg_maxiter and cg_tol, or enabling preconditioning for ill-conditioned problems. Future JAX versions may provide convergence diagnostics.

Examples

>>> state = make_gauss_newton_state(sample, probe)
>>> new_state = gn_step(state, my_residual_fn)
>>> # With preconditioning for better convergence:
>>> new_state = gn_step(state, my_residual_fn,
...                               use_preconditioner=True)
janssen.utils.hessian_matvec(loss_fn: Callable[[Float[Array, 'n']], Float[Array, '']], params: Float[Array, 'n']) Callable[[Float[Array, 'n']], Float[Array, 'n']][source]

Create exact Hessian-vector product operator.

For problems where the Gauss-Newton approximation H ≈ J^T J is insufficient, this computes products with the exact Hessian:

H @ v = d/dε [∇L(θ + εv)]|_{ε=0}

using forward-over-reverse autodiff (jvp through grad).

Implementation Logic

The function uses forward-over-reverse composition:

  1. Setup Phase (executed once, outside returned closure): grad_fn = jax.grad(loss_fn) - Constructs the gradient function ∇L: R^n → R^n - This is reverse-mode autodiff: cost O(n) per gradient eval - Computing once and reusing avoids redundant graph construction

  2. Hessian-Vector Product (executed per matrix-vector product): Internal function _hvp(v) computes: hv = jax.jvp(grad_fn, (params,), (v,))[1] - Forward-mode autodiff through the gradient function - Computes directional derivative: d/dε [∇L(θ + εv)]|_{ε=0} - This equals H @ v by definition of the Hessian

Mathematical Background

The Hessian H = ∇²L is the Jacobian of the gradient:

H @ v = J_∇L @ v = d/dε [∇L(θ + εv)]|_{ε=0}

Forward-over-reverse (fwd ∘ rev) is efficient for matrix-vector products, with cost roughly 2× a gradient evaluation.

Contrast with Gauss-Newton approximation J^T J: - Exact Hessian: H = J^T J + Σᵢ rᵢ ∇²rᵢ - GN approximation: H ≈ J^T J (drops second term) - When residuals rᵢ are small or linear, J^T J is accurate - When residuals are large and nonlinear, exact H is needed

type loss_fn:

Callable[[Float[Array, 'n']], Float[Array, '']]

param loss_fn:

Scalar loss function L(θ).

type loss_fn:

Callable[[Float[Array, " n"]], Float[Array, " "]]

type params:

Float[Array, 'n']

param params:

Current parameters (linearization point).

type params:

Float[Array, " n"]

returns:

hvp – Function computing H @ v for any vector v.

rtype:

Callable[[Float[Array, " n"]], Float[Array, " n"]]

Notes

The exact Hessian includes second-order derivative information beyond J^T J, capturing curvature from the residual structure itself. This is important when residuals are highly nonlinear.

Cost: ~2× gradient evaluation per HVP. For large n, this is much cheaper than forming H explicitly (which would cost O(n) gradient evaluations and O(n²) storage).

Examples

>>> def loss_fn(x):
...     return jnp.sum(x**4)
>>> params = jnp.array([1.0, 2.0])
>>> hvp = hessian_matvec(loss_fn, params)
>>> result = hvp(jnp.array([1.0, 0.0]))
janssen.utils.jtj_matvec(residual_fn: Callable[[Float[Array, 'n']], Float[Array, 'm']], params: Float[Array, 'n'], damping: Float[Array, '']) Callable[[Float[Array, 'n']], Float[Array, 'n']][source]

Create matrix-vector product operator for (J^T J + λI).

This is the core of Jacobian-free Gauss-Newton optimization. Rather than forming the Jacobian matrix J explicitly, we compute products with (J^T J) using JAX’s autodiff primitives:

(J^T J) @ v = J^T @ (J @ v)

= vjp(jvp(v))

where jvp is forward-mode autodiff and vjp is reverse-mode.

Implementation Logic

The function uses a two-stage construction:

  1. Setup Phase (executed once, outside returned closure): - Computes vjp_fn = jax.vjp(residual_fn, params)[1] - This captures the reverse-mode autodiff operator J^T @ (·) - Computing vjp_fn once and reusing it in the closure eliminates

    redundant computation when matvec is called multiple times

    • Critical optimization: without this, each matvec call would recompute vjp, doubling the number of forward passes

  2. Matvec Phase (executed per matrix-vector product): Internal function _matvec(v) computes: a. jv = jax.jvp(residual_fn, (params,), (v,))[1]

    • Forward-mode autodiff: directional derivative J @ v

    • Cost: 1 forward pass through residual_fn

    1. jtjv = vjp_fn(jv)[0] - Reverse-mode autodiff: J^T @ (J @ v) - Cost: 1 backward pass through residual_fn

    2. return jtjv + damping * v - Adds Levenberg-Marquardt regularization λI

The key insight: J^T J never materialized. Each matvec costs only 1 forward + 1 backward pass, vs O(mn) storage and O(mn) flops for explicit formation.

type residual_fn:

Callable[[Float[Array, 'n']], Float[Array, 'm']]

param residual_fn:

Function mapping parameters to residuals r(θ).

type residual_fn:

Callable[[Float[Array, " n"]], Float[Array, " m"]]

type params:

Float[Array, 'n']

param params:

Current parameters (linearization point).

type params:

Float[Array, " n"]

type damping:

Float[Array, '']

param damping:

Levenberg-Marquardt damping parameter λ ≥ 0.

type damping:

Float[Array, " "]

returns:

matvec – Function computing (J^T J + λI) @ v for any vector v.

rtype:

Callable[[Float[Array, " n"]], Float[Array, " n"]]

Notes

Memory complexity is O(m + n) per matrix-vector product, compared to O(m * n) to store J explicitly. For large-scale problems where m, n >> 1, this is essential for tractability.

The vjp_fn is computed once and captured in the closure. This is critical for performance: conjugate gradient will call the returned matvec 10-100 times per Gauss-Newton step. Without caching vjp_fn, we’d waste 50-100 extra forward passes.

Examples

>>> def residual_fn(x):
...     return jnp.array([x[0]**2 - 1, x[1]**2 - 4])
>>> params = jnp.array([1.5, 2.5])
>>> matvec = jtj_matvec(residual_fn, params, jnp.array(1e-3))
>>> result = matvec(jnp.array([1.0, 0.0]))
janssen.utils.flatten_params(sample: Complex[Array, 'Hs Ws'], probe: Complex[Array, 'Hp Wp']) Float[Array, 'n'][source]

Flatten complex arrays to real parameter vector for optimization.

Converts two complex 2D arrays into a single real-valued vector by separating real and imaginary components. Handles arrays with different shapes (e.g., sample=512×512 FOV, probe=64×64 spot).

Parameters:
  • sample (Complex[Array, " Hs Ws"]) – First complex array (e.g., object transmission function).

  • probe (Complex[Array, " Hp Wp"]) – Second complex array (e.g., probe wavefront).

Returns:

params – Flattened real parameter vector with n = 2*(Hs*Ws + Hp*Wp). Layout: [sample_real, sample_imag, probe_real, probe_imag]

Return type:

Float[Array, " n"]

Examples

>>> sample = jnp.ones((512, 512), dtype=jnp.complex128)
>>> probe = jnp.ones((64, 64), dtype=jnp.complex128) * (1+1j)
>>> params = flatten_params(sample, probe)
>>> params.shape
(532480,)  # 2*(512*512 + 64*64)

See also

unflatten_params

Inverse operation to reconstruct complex arrays

janssen.utils.fourier_shift(field: Complex[Array, 'H W'], shift_x: float | Float[Array, ''], shift_y: float | Float[Array, '']) Complex[Array, 'H W'][source]

Shift a 2D field using FFT-based sub-pixel shifting.

Applies a phase ramp in Fourier space to shift the field by (shift_x, shift_y) pixels relative to the center of the image. Supports sub-pixel shifts with high accuracy.

Parameters:
  • field (Complex[Array, " H W"]) – Input 2D complex field to shift. The field is assumed to be centered (i.e., the center of the image is the origin).

  • shift_x (float) – Shift in x direction (columns) in pixels. Positive shifts move the field to the right.

  • shift_y (float) – Shift in y direction (rows) in pixels. Positive shifts move the field downward.

Returns:

shifted – Shifted field with same shape as input.

Return type:

Complex[Array, " H W"]

Notes

The shift is implemented by multiplying the Fourier transform of the field by a phase ramp:

\[\begin{split}F_{shifted}(f_x, f_y) = F(f_x, f_y) \\cdot \\exp(-2\\pi i (f_x \\cdot \\Delta x + f_y \\cdot \\Delta y))\end{split}\]

This is equivalent to circular shifting with sub-pixel accuracy. The shift is relative to the center of the image, so a shift of (0, 0) leaves the field unchanged.

Examples

>>> import jax.numpy as jnp
>>> field = jnp.ones((64, 64), dtype=jnp.complex128)
>>> shifted = fourier_shift(field, 10.5, -5.25)  # Sub-pixel shift
janssen.utils.unflatten_params(params: Float[Array, 'n'], sample_shape: tuple[int, int], probe_shape: tuple[int, int]) tuple[Complex[Array, 'Hs Ws'], Complex[Array, 'Hp Wp']][source]

Unflatten real parameter vector back to complex arrays.

Reconstructs two complex 2D arrays from a flattened real parameter vector. Handles arrays with different shapes (e.g., sample=512×512 FOV, probe=64×64 spot).

Parameters:
  • params (Float[Array, " n"]) – Flattened real parameter vector with n = 2*(Hs*Ws + Hp*Wp).

  • sample_shape (Tuple[int, int]) – Shape (Hs, Ws) of the sample array.

  • probe_shape (Tuple[int, int]) – Shape (Hp, Wp) of the probe array.

Return type:

tuple[Complex[Array, 'Hs Ws'], Complex[Array, 'Hp Wp']]

Returns:

  • sample (Complex[Array, " Hs Ws"]) – First complex array reconstructed from first 2*Hs*Ws elements.

  • probe (Complex[Array, " Hp Wp"]) – Second complex array reconstructed from remaining elements.

Notes

This function cannot be JIT-compiled directly because it uses dynamic indexing based on the shape parameters. When used within a JIT-compiled function, ensure shapes are static via static_argnums.

Examples

>>> params = jnp.arange(532480, dtype=jnp.float64)
>>> sample, probe = unflatten_params(params, (512, 512), (64, 64))
>>> sample.shape, probe.shape
((512, 512), (64, 64))

See also

flatten_params

Forward operation to create flattened vector

janssen.utils.wirtinger_grad(func2diff: Callable[[...], Float[Array, '...']], argnums: int | Sequence[int] | None = 0) Callable[[...], Complex[Array, '...'] | tuple[Complex[Array, '...'], ...]][source]

Compute the Wirtinger gradient of a complex-valued function.

This function returns a new function that computes the Wirtinger gradient of the input function f with respect to the specified argument(s). This is based on the formula for Wirtinger derivative:

\[\frac{\partial f}{\partial z} = \frac{1}{2} \left( \frac{\partial f}{\partial x} - i \frac{\partial f}{\partial y} \right)\]
Parameters:
  • func2diff (Callable[..., Float[Array, " ..."]]) – A complex-valued function to differentiate.

  • argnums (Union[int, Sequence[int]], optional) – Specifies which argument(s) to compute the gradient with respect to. Can be an int or a sequence of ints. Default is 0.

Returns:

grad_f – A function that computes the Wirtinger gradient of f with respect to the specified argument(s). Returns a single array if argnums is an int, or a tuple of arrays if argnums is a sequence.

Return type:

Callable

Notes

The Wirtinger derivative is essential for optimizing real-valued loss functions with respect to complex-valued parameters. It provides the correct gradient direction for steepest descent in the complex plane.

Examples

>>> import jax.numpy as jnp
>>> def loss(z):
...     return jnp.sum(jnp.abs(z)**2)
>>> grad_fn = wirtinger_grad(loss)
>>> z = jnp.array([1+2j, 3+4j])
>>> grad_fn(z)  # Returns the Wirtinger gradient