Source code for janssen.invert.mixed_state

"""Mixed-state ptychography for partially coherent sources.

Extended Summary
----------------
Mixed-state ptychography reconstructs both object and probe modes when
the illumination is partially coherent. The forward model sums diffraction
intensities from each coherent mode incoherently:

    I_i = Sigma_n w_n |FFT(probe_n * shift(object, r_i))|^2

All functions are fully differentiable, enabling gradient-based optimization
of object, probe modes, mode weights, and even coherence parameters.

Routine Listings
----------------
mixed_state_forward : function
    Compute predicted diffraction patterns for all positions.
mixed_state_forward_single_position : function
    Forward model for one scan position with mixed-state illumination.
mixed_state_loss : function
    Compute reconstruction loss for mixed-state ptychography.
coherence_parameterized_loss : function
    Loss function with coherence width as optimizable parameter.
mixed_state_gradient_step : function
    Single gradient descent step for mixed-state reconstruction.
mixed_state_reconstruct : function
    Run mixed-state ptychography reconstruction.

Notes
-----
The coherent mode representation enables efficient simulation of partial
coherence. For M modes on an N×N grid, memory scales as O(M * N^2) versus
O(N^4) for full mutual intensity.

The MixedStatePtychoData PyTree and make_mixed_state_ptycho_data factory
function are defined in janssen.types.coherence_types and re-exported here
for convenience.
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Array, Complex, Float, jaxtyped

from janssen.cohere import gaussian_schell_model_modes
from janssen.types import (
    CoherentModeSet,
    MixedStatePtychoData,
    ScalarInteger,
    make_coherent_mode_set,
    make_mixed_state_ptycho_data,
)


@jaxtyped(typechecker=beartype)
def _fourier_shift(
    field: Complex[Array, " H W"],
    shift_x: Float[Array, " "],
    shift_y: Float[Array, " "],
) -> Complex[Array, " H W"]:
    """Shift field using Fourier phase ramp."""
    h, w = field.shape
    ky: Float[Array, " H 1"] = jnp.fft.fftfreq(h)[:, None]
    kx: Float[Array, " 1 W"] = jnp.fft.fftfreq(w)[None, :]
    phase_ramp: Complex[Array, " H W"] = jnp.exp(
        -2j * jnp.pi * (kx * shift_x + ky * shift_y)
    )
    shifted: Complex[Array, " H W"] = jnp.fft.ifft2(
        jnp.fft.fft2(field) * phase_ramp
    )
    return shifted


[docs] @jaxtyped(typechecker=beartype) def mixed_state_forward_single_position( probe_modes: Complex[Array, " M H W"], mode_weights: Float[Array, " M"], obj: Complex[Array, " H W"], shift_x: Float[Array, " "], shift_y: Float[Array, " "], ) -> Float[Array, " H W"]: """Forward model for one scan position with mixed-state illumination. Computes: I = Σₙ wₙ |FFT(probe_n · shift(object, r_i))|² Parameters ---------- probe_modes : Complex[Array, " M H W"] Coherent probe modes. mode_weights : Float[Array, " M"] Mode weights (eigenvalues). obj : Complex[Array, " H W"] Object transmission. shift_x, shift_y : Float[Array, " "] Scan position in pixels. Returns ------- intensity : Float[Array, " H W"] Diffraction intensity (incoherent sum over modes). """ def _single_mode_intensity_int( mode: Complex[Array, " H W"], ) -> Float[Array, " H W"]: """Compute diffraction intensity for single probe mode.""" probe_shifted: Complex[Array, " H W"] = _fourier_shift( mode, shift_x, shift_y ) exit_wave: Complex[Array, " H W"] = obj * probe_shifted exit_wave_ft: Complex[Array, " H W"] = jnp.fft.fftshift( jnp.fft.fft2(exit_wave) ) intensity: Float[Array, " H W"] = jnp.abs(exit_wave_ft) ** 2 return intensity mode_intensities: Float[Array, " M H W"] = jax.vmap( _single_mode_intensity_int )(probe_modes) total: Float[Array, " H W"] = jnp.einsum( "m,mhw->hw", mode_weights, mode_intensities ) return total
[docs] @jaxtyped(typechecker=beartype) def mixed_state_forward( data: MixedStatePtychoData, ) -> Float[Array, " N H W"]: """Compute predicted diffraction patterns for all positions. Parameters ---------- data : MixedStatePtychoData Ptychography data with probe modes and object. Returns ------- predicted : Float[Array, " N H W"] Predicted diffraction intensities. """ modes: Complex[Array, " M H W"] = data.probe_modes.modes weights: Float[Array, " M"] = data.probe_modes.weights obj: Complex[Array, " H W"] = data.sample positions: Float[Array, " N 2"] = data.positions def _forward_one_int(pos: Float[Array, " 2"]) -> Float[Array, " H W"]: """Compute forward model for single scan position.""" intensity: Float[Array, " H W"] = mixed_state_forward_single_position( modes, weights, obj, pos[0], pos[1] ) return intensity predicted: Float[Array, " N H W"] = jax.vmap(_forward_one_int)(positions) return predicted
[docs] @jaxtyped(typechecker=beartype) def mixed_state_loss( data: MixedStatePtychoData, loss_type: str = "amplitude", ) -> Float[Array, " "]: """Compute reconstruction loss for mixed-state ptychography. Parameters ---------- data : MixedStatePtychoData Current state with probe modes and object. loss_type : str "amplitude" for ||√I_exp - √I_pred||² (default, robust) "intensity" for ||I_exp - I_pred||² "poisson" for Poisson negative log-likelihood Returns ------- loss : Float[Array, " "] Scalar loss value. """ predicted: Float[Array, " N H W"] = mixed_state_forward(data) measured: Float[Array, " N H W"] = data.diffraction_patterns if loss_type == "amplitude": amp_meas: Float[Array, " N H W"] = jnp.sqrt(jnp.maximum(measured, 0.0)) amp_pred: Float[Array, " N H W"] = jnp.sqrt( jnp.maximum(predicted, 1e-10) ) loss: Float[Array, " "] = jnp.sum((amp_meas - amp_pred) ** 2) return loss if loss_type == "intensity": loss: Float[Array, " "] = jnp.sum((measured - predicted) ** 2) return loss if loss_type == "poisson": eps: float = 1e-10 loss: Float[Array, " "] = jnp.sum( predicted - measured * jnp.log(predicted + eps) ) return loss raise ValueError(f"Unknown loss_type: {loss_type}")
[docs] @jaxtyped(typechecker=beartype) def coherence_parameterized_loss( coherence_width: Float[Array, " "], sample: Complex[Array, " H W"], diffraction_patterns: Float[Array, " N H W"], positions: Float[Array, " N 2"], beam_width: Float[Array, " "], wavelength: Float[Array, " "], dx: Float[Array, " "], num_modes: ScalarInteger = 10, ) -> Float[Array, " "]: """Loss function with coherence width as the optimizable parameter. This enables gradient-based recovery of source coherence: σ_c* = argmin_{σ_c} L(σ_c; I_measured) Parameters ---------- coherence_width : Float[Array, " "] Coherence width to optimize. sample : Complex[Array, " H W"] Object (fixed or jointly optimized). diffraction_patterns : Float[Array, " N H W"] Measured data. positions : Float[Array, " N 2"] Scan positions. beam_width : Float[Array, " "] Known beam intensity width. wavelength : Float[Array, " "] Wavelength. dx : Float[Array, " "] Pixel size. num_modes : int Number of GSM modes. Returns ------- loss : Float[Array, " "] Scalar loss. Notes ----- The gradient ∂L/∂σ_c flows through: 1. σ_c → eigenvalues λₙ (analytical for GSM) 2. σ_c → mode width w_mode → mode shapes φₙ 3. modes → forward model → loss No hand-derived update rules needed! """ hh: int = sample.shape[0] ww: int = sample.shape[1] probe_modes: CoherentModeSet = gaussian_schell_model_modes( wavelength=wavelength, dx=dx, grid_size=(hh, ww), beam_width=beam_width, coherence_width=coherence_width, num_modes=num_modes, ) data: MixedStatePtychoData = make_mixed_state_ptycho_data( diffraction_patterns=diffraction_patterns, probe_modes=probe_modes, sample=sample, positions=positions, wavelength=wavelength, dx=dx, ) loss: Float[Array, " "] = mixed_state_loss(data, loss_type="amplitude") return loss
[docs] @jaxtyped(typechecker=beartype) def mixed_state_gradient_step( data: MixedStatePtychoData, learning_rate: float = 1e-3, update_object: bool = True, update_modes: bool = True, update_weights: bool = False, ) -> MixedStatePtychoData: """Single gradient descent step for mixed-state reconstruction. Parameters ---------- data : MixedStatePtychoData Current state. learning_rate : float Step size. update_object : bool Whether to update object. update_modes : bool Whether to update probe mode fields. update_weights : bool Whether to update mode weights. Returns ------- updated_data : MixedStatePtychoData State after one gradient step. """ def _loss_fn_int( sample: Complex[Array, " Hs Ws"], modes: Complex[Array, " M H W"], weights: Float[Array, " M"], ) -> Float[Array, " "]: """Compute mixed-state loss for gradient computation.""" updated_probe: CoherentModeSet = make_coherent_mode_set( modes=modes, weights=weights, wavelength=data.probe_modes.wavelength, dx=data.probe_modes.dx, z_position=data.probe_modes.z_position, polarization=data.probe_modes.polarization, ) updated_data: MixedStatePtychoData = make_mixed_state_ptycho_data( diffraction_patterns=data.diffraction_patterns, probe_modes=updated_probe, sample=sample, positions=data.positions, wavelength=data.wavelength, dx=data.dx, ) loss: Float[Array, " "] = mixed_state_loss(updated_data) return loss grad_fn = jax.grad(_loss_fn_int, argnums=(0, 1, 2)) grad_sample: Complex[Array, " Hs Ws"] grad_modes: Complex[Array, " M H W"] grad_weights: Float[Array, " M"] grad_sample, grad_modes, grad_weights = grad_fn( data.sample, data.probe_modes.modes, data.probe_modes.weights ) new_sample: Complex[Array, " Hs Ws"] = jax.lax.cond( update_object, lambda: data.sample - learning_rate * grad_sample, lambda: data.sample, ) new_modes: Complex[Array, " M H W"] = jax.lax.cond( update_modes, lambda: data.probe_modes.modes - learning_rate * grad_modes, lambda: data.probe_modes.modes, ) new_weights: Float[Array, " M"] = jax.lax.cond( update_weights, lambda: data.probe_modes.weights - learning_rate * grad_weights, lambda: data.probe_modes.weights, ) new_weights = jnp.maximum(new_weights, 1e-10) new_weights = new_weights / jnp.sum(new_weights) updated_probe: CoherentModeSet = make_coherent_mode_set( modes=new_modes, weights=new_weights, wavelength=data.probe_modes.wavelength, dx=data.probe_modes.dx, z_position=data.probe_modes.z_position, polarization=data.probe_modes.polarization, normalize_weights=False, ) updated_data: MixedStatePtychoData = make_mixed_state_ptycho_data( diffraction_patterns=data.diffraction_patterns, probe_modes=updated_probe, sample=new_sample, positions=data.positions, wavelength=data.wavelength, dx=data.dx, ) return updated_data
[docs] @jaxtyped(typechecker=beartype) def mixed_state_reconstruct( data: MixedStatePtychoData, num_iterations: ScalarInteger = 100, learning_rate: float = 1e-3, update_object: bool = True, update_modes: bool = True, update_weights: bool = False, ) -> tuple[MixedStatePtychoData, Float[Array, " I"]]: """Run mixed-state ptychography reconstruction. Parameters ---------- data : MixedStatePtychoData Initial state with probe modes and object estimate. num_iterations : int Number of gradient descent iterations. learning_rate : float Step size. update_object : bool Whether to reconstruct object. update_modes : bool Whether to reconstruct probe modes. update_weights : bool Whether to learn mode weights. Returns ------- final_data : MixedStatePtychoData Reconstructed state. loss_history : Float[Array, " I"] Loss at each iteration. """ def _step_int( carry: MixedStatePtychoData, _: None ) -> tuple[MixedStatePtychoData, Float[Array, " "]]: """Single gradient descent iteration.""" current_data: MixedStatePtychoData = carry loss: Float[Array, " "] = mixed_state_loss(current_data) updated: MixedStatePtychoData = mixed_state_gradient_step( current_data, learning_rate, update_object, update_modes, update_weights, ) return updated, loss final_data: MixedStatePtychoData loss_history: Float[Array, " I"] final_data, loss_history = jax.lax.scan( _step_int, data, None, length=int(num_iterations) ) return final_data, loss_history