Source code for janssen.utils.gauss_newton

"""General-purpose Gauss-Newton optimization with JAX autodiff.

Extended Summary
----------------
This module provides Jacobian-free Gauss-Newton optimization for solving
nonlinear least-squares problems:

    min_θ  0.5 * ||r(θ)||^2

where r: R^n → R^m is a residual function. The key insight is that JAX's
autodifferentiation primitives (jvp/vjp) enable computing Jacobian-vector
products without ever materializing the Jacobian matrix, making second-order
optimization tractable for large-scale problems.

The Gauss-Newton method approximates the Hessian as H ≈ J^T J and solves:

    (J^T J + λI) δ = -J^T r

where λ is the Levenberg-Marquardt damping parameter. This module provides
all the linear algebra primitives to solve this system efficiently using
conjugate gradient with automatic damping adaptation.

Routine Listings
----------------
jtj_matvec : function
    Create (J^T J + λI) matrix-vector product operator.
jt_residual : function
    Compute residuals and J^T @ r simultaneously.
hessian_matvec : function
    Create exact Hessian-vector product operator.
gn_step : function
    Generic Gauss-Newton step with trust-region damping.
gn_solve : function
    High-level solver that runs Gauss-Newton until convergence.
gn_history : function
    Gauss-Newton solver with per-iteration history tracking.
gn_loss_history : function
    Gauss-Newton solver with per-iteration loss tracking only.
max_eigenval : function
    Estimate largest eigenvalue of J^T J via power iteration.
jtj_diag : function
    Estimate diagonal of J^T J for preconditioning.

Notes
-----
All functions are JAX-compatible (jit/grad/vmap) and follow functional
programming conventions. The optimization state is managed via the
GaussNewtonState PyTree from janssen.types.

This module is domain-agnostic. For specific applications (ptychography,
tomography, etc.), provide an appropriate residual function r(θ).

References
----------
.. [1] Nocedal & Wright, "Numerical Optimization", 2nd ed., Chapter 10
.. [2] Kelley, "Iterative Methods for Optimization", SIAM (1999)
.. [3] Bradbury et al., "JAX: Autograd and XLA", MLSys (2018)
"""

from functools import partial

import jax
import jax.numpy as jnp
import jax.scipy.sparse.linalg as sparse_linalg
from beartype import beartype
from beartype.typing import Callable, Tuple, Union
from jaxtyping import Array, Bool, Complex, Float, jaxtyped

from janssen.types import (
    GaussNewtonState,
    ScalarBool,
    ScalarInteger,
    ScalarNumeric,
)

from .math import flatten_params, unflatten_params

TRUST_REGION_EXCELLENT = 0.75
TRUST_REGION_ACCEPTABLE = 0.25
CONVERGENCE_TOL = 1e-8
DIVISION_EPSILON = 1e-12
LOSS_ZERO_TOL = 1e-14
MIN_DAMPING = 1e-12
MAX_DAMPING = 1e8


[docs] def jtj_matvec( residual_fn: Callable[[Float[Array, " n"]], Float[Array, " m"]], params: Float[Array, " n"], damping: Float[Array, " "], ) -> Callable[[Float[Array, " n"]], Float[Array, " n"]]: """Create matrix-vector product operator for (J^T J + λI). This is the core of Jacobian-free Gauss-Newton optimization. Rather than forming the Jacobian matrix J explicitly, we compute products with (J^T J) using JAX's autodiff primitives: (J^T J) @ v = J^T @ (J @ v) = vjp(jvp(v)) where jvp is forward-mode autodiff and vjp is reverse-mode. Implementation Logic -------------------- The function uses a two-stage construction: 1. **Setup Phase** (executed once, outside returned closure): - Computes vjp_fn = jax.vjp(residual_fn, params)[1] - This captures the reverse-mode autodiff operator J^T @ (·) - Computing vjp_fn once and reusing it in the closure eliminates redundant computation when matvec is called multiple times - Critical optimization: without this, each matvec call would recompute vjp, doubling the number of forward passes 2. **Matvec Phase** (executed per matrix-vector product): Internal function _matvec(v) computes: a. jv = jax.jvp(residual_fn, (params,), (v,))[1] - Forward-mode autodiff: directional derivative J @ v - Cost: 1 forward pass through residual_fn b. jtjv = vjp_fn(jv)[0] - Reverse-mode autodiff: J^T @ (J @ v) - Cost: 1 backward pass through residual_fn c. return jtjv + damping * v - Adds Levenberg-Marquardt regularization λI The key insight: J^T J never materialized. Each matvec costs only 1 forward + 1 backward pass, vs O(mn) storage and O(mn) flops for explicit formation. Parameters ---------- residual_fn : Callable[[Float[Array, " n"]], Float[Array, " m"]] Function mapping parameters to residuals r(θ). params : Float[Array, " n"] Current parameters (linearization point). damping : Float[Array, " "] Levenberg-Marquardt damping parameter λ ≥ 0. Returns ------- matvec : Callable[[Float[Array, " n"]], Float[Array, " n"]] Function computing (J^T J + λI) @ v for any vector v. Notes ----- Memory complexity is O(m + n) per matrix-vector product, compared to O(m * n) to store J explicitly. For large-scale problems where m, n >> 1, this is essential for tractability. The vjp_fn is computed once and captured in the closure. This is critical for performance: conjugate gradient will call the returned matvec 10-100 times per Gauss-Newton step. Without caching vjp_fn, we'd waste 50-100 extra forward passes. Examples -------- >>> def residual_fn(x): ... return jnp.array([x[0]**2 - 1, x[1]**2 - 4]) >>> params = jnp.array([1.5, 2.5]) >>> matvec = jtj_matvec(residual_fn, params, jnp.array(1e-3)) >>> result = matvec(jnp.array([1.0, 0.0])) """ _: Float[Array, " m"] vjp_fn: Callable[[Float[Array, " m"]], Tuple[Float[Array, " n"]]] _, vjp_fn = jax.vjp(residual_fn, params) def _matvec(v: Float[Array, " n"]) -> Float[Array, " n"]: _: Float[Array, " m"] jv: Float[Array, " m"] _, jv = jax.jvp(residual_fn, (params,), (v,)) jtjv: Float[Array, " n"] (jtjv,) = vjp_fn(jv) return jtjv + damping * v return _matvec
[docs] @partial(jax.jit, static_argnums=(0,)) def jt_residual( residual_fn: Callable[[Float[Array, " n"]], Float[Array, " m"]], params: Float[Array, " n"], weights: Union[ScalarNumeric, Float[Array, " m"]] = -1.0, ) -> Tuple[Float[Array, " m"], Float[Array, " n"]]: """Compute residuals and J^T @ W @ r simultaneously. Uses reverse-mode autodiff to compute the gradient of the loss L = 0.5 * r^T W r, which equals J^T @ W @ r. For unweighted least squares (W = I), this reduces to J^T @ r. Implementation Logic -------------------- The function exploits the chain rule via reverse-mode autodiff: 1. **Compute residuals and VJP function**: residuals, vjp_fn = jax.vjp(residual_fn, params) - Evaluates r(θ) at current parameters - Constructs vjp_fn: vector → J^T @ vector - Single forward pass through residual_fn 2. **Weight handling** (conditional based on weights parameter): - If weights < 0 (sentinel): set weights_normalized = ones(m) → Unweighted case W = I - Else: broadcast weights to residual shape → Weighted case W = diag(weights) - Uses jax.lax.cond for JIT compatibility (control flow traced) 3. **Apply weighting and compute weighted gradient**: weighted_residuals = weights_normalized * residuals jt_wr = vjp_fn(weighted_residuals)[0] - Element-wise multiplication: W @ r for diagonal W - vjp_fn(W @ r) = J^T @ W @ r by chain rule - Single backward pass through residual_fn Total cost: 1 forward + 1 backward pass through residual_fn, regardless of whether weighting is used. Mathematical Derivation ----------------------- For loss L(θ) = 0.5 * ||W^{1/2} r(θ)||² = 0.5 r(θ)^T W r(θ): ∇L = J^T W r where J is the Jacobian ∂r/∂θ. The VJP operator computes J^T @ v for any vector v. Setting v = W @ r gives J^T @ W @ r directly. Parameters ---------- residual_fn : Callable[[Float[Array, " n"]], Float[Array, " m"]] Function mapping parameters to residuals. params : Float[Array, " n"] Current parameters. weights : Union[ScalarNumeric, Float[Array, " m"]], optional Diagonal weight matrix entries for weighted least squares. If -1.0 (default), assumes unweighted (W = I). Returns ------- residuals : Float[Array, " m"] Current residual vector r(θ). jt_wr : Float[Array, " n"] Weighted gradient J^T @ W @ r = ∇(0.5 * r^T W r). Notes ----- For diagonal weighting W = diag(w), the weighted gradient is J^T @ W @ r where each residual r_i is weighted by w_i. Examples -------- >>> def residual_fn(x): ... return x**2 - jnp.array([1.0, 4.0]) >>> params = jnp.array([1.5, 2.5]) >>> r, jt_r = jt_residual(residual_fn, params) >>> # Weighted version: >>> weights = jnp.array([2.0, 0.5]) >>> r, jt_wr = jt_residual(residual_fn, params, weights) """ residuals: Float[Array, " m"] vjp_fn: Callable[[Float[Array, " m"]], Tuple[Float[Array, " n"]]] residuals, vjp_fn = jax.vjp(residual_fn, params) weights_arr: Float[Array, "..."] = jnp.asarray(weights) weights_normalized: Float[Array, " m"] weights_normalized = jax.lax.cond( jnp.sum(weights_arr) < 0, lambda: jnp.ones_like(residuals), lambda: jnp.broadcast_to(weights_arr, residuals.shape), ) weighted_residuals: Float[Array, " m"] = weights_normalized * residuals jt_wr: Float[Array, " n"] (jt_wr,) = vjp_fn(weighted_residuals) return residuals, jt_wr
[docs] def hessian_matvec( loss_fn: Callable[[Float[Array, " n"]], Float[Array, " "]], params: Float[Array, " n"], ) -> Callable[[Float[Array, " n"]], Float[Array, " n"]]: """Create exact Hessian-vector product operator. For problems where the Gauss-Newton approximation H ≈ J^T J is insufficient, this computes products with the exact Hessian: H @ v = d/dε [∇L(θ + εv)]|_{ε=0} using forward-over-reverse autodiff (jvp through grad). Implementation Logic -------------------- The function uses forward-over-reverse composition: 1. **Setup Phase** (executed once, outside returned closure): grad_fn = jax.grad(loss_fn) - Constructs the gradient function ∇L: R^n → R^n - This is reverse-mode autodiff: cost O(n) per gradient eval - Computing once and reusing avoids redundant graph construction 2. **Hessian-Vector Product** (executed per matrix-vector product): Internal function _hvp(v) computes: hv = jax.jvp(grad_fn, (params,), (v,))[1] - Forward-mode autodiff through the gradient function - Computes directional derivative: d/dε [∇L(θ + εv)]|_{ε=0} - This equals H @ v by definition of the Hessian Mathematical Background ----------------------- The Hessian H = ∇²L is the Jacobian of the gradient: H @ v = J_∇L @ v = d/dε [∇L(θ + εv)]|_{ε=0} Forward-over-reverse (fwd ∘ rev) is efficient for matrix-vector products, with cost roughly 2× a gradient evaluation. Contrast with Gauss-Newton approximation J^T J: - Exact Hessian: H = J^T J + Σᵢ rᵢ ∇²rᵢ - GN approximation: H ≈ J^T J (drops second term) - When residuals rᵢ are small or linear, J^T J is accurate - When residuals are large and nonlinear, exact H is needed Parameters ---------- loss_fn : Callable[[Float[Array, " n"]], Float[Array, " "]] Scalar loss function L(θ). params : Float[Array, " n"] Current parameters (linearization point). Returns ------- hvp : Callable[[Float[Array, " n"]], Float[Array, " n"]] Function computing H @ v for any vector v. Notes ----- The exact Hessian includes second-order derivative information beyond J^T J, capturing curvature from the residual structure itself. This is important when residuals are highly nonlinear. Cost: ~2× gradient evaluation per HVP. For large n, this is much cheaper than forming H explicitly (which would cost O(n) gradient evaluations and O(n²) storage). Examples -------- >>> def loss_fn(x): ... return jnp.sum(x**4) >>> params = jnp.array([1.0, 2.0]) >>> hvp = hessian_matvec(loss_fn, params) >>> result = hvp(jnp.array([1.0, 0.0])) """ grad_fn: Callable[[Float[Array, " n"]], Float[Array, " n"]] grad_fn = jax.grad(loss_fn) def _hvp(v: Float[Array, " n"]) -> Float[Array, " n"]: _: Float[Array, " n"] hv: Float[Array, " n"] _, hv = jax.jvp(grad_fn, (params,), (v,)) return hv return _hvp
[docs] @partial(jax.jit, static_argnums=(1, 2, 3)) @jaxtyped(typechecker=beartype) def gn_step( state: GaussNewtonState, residual_fn: Callable[[Float[Array, " n"]], Float[Array, " m"]], cg_maxiter: int = 50, cg_tol: float = 1e-5, use_preconditioner: ScalarBool = False, ) -> GaussNewtonState: """Perform Gauss-Newton step with Levenberg-Marquardt damping. Solves the trust-region subproblem using conjugate gradient: (J^T J + λI) δ = -J^T r then updates parameters: θ_{k+1} = θ_k + δ The damping λ adapts based on actual vs predicted reduction. Implementation Logic -------------------- The function executes a complete trust-region Gauss-Newton step: **Phase 1: Setup and Gradient Computation** 1. Extract sample and probe shapes from state 2. Flatten complex arrays (sample, probe) → real vector params 3. Compute residuals r and gradient J^T @ r via jt_residual - Cost: 1 forward + 1 backward pass through residual_fn 4. Compute current loss: 0.5 ||r||² **Phase 2: Linear System Construction** 5. Build matrix-vector operator: matvec = jtj_matvec(...) - Computes (J^T J + λI) @ v using JVP/VJP composition - Caches vjp_fn to avoid redundant computation in CG 6. If use_preconditioner=True: a. Estimate diagonal via Hutchinson: diag_jtj ≈ diag(J^T J) - Uses 10 Rademacher samples (random ±1 vectors) - Cost: 10 × (1 forward + 1 backward pass) b. Clamp negative estimates: max(diag_jtj, 0) + λ - Ensures positive definite preconditioner - Negative estimates can occur due to finite sampling c. Construct preconditioner: M(v) = v / (diag + λ) - Diagonal scaling: approximates (diag(H))^{-1} **Phase 3: Solve Trust-Region Subproblem** 7. Solve (J^T J + λI) δ = -J^T r via conjugate gradient - Inputs: matvec operator, RHS = -J^T r, preconditioner M - Iterative solver: converges in typically 10-50 iterations - Cost per iteration: 1 matvec call = 1 forward + 1 backward - Total cost: ~20 × (1 forward + 1 backward) without precond ~10 × (1 forward + 1 backward) with precond 8. Compute step: params_new = params + δ 9. Unflatten params_new → (sample_new, probe_new) **Phase 4: Evaluate New Point** 10. Compute residuals_new = residual_fn(params_new) - Cost: 1 forward pass 11. Compute new_loss = 0.5 ||residuals_new||² **Phase 5: Trust-Region Decision** 12. Compute predicted reduction: pred = 0.5 δ^T (J^T J + λI) δ - Uses h_delta = matvec(δ) already computed - Predicted decrease in quadratic model 13. Compute actual reduction: actual = current_loss - new_loss 14. Compute reduction ratio: ρ = actual / predicted - If predicted ≤ 0: set ρ = 0 (non-descent direction) 15. Accept/reject decision: accept = (actual > 0) AND (predicted > 0) AND (ρ > 0) - Requires all three: loss decreased, descent direction, positive ratio **Phase 6: Damping Adaptation** 16. Update damping λ based on ρ (using nested jax.lax.cond): - If predicted ≤ 0: λ ← 10λ (drastic increase, bad step) - Else if ρ > 0.75: λ ← λ/3 (excellent step, be aggressive) - Else if ρ > 0.25: λ ← λ (acceptable step, maintain) - Else: λ ← 3λ (poor step, be conservative) 17. Clamp damping: λ ∈ [10^{-12}, 10^8] - Prevents both over-damping (slow convergence) and under-damping (instability) **Phase 7: State Update** 18. If accepted: - Use new sample, probe, loss 19. If rejected: - Keep current sample, probe, loss 20. Compute relative improvement |current_loss - final_loss| / current_loss 21. Check convergence: accept AND (rel_improvement < 10^{-8}) 22. Return updated GaussNewtonState with: - sample, probe (updated or kept) - iteration counter incremented - loss (new or current) - damping (adapted) - converged flag Internal Helper Functions ------------------------- Uses nested function _preconditioner_matvec(v) when preconditioning: - Computes element-wise: v / (diag_with_damping) - Passed as M parameter to jax.scipy.sparse.linalg.cg - Transforms system from (J^T J + λI) δ = -J^T r to M (J^T J + λI) δ = M (-J^T r) - Accelerates CG by improving condition number Computational Complexity ------------------------ Without preconditioning: - Setup: 1 forward + 1 backward (gradient) - CG solve: ~20 × (1 forward + 1 backward) typical - Evaluation: 1 forward (new residuals) - Total: ~22 forward + ~21 backward passes With preconditioning: - Setup: 1 forward + 1 backward (gradient) - Diagonal estimation: 10 × (1 forward + 1 backward) - CG solve: ~10 × (1 forward + 1 backward) typical - Evaluation: 1 forward - Total: ~21 forward + ~21 backward passes Preconditioning breaks even immediately (fewer CG iterations offset the diagonal estimation cost), and provides 2-5× speedup for ill-conditioned problems. Parameters ---------- state : GaussNewtonState Current optimization state containing sample, probe, iteration, loss, damping, and convergence status. residual_fn : Callable[[Float[Array, " n"]], Float[Array, " m"]] Function mapping flattened parameters to residuals. cg_maxiter : int, optional Maximum conjugate gradient iterations. Default is 50. cg_tol : float, optional CG convergence tolerance. Default is 1e-5. use_preconditioner : bool, optional Whether to use diagonal preconditioning for CG. Preconditioning can significantly improve convergence rate for ill-conditioned problems. Default is False. Returns ------- new_state : GaussNewtonState Updated optimization state after the step. Notes ----- This function is generic and works with any residual function. The parameters are flattened from (sample, probe) to a real vector before passing to residual_fn. Trust-region adaptation: - If ρ > 0.75: very good step, decrease damping (be aggressive) - If 0.25 < ρ < 0.75: acceptable step, keep damping - If ρ < 0.25: poor step, increase damping (be conservative) where ρ = actual_reduction / predicted_reduction. When use_preconditioner=True, a diagonal preconditioner is computed using Hutchinson's estimator: M = diag(diag(J^T J) + λ)^{-1}. This improves CG convergence for ill-conditioned problems. Convergence detection: - If loss < 1e-14: already at optimum, mark converged - If residuals contain NaN/Inf: mark converged (failure mode) - If relative improvement < 1e-8: normal convergence CG convergence limitation: JAX's sparse_linalg.cg currently returns None for the convergence info (second return value). This means we cannot detect when CG fails to converge within cg_maxiter iterations. The only recourse is choosing appropriate cg_maxiter and cg_tol, or enabling preconditioning for ill-conditioned problems. Future JAX versions may provide convergence diagnostics. Examples -------- >>> state = make_gauss_newton_state(sample, probe) >>> new_state = gn_step(state, my_residual_fn) >>> # With preconditioning for better convergence: >>> new_state = gn_step(state, my_residual_fn, ... use_preconditioner=True) """ sample_shape: Tuple[int, int] = state.sample.shape probe_shape: Tuple[int, int] = state.probe.shape params: Float[Array, " n"] = flatten_params(state.sample, state.probe) residuals: Float[Array, " m"] jt_r: Float[Array, " n"] residuals, jt_r = jt_residual(residual_fn, params) residuals_finite: Bool[Array, " "] = jnp.all(jnp.isfinite(residuals)) residuals = jnp.where( residuals_finite, residuals, jnp.zeros_like(residuals) ) current_loss: Float[Array, " "] = 0.5 * jnp.sum(residuals**2) loss_near_zero: Bool[Array, " "] = current_loss < LOSS_ZERO_TOL matvec: Callable[[Float[Array, " n"]], Float[Array, " n"]] = ( jtj_matvec(residual_fn, params, state.damping) ) diag_jtj: Float[Array, " n"] = jax.lax.cond( use_preconditioner, lambda: jtj_diag(residual_fn, params, num_samples=10), lambda: jnp.ones_like(params), ) diag_with_damping: Float[Array, " n"] = ( jnp.maximum(diag_jtj, 0.0) + state.damping ) def _preconditioner(v: Float[Array, " n"]) -> Float[Array, " n"]: return jax.lax.cond( use_preconditioner, lambda: v / diag_with_damping, lambda: v, ) delta: Float[Array, " n"] cg_info: None delta, cg_info = sparse_linalg.cg( matvec, -jt_r, x0=jnp.zeros_like(params), maxiter=cg_maxiter, tol=cg_tol, M=_preconditioner, ) params_new: Float[Array, " n"] = params + delta sample_new: Complex[Array, " Hs Ws"] probe_new: Complex[Array, " Hp Wp"] sample_new, probe_new = unflatten_params( params_new, sample_shape, probe_shape ) residuals_new: Float[Array, " m"] = residual_fn(params_new) new_loss: Float[Array, " "] = 0.5 * jnp.sum(residuals_new**2) h_delta: Float[Array, " n"] = matvec(delta) predicted_reduction: Float[Array, " "] = 0.5 * jnp.dot(delta, h_delta) actual_reduction: Float[Array, " "] = current_loss - new_loss pred_positive: Bool[Array, " "] = predicted_reduction > 0.0 rho: Float[Array, " "] = jnp.where( pred_positive, actual_reduction / (predicted_reduction + DIVISION_EPSILON), 0.0, ) accept: Bool[Array, " "] = ( (actual_reduction > 0.0) & pred_positive & (rho > 0.0) ) new_damping: Float[Array, " "] = jax.lax.cond( pred_positive, lambda: jax.lax.cond( rho > TRUST_REGION_EXCELLENT, lambda: state.damping * 0.33, lambda: jax.lax.cond( rho > TRUST_REGION_ACCEPTABLE, lambda: state.damping, lambda: state.damping * 3.0, ), ), lambda: state.damping * 10.0, ) new_damping = jnp.clip(new_damping, MIN_DAMPING, MAX_DAMPING) final_sample: Complex[Array, " hh ww"] = jax.lax.cond( accept, lambda: sample_new, lambda: state.sample ) final_probe: Complex[Array, " hh ww"] = jax.lax.cond( accept, lambda: probe_new, lambda: state.probe ) final_loss: Float[Array, " "] = jax.lax.cond( accept, lambda: new_loss, lambda: current_loss ) rel_improvement: Float[Array, " "] = jnp.abs(current_loss - final_loss) / ( current_loss + DIVISION_EPSILON ) step_converged: Bool[Array, " "] = accept & ( rel_improvement < CONVERGENCE_TOL ) converged: Bool[Array, " "] = ( loss_near_zero | step_converged | (~residuals_finite) ) return GaussNewtonState( sample=final_sample, probe=final_probe, iteration=state.iteration + 1, loss=final_loss, damping=new_damping, converged=converged, )
[docs] @partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) @jaxtyped(typechecker=beartype) def gn_solve( state: GaussNewtonState, residual_fn: Callable[[Float[Array, " n"]], Float[Array, " m"]], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-5, use_preconditioner: bool = False, ) -> GaussNewtonState: """Run Gauss-Newton optimization until convergence or max iterations. High-level solver that repeatedly calls gn_step until the optimization converges or reaches the maximum iteration limit. This provides a single entry point for running the full optimization. Implementation Logic -------------------- The function uses jax.lax.scan for efficient iteration: 1. **Internal Step Function** step_fn(carry, _): Wraps gn_step with early stopping logic: a. Check if carry.converged is True b. If converged: return carry unchanged (no-op) c. If not converged: call gn_step(carry, ...) d. Return updated state Uses jax.lax.cond for control flow: - JIT-compatible conditional execution - When converged=True, subsequent iterations are no-ops - Ensures fixed iteration count for XLA compilation 2. **Main Loop**: jax.lax.scan(step_fn, state, None, length=max_iterations) - Iterates exactly max_iterations times (required for JIT) - Carries GaussNewtonState through iterations - Stops updating when converged=True (via lax.cond guard) - Returns final_state after max_iterations or convergence Why scan instead of while_loop? -------------------------------- JAX provides two iteration primitives: - jax.lax.while_loop: Stops when condition is False → Non-deterministic iteration count → Harder to JIT-compile (variable-length trace) - jax.lax.scan: Runs fixed number of iterations → Deterministic iteration count → Efficient JIT compilation → Early stopping via lax.cond guard inside body We use scan with internal lax.cond guard to get both: - Deterministic compilation (fixed max_iterations) - Early stopping (when converged=True, step_fn returns carry) Trade-off: If convergence happens at iteration 20, iterations 21-100 are no-ops (just carry forwarding). Cost is negligible compared to actual Gauss-Newton steps. Computational Cost ------------------ If convergence happens at iteration k < max_iterations: - Active iterations: k × cost(gn_step) - No-op iterations: (max_iterations - k) × cost(lax.cond) - No-op cost is ~microseconds, negligible vs GN step (~seconds) Total cost ≈ k × [20-40 forward/backward passes per step] For 100 max iterations with typical convergence at k=30: - Active cost: 30 × 40 = 1200 forward+backward passes - No-op cost: 70 × 0 ≈ 0 (compiler optimizes) Parameters ---------- state : GaussNewtonState Initial optimization state containing sample, probe, iteration, loss, damping, and convergence status. residual_fn : Callable[[Float[Array, " n"]], Float[Array, " m"]] Function mapping flattened parameters to residuals. max_iterations : int, optional Maximum number of Gauss-Newton iterations. Default is 100. cg_maxiter : int, optional Maximum conjugate gradient iterations per step. Default is 50. cg_tol : float, optional CG convergence tolerance. Default is 1e-5. use_preconditioner : bool, optional Whether to use diagonal preconditioning for CG. Default is False. Returns ------- final_state : GaussNewtonState Final optimization state after convergence or max iterations. Notes ----- The solver stops when either: - The state's converged flag becomes True - The iteration count reaches max_iterations For monitoring progress during optimization, consider using this function with jax.lax.scan and a custom body function that logs intermediate state. This function is fully JIT-compatible. Examples -------- >>> state = make_gauss_newton_state(sample, probe) >>> final_state = gn_solve(state, my_residual_fn, ... max_iterations=50) >>> print(f"Converged: {final_state.converged}") >>> print(f"Final loss: {final_state.loss}") """ def step_fn( carry: GaussNewtonState, _: None ) -> Tuple[GaussNewtonState, None]: result: GaussNewtonState = jax.lax.cond( carry.converged, lambda: carry, lambda: gn_step( carry, residual_fn, cg_maxiter=cg_maxiter, cg_tol=cg_tol, use_preconditioner=use_preconditioner, ), ) return result, None final_state: GaussNewtonState _: None final_state, _ = jax.lax.scan(step_fn, state, None, length=max_iterations) return final_state
[docs] @partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) @jaxtyped(typechecker=beartype) def gn_history( state: GaussNewtonState, residual_fn: Callable[[Float[Array, " n"]], Float[Array, " m"]], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-5, use_preconditioner: bool = False, ) -> Tuple[GaussNewtonState, GaussNewtonState, Float[Array, " N"]]: """Run Gauss-Newton optimization with per-iteration history tracking. Like gn_solve but returns intermediate states and losses at each iteration, enabling convergence diagnostics and visualization. Implementation Logic -------------------- Same iteration strategy as gn_solve using jax.lax.scan, but scan outputs capture per-iteration states: 1. **Internal Step Function** step_fn(carry, _): a. Check if carry.converged is True b. If converged: return carry unchanged (no-op) c. If not converged: call gn_step(carry, ...) d. Return updated state and output tuple (state, loss) 2. **Main Loop**: jax.lax.scan(step_fn, state, None, length=max_iterations) - Returns both final_state and outputs PyTree - Outputs contain all intermediate states and losses The scan outputs are PyTrees where each leaf has an additional dimension of size max_iterations. For example, if state.sample has shape (H, W), then all_states.sample has shape (H, W, N) where N = max_iterations. Parameters ---------- state : GaussNewtonState Initial optimization state containing sample, probe, iteration, loss, damping, and convergence status. residual_fn : Callable[[Float[Array, " n"]], Float[Array, " m"]] Function mapping flattened parameters to residuals. max_iterations : int, optional Maximum number of Gauss-Newton iterations. Default is 100. cg_maxiter : int, optional Maximum conjugate gradient iterations per step. Default is 50. cg_tol : float, optional CG convergence tolerance. Default is 1e-5. use_preconditioner : bool, optional Whether to use diagonal preconditioning for CG. Default is False. Returns ------- final_state : GaussNewtonState Final optimization state after convergence or max iterations. all_states : GaussNewtonState PyTree with all intermediate states. Each field has an extra dimension of size max_iterations stacked along the last axis. For example, if state.sample is (H, W), all_states.sample is (H, W, max_iterations). all_losses : Float[Array, " N"] Loss value at each iteration, shape (max_iterations,). Notes ----- This function is useful for: - Convergence diagnostics and plotting - Creating visualizations of optimization progress - Debugging optimization issues - Resume/continuation with full history tracking If you only need the final state (not intermediate history), use gn_solve instead for slightly lower memory usage. Examples -------- >>> state = make_gauss_newton_state(sample, probe) >>> final, history, losses = gn_history( ... state, my_residual_fn, max_iterations=50 ... ) >>> print(f"Converged: {final.converged}") >>> print(f"Loss history shape: {losses.shape}") # (50,) >>> print(f"Sample history shape: {history.sample.shape}") # (H, W, 50) """ def step_fn( carry: GaussNewtonState, _: None ) -> Tuple[GaussNewtonState, Tuple[GaussNewtonState, Float[Array, " "]]]: result: GaussNewtonState = jax.lax.cond( carry.converged, lambda: carry, lambda: gn_step( carry, residual_fn, cg_maxiter=cg_maxiter, cg_tol=cg_tol, use_preconditioner=use_preconditioner, ), ) return result, (result, result.loss) final_state: GaussNewtonState outputs: Tuple[GaussNewtonState, Float[Array, " N"]] final_state, outputs = jax.lax.scan( step_fn, state, None, length=max_iterations ) all_states: GaussNewtonState all_losses: Float[Array, " N"] all_states, all_losses = outputs return final_state, all_states, all_losses
[docs] @partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) @jaxtyped(typechecker=beartype) def gn_loss_history( state: GaussNewtonState, residual_fn: Callable[[Float[Array, " n"]], Float[Array, " m"]], max_iterations: int = 100, cg_maxiter: int = 50, cg_tol: float = 1e-5, use_preconditioner: bool = False, ) -> Tuple[GaussNewtonState, Float[Array, " N"]]: """Run Gauss-Newton optimization and return per-iteration losses. This variant tracks only the scalar loss per iteration and not the full state history, which keeps memory overhead low for large ptychography problems. """ def step_fn( carry: GaussNewtonState, _: None ) -> Tuple[GaussNewtonState, Float[Array, " "]]: result: GaussNewtonState = jax.lax.cond( carry.converged, lambda: carry, lambda: gn_step( carry, residual_fn, cg_maxiter=cg_maxiter, cg_tol=cg_tol, use_preconditioner=use_preconditioner, ), ) return result, result.loss final_state: GaussNewtonState all_losses: Float[Array, " N"] final_state, all_losses = jax.lax.scan( step_fn, state, None, length=max_iterations ) return final_state, all_losses
[docs] @partial(jax.jit, static_argnums=(0, 2)) @jaxtyped(typechecker=beartype) def max_eigenval( residual_fn: Callable[[Float[Array, " n"]], Float[Array, " m"]], params: Float[Array, " n"], num_iterations: int = 20, ) -> Float[Array, " "]: """Estimate largest eigenvalue of J^T J via power iteration. A large eigenvalue indicates potential ill-conditioning, which can cause convergence issues. Use this diagnostic to tune damping. Implementation Logic -------------------- The function implements the power iteration algorithm: 1. **Setup**: - Construct matvec operator: (J^T J) @ v via jtj_matvec - Initialize random unit vector v₀ ~ N(0, I), normalized - Uses fixed seed (42) for reproducibility 2. **Power Iteration Loop** (num_iterations times): Internal function power_step(v_curr) performs: a. Compute Av = (J^T J) @ v_curr via matvec - Cost: 1 forward + 1 backward pass through residual_fn b. Normalize: v_next = Av / ||Av|| - Ensures numerical stability, prevents overflow c. Return v_next (carry state for next iteration) Uses jax.lax.scan for efficient iteration: - Compiled as single fused kernel by XLA - Avoids Python loop overhead - scan(power_step, v₀, None, length=num_iterations) 3. **Eigenvalue Extraction**: - Compute final Rayleigh quotient: λ = v^T (J^T J) v - After k iterations, vₖ ≈ dominant eigenvector - Rayleigh quotient converges to λ_max(J^T J) Algorithm: Power Iteration -------------------------- For symmetric positive semi-definite A = J^T J: vₖ₊₁ = A vₖ / ||A vₖ|| Converges to dominant eigenvector v₁ at rate (λ₂/λ₁)^k where λ₁ ≥ λ₂ ≥ ... are eigenvalues. Convergence is geometric. With k=20 iterations and typical spectral gaps, relative error is usually < 1%. For very ill-conditioned problems (λ₁ ≈ λ₂), increase num_iterations to 50-100. Computational Cost ------------------ - Total: num_iterations × (1 forward + 1 backward pass) - For num_iterations=20: 20 forward + 20 backward passes - Plus 1 final matvec for Rayleigh quotient - Memory: O(n) for storing vₖ Parameters ---------- residual_fn : Callable[[Float[Array, " n"]], Float[Array, " m"]] Residual function. params : Float[Array, " n"] Current parameters. num_iterations : int, optional Number of power iterations. Default is 20. Returns ------- lambda_max : Float[Array, " "] Estimate of the largest eigenvalue of J^T J. Notes ----- Power iteration computes: v_{k+1} = (J^T J) @ v_k / ||(J^T J) @ v_k|| After convergence, the Rayleigh quotient v^T (J^T J) v gives λ_max. This is useful for: - Diagnosing ill-conditioning: condition number κ ≈ λ_max / λ_min - Tuning initial damping: set λ₀ ~ 10^{-3} λ_max as starting point - Detecting rank deficiency: if λ_max ≈ 0, J has null space Examples -------- >>> lambda_max = max_eigenval(residual_fn, params) >>> print(f"Maximum eigenvalue: {lambda_max:.2e}") """ matvec: Callable[[Float[Array, " n"]], Float[Array, " n"]] = ( jtj_matvec(residual_fn, params, jnp.array(0.0)) ) key: Array = jax.random.PRNGKey(42) v: Float[Array, " n"] = jax.random.normal(key, params.shape) v = v / jnp.linalg.norm(v) def power_step( v_curr: Float[Array, " n"], _: None ) -> Tuple[Float[Array, " n"], None]: av: Float[Array, " n"] = matvec(v_curr) v_next: Float[Array, " n"] = av / jnp.linalg.norm(av) return v_next, None v_final: Float[Array, " n"] _: None v_final, _ = jax.lax.scan(power_step, v, None, length=num_iterations) av: Float[Array, " n"] = matvec(v_final) result: Float[Array, " "] = jnp.dot(v_final, av) return result
[docs] @partial(jax.jit, static_argnums=(0, 2)) @jaxtyped(typechecker=beartype) def jtj_diag( residual_fn: Callable[[Float[Array, " n"]], Float[Array, " m"]], params: Float[Array, " n"], num_samples: int = 10, ) -> Float[Array, " n"]: """Estimate diagonal of J^T J via Hutchinson's trace estimator. The diagonal approximates per-parameter sensitivities and is useful for diagonal preconditioning in conjugate gradient. Implementation Logic -------------------- The function implements Hutchinson's stochastic trace estimator: 1. **Setup**: - Construct matvec operator: (J^T J) @ v via jtj_matvec - Generate num_samples independent random keys 2. **Per-Sample Estimation** (parallelized via vmap): Internal function estimate_one(key) performs: a. Sample probe vector z ~ Rademacher({-1, +1}^n) - Each entry independently ±1 with equal probability - Uses params.dtype to match precision (float32/float64) b. Compute Az = (J^T J) @ z via matvec - Cost: 1 forward + 1 backward pass through residual_fn c. Compute element-wise product: z ⊙ Az - This is the single-sample diagonal estimate d. Return z ⊙ Az Uses jax.vmap(estimate_one)(keys): - Vectorizes over num_samples keys - Compiles to parallel batch of matrix-vector products - Returns shape (num_samples, n) stacked estimates 3. **Averaging**: - Compute mean across samples: diagonal = mean(estimates, axis=0) - Reduces variance by √(num_samples) Algorithm: Hutchinson's Estimator ---------------------------------- For any symmetric matrix A, diagonal entries satisfy: A_ii = E_z[z_i (A @ z)_i] where z ~ Rademacher({-1, +1}^n). This is unbiased: E[z_i (A @ z)_i] = E[z_i Σⱼ A_ij z_j] = Σⱼ A_ij E[z_i z_j] = A_ii since E[z_i z_j] = δ_ij (orthogonality of Rademacher). Variance and Sampling --------------------- The variance of the single-sample estimator is: Var[z_i (A @ z)_i] ≈ 2 Σⱼ A_ij² For matrices with strong off-diagonal structure (like J^T J in ptychography due to overlapping scan positions), this variance can be large. The estimator variance scales as: Var[diagonal_estimate] ~ (2/num_samples) Σⱼ A_ij² With num_samples=10, standard error is ~√(0.2 Σⱼ A_ij²). For ill-conditioned problems, increase to 50-100 samples. Why Rademacher? Could also use Gaussian, but Rademacher: - Has lower variance for heavy-tailed eigenvalue distributions - Is memory-efficient (binary values) - Allows exact integer arithmetic in some contexts Note: Individual diagonal estimates can be negative due to finite sampling, even though true diagonal entries are non-negative (J^T J is positive semi-definite). When using for preconditioning, clamp to zero: max(diagonal, 0) before inversion. Computational Cost ------------------ - Total: num_samples × (1 forward + 1 backward pass) - For num_samples=10: 10 forward + 10 backward passes - All samples computed in parallel via vmap (batch size = num_samples) - Memory: O(num_samples × n) for stacked estimates Parameters ---------- residual_fn : Callable[[Float[Array, " n"]], Float[Array, " m"]] Residual function. params : Float[Array, " n"] Current parameters. num_samples : int, optional Number of random probes. Default is 10. Returns ------- diagonal : Float[Array, " n"] Estimated diagonal of J^T J. Notes ----- Uses Hutchinson's trick: diag(A) ≈ E[z ⊙ (A @ z)] for random z ∈ {-1, +1}^n (Rademacher distribution). The diagonal can be used to construct a preconditioner: M = diag(J^T J + ε)^{-1/2} Trade-off: More samples → lower variance, higher cost. For most problems, 10 samples is sufficient. For high accuracy or when diagonal entries vary by orders of magnitude, use 50-100 samples. Examples -------- >>> diag = jtj_diag(residual_fn, params) >>> precond = 1.0 / jnp.sqrt(diag + 1e-6) """ n: ScalarInteger = params.shape[0] matvec: Callable[[Float[Array, " n"]], Float[Array, " n"]] = ( jtj_matvec(residual_fn, params, jnp.array(0.0)) ) def estimate_one(key: Array) -> Float[Array, " n"]: z: Float[Array, " n"] = jax.random.rademacher( key, (n,), dtype=params.dtype ) az: Float[Array, " n"] = matvec(z) result: Float[Array, " n"] = z * az return result keys: Array = jax.random.split(jax.random.PRNGKey(0), num_samples) estimates: Float[Array, " num_samples n"] = jax.vmap(estimate_one)(keys) diagonal: Float[Array, " n"] = jnp.mean(estimates, axis=0) return diagonal