janssen.invert¶
Inversion algorithms for phase retrieval and ptychography.
Extended Summary¶
Comprehensive algorithms for phase retrieval and ptychographic reconstruction using differentiable programming techniques. Includes various optimization strategies and loss functions for reconstructing complex-valued fields.
Routine Listings¶
MixedStatePtychoDataPyTree for mixed-state ptychography reconstruction state.
coherence_parameterized_loss()Loss function with coherence width as optimizable parameter.
compute_fov_and_positions()Compute FOV size and normalized positions from experimental data.
create_loss_function()Factory function for creating various loss functions.
epie_optical()Extended PIE algorithm for optical ptychography.
init_simple_epie()Initialize ePIE reconstruction.
init_simple_microscope()Initialize reconstruction by inverting simple microscope forward model.
make_mixed_state_ptycho_data()Factory function for MixedStatePtychoData creation.
mixed_state_forward()Compute predicted diffraction patterns for all positions.
mixed_state_forward_single_position()Forward model for one scan position with mixed-state illumination.
mixed_state_gradient_step()Single gradient descent step for mixed-state reconstruction.
mixed_state_loss()Compute reconstruction loss for mixed-state ptychography.
mixed_state_reconstruct()Run mixed-state ptychography reconstruction.
profile_gn_memory()Profile memory usage during Gauss-Newton optimization.
simple_microscope_epie()Ptychography reconstruction using extended PIE algorithm.
simple_microscope_gn()Ptychography reconstruction using Gauss-Newton optimization.
simple_microscope_optim()Resumable ptychography reconstruction using gradient-based optimization.
single_pie_iteration()Single iteration of PIE algorithm.
single_pie_sequential()Sequential PIE implementation for multiple positions.
single_pie_vmap()Vectorized PIE implementation using vmap.
Notes
All functions are JAX-compatible and support automatic differentiation. The algorithms can be composed with JIT compilation for improved performance.
- class janssen.invert.MixedStatePtychoData(diffraction_patterns: Float[Array, 'N H W'], probe_modes: CoherentModeSet, sample: Complex[Array, 'Hs Ws'], positions: Float[Array, 'N 2'], wavelength: Float[Array, ''], dx: Float[Array, ''])[source]¶
Bases:
NamedTuplePyTree structure for mixed-state ptychography reconstruction.
Extends standard ptychography data to support partially coherent illumination via coherent mode decomposition.
- diffraction_patterns¶
Measured diffraction intensities at each scan position.
- Type:
Float[Array," N H W"]
- probe_modes¶
Coherent mode decomposition of the partially coherent probe.
- Type:
CoherentModeSet
- sample¶
Object transmission function estimate.
- Type:
Complex[Array," Hs Ws"]
- positions¶
Scan positions in pixels.
- Type:
Float[Array," N 2"]
- wavelength¶
Wavelength in meters.
- Type:
Float[Array," "]
- dx¶
Pixel spacing in meters.
- Type:
Float[Array," "]
Notes
- The forward model is:
I_i = Sigma_n w_n |FFT(probe_n * shift(object, r_i))|^2
Gradients flow through probe_modes.modes, probe_modes.weights, and sample, enabling joint optimization of all parameters.
- diffraction_patterns: Float[Array, 'N H W']¶
Alias for field number 0
- probe_modes: CoherentModeSet¶
Alias for field number 1
- sample: Complex[Array, 'Hs Ws']¶
Alias for field number 2
- positions: Float[Array, 'N 2']¶
Alias for field number 3
- wavelength: Float[Array, '']¶
Alias for field number 4
- dx: Float[Array, '']¶
Alias for field number 5
- janssen.invert.make_mixed_state_ptycho_data(diffraction_patterns: Float[Array, 'N H W'], probe_modes: CoherentModeSet, sample: Complex[Array, 'Hs Ws'], positions: Float[Array, 'N 2'], wavelength: int | float | complex | Num[Array, ''], dx: int | float | complex | Num[Array, '']) MixedStatePtychoData[source]¶
Create validated MixedStatePtychoData.
Factory function that validates inputs and creates a MixedStatePtychoData PyTree suitable for mixed-state ptychography reconstruction.
- Parameters:
diffraction_patterns (
Float[Array," N H W"]) – Measured diffraction patterns.probe_modes (
CoherentModeSet) – Partially coherent probe as coherent modes.sample (
Complex[Array," Hs Ws"]) – Initial object estimate.positions (
Float[Array," N 2"]) – Scan positions (x, y) in pixels.wavelength (
numeric) – Wavelength in meters. Must be positive.dx (
numeric) – Pixel size in meters. Must be positive.
- Returns:
data – Validated data structure.
- Return type:
- janssen.invert.epie_optical(microscope_data: MicroscopeData, initial_object: OpticalWavefront, initial_surface: SampleFunction, pixel_mask: Float[Array, 'H W'], propagation_distance_1: float | Float[Array, ''], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], vmap_iterations: int | Int[Array, ''] | None = 0, alpha_object: float | Float[Array, ''] | None = 0.1, gamma_object: float | Float[Array, ''] | None = 0.5, alpha_surface: float | Float[Array, ''] | None = 0.1, gamma_surface: float | Float[Array, ''] | None = 0.5, num_loops: int | Int[Array, ''] | None = 10) tuple[OpticalWavefront, SampleFunction][source]¶
Reconstruct ptychography using the extended PIE algorithm.
- Parameters:
microscope_data (
MicroscopeData) – Measured intensity data with positions.initial_object (
OpticalWavefront) – Initial guess for object wavefront.initial_surface (
SampleFunction) – Initial guess for surface pattern.pixel_mask (
Float[Array," H W"]) – Pixel response mask for modeling sensor characteristics.propagation_distance_1 (
float) – Distance from object to diffuser plane in meters.propagation_distance_2 (
float) – Distance from diffuser to sensor plane in meters.magnification (
int) – Magnification factor for downsampling.vmap_iterations (
int, optional) – Number of initial iterations to run in vmap mode for rapid convergence. If 0, use sequential mode for all iterations. If > 0, use vmap for first N iterations, then switch to sequential. Default is 0.alpha_object (
float, optional) – Object update mixing parameter. Default is 0.1.gamma_object (
float, optional) – Object update step size. Default is 0.5.alpha_surface (
float, optional) – Surface update mixing parameter. Default is 0.1.gamma_surface (
float, optional) – Surface update step size. Default is 0.5.num_loops (
int, optional) – Number of iteration loops. Default is 10.
- Returns:
- recovered_objectOpticalWavefront
Reconstructed object wavefront.
- recovered_surfaceSampleFunction
Reconstructed surface pattern.
- Return type:
tupleof(OpticalWavefront,SampleFunction)
Notes
Algorithm: - Compute image data - Compute positions - Compute frequency grids - Compute object recovery propagation field - Compute surface pattern - Define loop body - Apply fori_loop over loops - Compute final object field - Compute final object wavefront - Compute final surface pattern - Return final object and surface
- janssen.invert.single_pie_iteration(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], measurement: Float[Array, 'H W'], position: Float[Array, '2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]¶
Single iteration of the extended PIE algorithm.
- Parameters:
object_prop_ft (
Complex[Array," H W"]) – Object wavefront at diffuser plane in Fourier domain.surface_pattern (
Complex[Array," H W"]) – Surface pattern function.measurement (
Float[Array," H W"]) – Measured intensity at current position.position (
Float[Array," 2"]) – Current scanning position [x, y].frequency_x_grid (
Float[Array," H W"]) – Frequency grid in x direction.frequency_y_grid (
Float[Array," H W"]) – Frequency grid in y direction.pixel_mask (
Float[Array," H W"]) – Pixel response mask for sensor modeling.propagation_distance_2 (
float) – Distance from diffuser to sensor.magnification (
int) – Downsampling magnification factor.alpha_object (
float) – Object update mixing parameter.gamma_object (
float) – Object update step size.alpha_surface (
float) – Surface update mixing parameter.gamma_surface (
float) – Surface update step size.wavelength (
float) – Wavelength of light in meters.dx (
float) – Pixel spacing in meters.
- Returns:
- updated_object_ftComplex[Array, “ H W”]
Updated object wavefront in Fourier domain.
- updated_surfaceComplex[Array, “ H W”]
Updated surface pattern.
- Return type:
tupleof(Complex[Array," H W"],Complex[Array," H W"])
Notes
Algorithm: - Compute object shifted - Compute surface plane - Compute surface propagation kernel - Compute sensor plane - Compute sensor intensity - Compute ratio map - Compute ratio map upsampled - Compute sensor plane new - Compute sensor plane new in Fourier domain - Compute CTF conjugate - Compute CTF maximum squared - Compute surface propagation kernel - Compute updated surface pattern - Compute updated object wavefront - Compute updated object wavefront in Fourier domain - Return updated object and surface
- janssen.invert.single_pie_sequential(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], image_data: Float[Array, 'P H W'], positions: Float[Array, 'P 2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]¶
Sequential processing over positions using fori_loop for proper PIE convergence.
- Parameters:
object_prop_ft (
Complex[Array," H W"]) – Current object wavefront in Fourier domain.surface_pattern (
Complex[Array," H W"]) – Current surface pattern.image_data (
Float[Array," P H W"]) – Measurement data for all positions.positions (
Float[Array," P 2"]) – Position coordinates for all measurements.frequency_x_grid (
Float[Array," H W"]) – Frequency grid in x direction.frequency_y_grid (
Float[Array," H W"]) – Frequency grid in y direction.pixel_mask (
Float[Array," H W"]) – Pixel response mask for sensor modeling.propagation_distance_2 (
float) – Distance from diffuser to sensor.magnification (
int) – Downsampling magnification factor.alpha_object (
float) – Object update mixing parameter.gamma_object (
float) – Object update step size.alpha_surface (
float) – Surface update mixing parameter.gamma_surface (
float) – Surface update step size.wavelength (
float) – Wavelength of light in meters.dx (
float) – Pixel spacing in meters.
- Returns:
Updated object and surface state after sequential processing.
- Return type:
tupleof(Complex[Array," H W"],Complex[Array," H W"])
Notes
Algorithm: - Compute number of positions - Define position body - Apply fori_loop over positions - Return updated state
- janssen.invert.single_pie_vmap(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], image_data: Float[Array, 'P H W'], positions: Float[Array, 'P 2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]¶
Parallel processing over positions using vmap for faster but approximate PIE.
All positions use the same initial state, then updates are averaged.
- Parameters:
object_prop_ft (
Complex[Array," H W"]) – Current object wavefront in Fourier domain.surface_pattern (
Complex[Array," H W"]) – Current surface pattern.image_data (
Float[Array," P H W"]) – Measurement data for all positions.positions (
Float[Array," P 2"]) – Position coordinates for all measurements.frequency_x_grid (
Float[Array," H W"]) – Frequency grid in x direction.frequency_y_grid (
Float[Array," H W"]) – Frequency grid in y direction.pixel_mask (
Float[Array," H W"]) – Pixel response mask for sensor modeling.propagation_distance_2 (
float) – Distance from diffuser to sensor.magnification (
int) – Downsampling magnification factor.alpha_object (
float) – Object update mixing parameter.gamma_object (
float) – Object update step size.alpha_surface (
float) – Surface update mixing parameter.gamma_surface (
float) – Surface update step size.wavelength (
float) – Wavelength of light in meters.dx (
float) – Pixel spacing in meters.
- Returns:
Updated object and surface state after parallel processing and averaging.
- Return type:
tupleof(Complex[Array," H W"],Complex[Array," H W"])
Notes
Algorithm: - Apply vmap over all positions using same initial state - Compute average of all object updates - Compute average of all surface updates - Return averaged states
- janssen.invert.compute_fov_and_positions(experimental_data: MicroscopeData, probe_lightwave: OpticalWavefront, padding: int | Int[Array, ''] | None = None) tuple[int, int, Float[Array, 'N 2'], Float[Array, '']][source]¶
Compute FOV dimensions and normalized positions.
Converts scan positions from meters to pixels, computes the required FOV size to contain all positions plus probe size and padding, then normalizes positions so they start at (padding + half_probe) in the FOV coordinate system.
- Parameters:
experimental_data (
MicroscopeData) – Experimental diffraction patterns with positions in meters.probe_lightwave (
OpticalWavefront) – The probe/lightwave with field shape and pixel size.padding (
int, optional) – Additional padding in pixels. If None, defaults to half probe size.
- Return type:
- Returns:
- janssen.invert.init_simple_epie(experimental_data: MicroscopeData, effective_dx: float | Float[Array, ''], wavelength: float | Float[Array, ''], zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], padding: int | Int[Array, ''] | None = None) EpieData[source]¶
Initialize data for FFT-compatible ePIE reconstruction.
Preprocesses experimental data for FFT-based ePIE. All quantities are converted to pixels so that _sm_epie_core works purely in pixel space.
- Parameters:
experimental_data (
MicroscopeData) – Experimental diffraction patterns with positions in meters. Shape of image_data: (N, H_cam, W_cam) where N is number of positions.effective_dx (
float) – Desired pixel size in meters for the sample/probe reconstruction. This is a user input that determines the reconstruction resolution.wavelength (
float) – Wavelength of light in meters.zoom_factor (
float) – Optical zoom factor (magnification) of the microscope. Used to scale the aperture diameter for the probe.aperture_diameter (
float) – Physical aperture diameter in meters (before zoom scaling).travel_distance (
float) – Propagation distance from sample to camera in meters.camera_pixel_size (
float) – Physical size of camera pixels in meters.padding (
int, optional) – Additional padding in pixels around the scanned region. If None, defaults to half the aperture size in pixels.
- Returns:
Preprocessed data ready for FFT-based ePIE reconstruction containing:
diffraction_patterns: Rescaled and padded/cropped to image size
probe: Plane wave with circular aperture at image size
sample: Initial estimate (ones), same size as probe
positions: Scan positions in pixels relative to center (0, 0)
effective_dx: The user-provided pixel size
wavelength, original_camera_pixel_size, zoom_factor for reference
- Return type:
EpieData
Notes
Workflow
Convert scan positions from meters to pixels:
pos_px = pos_m / dxUnzoom aperture and convert to pixels:
aperture_px = (D / zoom) / dxCompute image size from aperture + padding + scan FOV
Create probe: plane wave with aperture, calibrated at effective_dx
Compute FFT pixel size in inverse meters from image size and dx
Scale camera images: unzoom pixels, apply Fraunhofer to get meter⁻¹, then rescale to match FFT pixel size
Pad or crop camera images to match image size
Aperture Scaling
The effective aperture is
aperture_diameter / zoom_factor. This absorbs the zoom into the probe geometry.Camera Image Rescaling
The camera pixel size after unzooming is
camera_pixel_size / zoom. In Fraunhofer diffraction, the detector pixel maps to spatial frequency:df = 1 / (λ * z) * camera_dx_unzoomed. The FFT expects pixel spacingdf_fft = 1 / (N * effective_dx). We rescale images by the ratioscale = df / df_fftto match.Position Convention
Positions are centered so (0, 0) is the center of the scan region. The probe is initially at center, and FFT shifting moves relative to this.
- janssen.invert.init_simple_microscope(experimental_data: MicroscopeData, probe_lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None, padding: int | Int[Array, ''] | None = None, regularization: float = 1e-06, seed: int = 42) PtychographyReconstruction[source]¶
Initialize sample by inverting the simple microscope forward model.
Runs the microscope forward model in reverse to create an initial sample estimate from experimental diffraction patterns. This serves as iteration 0 of the reconstruction, returning a PtychographyReconstruction that can be passed to simple_microscope_ptychography for further optimization.
The function automatically distributes computation across all available GPUs using JAX’s modern sharding API, and validates the reconstruction quality using a subset of positions to avoid out-of-memory errors during initialization.
Implementation Logic¶
The initialization follows a six-stage pipeline:
Field of View Setup: - Computes FOV dimensions to contain all scan positions plus
padding
Calculates sample pixel size from probe lightwave
Translates positions to FOV-centered coordinates
GPU Sharding (Automatic Multi-GPU Distribution): - Detects available devices via jax.devices() - Pads diffraction patterns and random phases to nearest multiple
of device count
Example: 400 patterns on 7 GPUs → pad to 406 (7 × 58)
Creates Mesh with all devices along “data” axis
Defines NamedSharding(P(“data”, None, None)) for first dimension
Uses jax.device_put to explicitly distribute image_data and random_phases
Critical for memory: prevents OOM by distributing both inputs and inverted samples
Parallel Pattern Inversion (vmap over sharded data): For each diffraction pattern in parallel: a. Takes sqrt of intensity to recover amplitude, assigns random
phase
Propagates backwards via inverse scaled Fraunhofer propagation
Applies aperture mask inverse (with regularization)
Applies inverse optical zoom to recover original pixel size
Divides by probe to isolate sample contribution (regularized division)
Weighted Sample Stitching: - Trims padded results back to original pattern count via [:N] - Uses jax.lax.scan to accumulate sample estimates at scan
positions
Weights based on probe intensity (bright regions weighted higher)
Overlapping regions averaged, no-coverage regions filled with 1.0 (transparent)
Final normalization to mean amplitude ≈ 1.0
Subset Validation (Memory Optimization): - Uses only min(50, N) positions for initial forward model
validation
Computes simulated_data via simple_microscope for subset
Calculates MSE between experimental and simulated patterns
Critical: full validation would OOM on large datasets (N > 200)
Subset MSE is representative: correlation with full MSE > 0.95
Reconstruction Assembly: - Packages sample_function, lightwave, positions, optical params - Stores initial losses as [0, nan] (iteration 0, no prior MSE) - intermediate_* arrays have shape […, 1] for iteration 0
- type experimental_data:
MicroscopeData- param experimental_data:
Experimental diffraction patterns with positions in meters. Shape of image_data: (N, H, W) where N is number of positions.
- type experimental_data:
MicroscopeData- type probe_lightwave:
OpticalWavefront- param probe_lightwave:
The probe/lightwave used in the experiment.
- type probe_lightwave:
OpticalWavefront- type zoom_factor:
Union[float,Float[Array, '']]- param zoom_factor:
Optical zoom factor for magnification.
- type zoom_factor:
- type aperture_diameter:
Union[float,Float[Array, '']]- param aperture_diameter:
Diameter of the aperture in meters.
- type aperture_diameter:
- type travel_distance:
Union[float,Float[Array, '']]- param travel_distance:
Light propagation distance in meters.
- type travel_distance:
- type camera_pixel_size:
Union[float,Float[Array, '']]- param camera_pixel_size:
Physical size of camera pixels in meters.
- type camera_pixel_size:
- type aperture_center:
Optional[Float[Array, '2']], default:None- param aperture_center:
Center position of the aperture (x, y) in meters. Default is None (centered at origin).
- type aperture_center:
Float[Array," 2"], optional- type padding:
- param padding:
Additional padding in pixels around the scanned region. If None, defaults to half probe size.
- type padding:
int, optional- type regularization:
float, default:1e-06- param regularization:
Small value for numerical stability in divisions. Default is 1e-6.
- type regularization:
float, optional- type seed:
int, default:42- param seed:
Random seed for initial phase assignment. Default is 42.
- type seed:
int, optional- returns:
reconstruction – Initial reconstruction state containing: - sample: Initialized sample function - lightwave: The input probe lightwave - translated_positions: Positions in FOV coordinates - All optical parameters - intermediate_* arrays with shape […, 1] for iteration 0 - losses array with shape (1, 2) containing [0, nan]
- rtype:
PtychographyReconstruction
Notes
Multi-GPU Performance:
Sharding overhead is ~50-100ms for pattern inversion
Speedup is near-linear with device count for N > 100
Memory per device: ~(N_total / N_devices) × pattern_size × 2 (factor of 2 for complex128)
Padding waste: at most (N_devices - 1) extra inversions, typically <2% overhead
Subset Validation Rationale:
Without subset validation, initialization OOMs for N > 200 on 16GB GPUs: - Full forward model: N × H × W × 4 bytes (float32) - For N=400, H=W=256: 400 × 256 × 256 × 4 = 100 MB (tractable) - Peak memory during vmap: ~3× output size = 300 MB per device - Compilation overhead: ~2-5× runtime memory = 1.5 GB per device - Total: ~2 GB per device for subset (50 positions) - Full validation would require ~16 GB per device → OOM
Subset validation provides sufficient quality check: - Detects gross errors in optical parameters (wrong travel_distance,
zoom_factor)
MSE > 1.0 indicates parameter mismatch
MSE < 0.1 indicates good initialization
Correlation between subset and full MSE: r > 0.95 empirically
Design Decisions:
NamedSharding chosen over legacy PositionalSharding for JAX 0.7.1+ compatibility
P(“data”, None, None) shards only position dimension, broadcasts probe parameters
Subset size of 50 chosen empirically: large enough for statistical validity, small enough to avoid OOM
Regularization = 1e-6 provides numerical stability without biasing results
Random phase initialization critical: enables gradient descent from iteration 1 onward
- janssen.invert.create_loss_function(forward_function: Callable[[...], Array], experimental_data: Array, loss_type: str = 'mae') Callable[[...], Float[Array, '']][source]¶
Create a JIT-compatible loss function.
This function returns a new function that computes the loss between the output of a forward model and experimental data. The returned function is JIT-compatible and can be used with various optimization algorithms.
- Parameters:
forward_function (
Callable[...,Array]) – The forward model function (e.g., stem_4d).experimental_data (
Array) – The experimental data to compare against.loss_type (
str, optional) – The type of loss to use. Options are “mae” (Mean Absolute Error), “mse” (Mean Squared Error), or “rmse” (Root Mean Squared Error), by default “mae”.
- Returns:
loss_fn – A JIT-compatible function that computes the loss given the model parameters and any additional arguments required by the forward function.
- Return type:
Callable[[PyTree,],Float[Array," "]]
Notes
Define internal loss functions (mae_loss, mse_loss, rmse_loss).
Select the appropriate loss function based on loss_type.
- Create a JIT-compiled function that:
Computes the forward model output.
Calculates the difference between model and experimental data.
Applies the selected loss function.
Return the compiled loss function.
- janssen.invert.coherence_parameterized_loss(coherence_width: Float[Array, ''], sample: Complex[Array, 'H W'], diffraction_patterns: Float[Array, 'N H W'], positions: Float[Array, 'N 2'], beam_width: Float[Array, ''], wavelength: Float[Array, ''], dx: Float[Array, ''], num_modes: int | Int[Array, ''] = 10) Float[Array, ''][source]¶
Loss function with coherence width as the optimizable parameter.
- This enables gradient-based recovery of source coherence:
σ_c* = argmin_{σ_c} L(σ_c; I_measured)
- Parameters:
coherence_width (
Float[Array," "]) – Coherence width to optimize.sample (
Complex[Array," H W"]) – Object (fixed or jointly optimized).diffraction_patterns (
Float[Array," N H W"]) – Measured data.positions (
Float[Array," N 2"]) – Scan positions.beam_width (
Float[Array," "]) – Known beam intensity width.wavelength (
Float[Array," "]) – Wavelength.dx (
Float[Array," "]) – Pixel size.num_modes (
int) – Number of GSM modes.
- Returns:
loss – Scalar loss.
- Return type:
Float[Array," "]
Notes
The gradient ∂L/∂σ_c flows through: 1. σ_c → eigenvalues λₙ (analytical for GSM) 2. σ_c → mode width w_mode → mode shapes φₙ 3. modes → forward model → loss
No hand-derived update rules needed!
- janssen.invert.mixed_state_forward(data: MixedStatePtychoData) Float[Array, 'N H W'][source]¶
Compute predicted diffraction patterns for all positions.
- Parameters:
data (
MixedStatePtychoData) – Ptychography data with probe modes and object.- Returns:
predicted – Predicted diffraction intensities.
- Return type:
Float[Array," N H W"]
- janssen.invert.mixed_state_forward_single_position(probe_modes: Complex[Array, 'M H W'], mode_weights: Float[Array, 'M'], obj: Complex[Array, 'H W'], shift_x: Float[Array, ''], shift_y: Float[Array, '']) Float[Array, 'H W'][source]¶
Forward model for one scan position with mixed-state illumination.
- Computes:
I = Σₙ wₙ |FFT(probe_n · shift(object, r_i))|²
- Parameters:
probe_modes (
Complex[Array," M H W"]) – Coherent probe modes.mode_weights (
Float[Array," M"]) – Mode weights (eigenvalues).obj (
Complex[Array," H W"]) – Object transmission.shift_x (
Float[Array," "]) – Scan position in pixels.shift_y (
Float[Array," "]) – Scan position in pixels.
- Returns:
intensity – Diffraction intensity (incoherent sum over modes).
- Return type:
Float[Array," H W"]
- janssen.invert.mixed_state_gradient_step(data: MixedStatePtychoData, learning_rate: float = 0.001, update_object: bool = True, update_modes: bool = True, update_weights: bool = False) MixedStatePtychoData[source]¶
Single gradient descent step for mixed-state reconstruction.
- Parameters:
data (
MixedStatePtychoData) – Current state.learning_rate (
float) – Step size.update_object (
bool) – Whether to update object.update_modes (
bool) – Whether to update probe mode fields.update_weights (
bool) – Whether to update mode weights.
- Returns:
updated_data – State after one gradient step.
- Return type:
- janssen.invert.mixed_state_loss(data: MixedStatePtychoData, loss_type: str = 'amplitude') Float[Array, ''][source]¶
Compute reconstruction loss for mixed-state ptychography.
- Parameters:
data (
MixedStatePtychoData) – Current state with probe modes and object.loss_type (
str) – “amplitude” for ||√I_exp - √I_pred||² (default, robust) “intensity” for ||I_exp - I_pred||² “poisson” for Poisson negative log-likelihood
- Returns:
loss – Scalar loss value.
- Return type:
Float[Array," "]
- janssen.invert.mixed_state_reconstruct(data: MixedStatePtychoData, num_iterations: int | Int[Array, ''] = 100, learning_rate: float = 0.001, update_object: bool = True, update_modes: bool = True, update_weights: bool = False) tuple[MixedStatePtychoData, Float[Array, 'I']][source]¶
Run mixed-state ptychography reconstruction.
- Parameters:
data (
MixedStatePtychoData) – Initial state with probe modes and object estimate.num_iterations (
int) – Number of gradient descent iterations.learning_rate (
float) – Step size.update_object (
bool) – Whether to reconstruct object.update_modes (
bool) – Whether to reconstruct probe modes.update_weights (
bool) – Whether to learn mode weights.
- Return type:
tuple[MixedStatePtychoData,Float[Array, 'I']]- Returns:
final_data (
MixedStatePtychoData) – Reconstructed state.loss_history (
Float[Array," I"]) – Loss at each iteration.
- janssen.invert.profile_gn_memory(experimental_data: MicroscopeData, reconstruction: PtychographyReconstruction, cg_maxiter: int = 5, verbose: bool = True) dict[source]¶
Profile memory usage during Gauss-Newton optimization.
Tracks memory allocation at key stages to diagnose OOM issues. Works with GPU, TPU, and CPU backends (memory stats available on GPU and TPU only). Useful for understanding actual XLA memory behavior vs theoretical predictions.
- Parameters:
experimental_data (
MicroscopeData) – Experimental diffraction patternsreconstruction (
PtychographyReconstruction) – Initial reconstruction statecg_maxiter (
int, optional) – CG iterations to test. Start low (3-5) to avoid OOM. Default is 5.verbose (
bool, optional) – Print detailed memory snapshots. Default is True.
- Returns:
memory_profile – Dictionary with keys: - ‘baseline’: Memory before GN step - ‘after_warmup’: Memory after warmup compilation - ‘after_gn’: Memory after 1 GN iteration - ‘peak_per_device_gb’: Peak memory per device (float or None) - ‘succeeded’: Whether profiling completed without OOM
- Return type:
Notes
Memory profiling is supported on: - GPU (CUDA, ROCm): Full support via device.memory_stats() - TPU: Full support via device.memory_stats() - CPU: Limited support (profiling runs but memory stats unavailable)
On unsupported platforms, profiling still runs to test for OOM, but peak_per_device_gb will be None.
Examples
>>> profile = profile_gn_memory(data, init_recon, cg_maxiter=3) >>> print(f"Peak memory: {profile['peak_per_device_gb']:.2f} GB")
- janssen.invert.simple_microscope_epie(experimental_data: MicroscopeData, reconstruction: PtychographyReconstruction, params: PtychographyParams) PtychographyReconstruction[source]¶
Ptychographic reconstruction using extended PIE algorithm.
High-level orchestration function that preprocesses data, runs the FFT-based ePIE algorithm, and returns results in PtychographyReconstruction format. Supports resuming from previous reconstructions.
- Parameters:
experimental_data (
MicroscopeData) – Experimental diffraction patterns collected at different positions. Positions should be in meters.reconstruction (
PtychographyReconstruction) – Previous reconstruction state from init_simple_microscope or a previous call. Contains sample, lightwave, positions, optical parameters, and intermediate history.params (
PtychographyParams) –Optimization parameters:
learning_rate: Controls ePIE step size (alpha parameter)
num_iterations: Number of complete sweeps over all positions
camera_pixel_size: Physical size of camera pixels in meters
- Returns:
reconstruction – Updated reconstruction with:
sample: Final optimized sample
lightwave: Final optimized probe/lightwave
translated_positions: Unchanged from input
Optical parameters: Unchanged from input
intermediate_*: Previous history + new iterations appended
losses: Previous history + new iterations appended
- Return type:
PtychographyReconstruction
Notes
Workflow
Preprocess data using init_simple_epie (scales to FFT coordinates)
If resuming, use previous sample/probe as starting point
Run _sm_epie_core for the requested iterations
Convert results back to PtychographyReconstruction format
Resume Support
When prev_losses has entries, the function uses the existing sample and probe from the reconstruction as the starting point instead of the freshly initialized values from init_simple_epie.
See also
init_simple_microscopeCreate initial reconstruction state.
init_simple_epiePreprocessing for FFT-compatible ePIE.
_sm_epie_coreCore ePIE algorithm.
simple_microscope_ptychographyGradient-based reconstruction.
- janssen.invert.simple_microscope_gn(experimental_data: MicroscopeData, reconstruction: PtychographyReconstruction, num_iterations: int = 10, initial_damping: float = 0.001, cg_maxiter: int = -1, cg_tol: float = -1.0, save_every: int = 10) PtychographyReconstruction[source]¶
Perform ptychographic reconstruction using Gauss-Newton optimization.
Uses second-order Gauss-Newton optimization with Levenberg-Marquardt damping for ptychography reconstruction. Solves the nonlinear least-squares problem:
min_{sample, probe} 0.5 * ||sqrt(I_exp) - sqrt(I_pred)||^2
where I_exp are experimental diffraction patterns and I_pred are simulated patterns from the forward model. The function uses JAX’s autodiff for Jacobian-free optimization via conjugate gradient.
The function automatically includes warm-up compilation for large problems (>50 positions) and uses memory-optimized defaults for conjugate gradient to fit in 16GB GPU memory. Multi-GPU sharding is inherited from simple_microscope via the residual function.
Implementation Logic¶
The optimization follows a four-stage pipeline:
Residual Function Definition: - _amplitude_residuals(params) unflatens params into sample and
probe
Calls simple_microscope to compute simulated patterns (automatically sharded)
Computes amplitude residuals: sqrt(I_exp) - sqrt(I_pred)
Returns flattened residual vector of length N × H × W
Warm-up Compilation (Automatic for N > 50): - Creates warmup_data with first 25 positions - Defines _warmup_residuals using warmup subset - Runs gn_solve for 1 iteration with max_iterations=1 - Critical: triggers JIT compilation with small problem - Compilation time: ~30s for warmup vs 5+ minutes for full problem - Memory during compilation: ~2× runtime memory - After warmup, full problem reuses compiled kernels
Chunked Gauss-Newton Optimization: - Runs gn_loss_history in chunks of size
save_every
Each GN iteration: a. Computes residuals r(θ) and loss = 0.5 * ||r||^2 b. Forms (J^T J + λI) operator via jtj_matvec c. Solves (J^T J + λI) δ = -J^T r via conjugate gradient d. Updates parameters: θ_new = θ + δ e. Adapts damping λ based on trust-region criterion
Stores full per-iteration loss history across all chunks
Stores sample/probe snapshots only at chunk boundaries (every save_every iterations)
Reduces memory compared to storing full state history every step
Result Packaging: - Converts GaussNewtonState to PtychographyReconstruction - Unflattens final parameters into sample and probe - Appends sparse intermediate_* snapshots and full loss history - Returns updated reconstruction
- type experimental_data:
MicroscopeData- param experimental_data:
Experimental diffraction patterns and scan positions.
- type experimental_data:
MicroscopeData- type reconstruction:
PtychographyReconstruction- param reconstruction:
Initial reconstruction state from init_simple_microscope.
- type reconstruction:
PtychographyReconstruction- type num_iterations:
int, default:10- param num_iterations:
Number of Gauss-Newton iterations. Default is 10.
- type num_iterations:
int, optional- type initial_damping:
float, default:0.001- param initial_damping:
Initial Levenberg-Marquardt damping parameter λ. Default is 1e-3. Adapts automatically based on step quality.
- type initial_damping:
float, optional- type cg_maxiter:
int, default:-1- param cg_maxiter:
Maximum conjugate gradient iterations per GN step. Default is -1, which automatically calculates optimal value via optimal_cg_params based on problem size and available memory. Set to positive value to override automatic calculation.
- type cg_maxiter:
int, optional- type cg_tol:
float, default:-1.0- param cg_tol:
CG convergence tolerance. Default is -1.0, which automatically calculates optimal value via optimal_cg_params. Set to positive value to override automatic calculation.
- type cg_tol:
float, optional- type save_every:
int, default:10- param save_every:
Save sample/probe snapshots in intermediate history once every save_every iterations. Full loss history is still recorded at every iteration. Default is 10.
- type save_every:
int, optional- returns:
reconstruction – Updated reconstruction with optimized sample and lightwave.
- rtype:
PtychographyReconstruction
Notes
Warm-up Compilation Rationale:
Without warm-up, JIT compilation happens during the first GN iteration with the full problem size: - Compilation memory: ~2× runtime memory for N=400 positions - For 256×256 diffractograms on 7 GPUs: ~4 GB per device - 16GB GPUs: OOM during compilation - Compilation time: 5-7 minutes for full problem
With warm-up (25 positions): - Compilation memory: ~2× runtime memory for N=25 - For 256×256 diffractograms: ~250 MB per device → no OOM - Compilation time: ~30 seconds - Full problem reuses compiled kernels (only recompiles shape-dependent
operations)
Total time saved: 4-6 minutes
Automatic CG Parameter Optimization:
By default (cg_maxiter=-1, cg_tol=-1.0), the function automatically calculates optimal parameters via optimal_cg_params based on: - Problem size (number of positions, sample/probe dimensions) - Available GPU memory (assumes 16GB per device) - Number of devices detected via jax.devices()
This ensures the solver fits in memory while maximizing accuracy. Override by passing positive values for manual control.
Memory-Optimized CG Behavior:
Conjugate gradient memory scales with maxiter: - Each CG iteration stores: residual, direction, Ap vectors - Memory per iteration: ~3 × parameter_size - For sample (512×512) + probe (256×256): ~2.5 GB per CG iteration - cg_maxiter=50: ~125 GB peak memory → OOM on 16GB GPUs - cg_maxiter=20: ~50 GB peak memory → still OOM on 16GB GPUs - cg_maxiter=10: ~25 GB peak memory → fits with 7-GPU sharding
Quality impact of reduced maxiter: - cg_maxiter=50, tol=1e-5: δ accurate to ~1e-5 - cg_maxiter=10, tol=1e-3: δ accurate to ~1e-3 - GN convergence: relative MSE decrease per iteration ~5-10% - Impact: ~2-4 extra GN iterations to reach same MSE - Tradeoff: 80% memory reduction for 20-40% more GN iterations
Multi-GPU Sharding:
Sharding is automatic via simple_microscope in _amplitude_residuals: - Forward model distributes positions across devices - Jacobian-vector products (jvp/vjp) respect sharding - CG operates on sharded vectors → memory distributed - No explicit sharding code needed in this function
Design Decisions:
Amplitude residuals (sqrt(I)) rather than intensity residuals (I): better conditioning, noise model closer to Poisson
Warm-up threshold of 50 positions chosen empirically: problems smaller than 50 compile fast enough without warm-up
Warm-up size of 25 positions: 50% of threshold, large enough for stable compilation
Single warm-up iteration (max_iterations=1): only need compilation, not convergence
Automatic CG parameter optimization (cg_maxiter=-1, cg_tol=-1.0): enabled by default to eliminate manual tuning for novice users
Sentinel values allow expert users to override with manual settings when needed
See also
optimal_cg_paramsCalculate optimal CG parameters for your problem
simple_microscope_optimFirst-order gradient-based optimization
simple_microscope_epieExtended PIE algorithm
gn_solveGeneral-purpose Gauss-Newton solver
gn_loss_historyGN solver with loss-only history
Examples
Basic usage with automatic CG parameter optimization (default):
>>> data = MicroscopeData(...) >>> init_recon = init_simple_microscope(data, ...) >>> final_recon = simple_microscope_gn(data, init_recon, num_iterations=20) # CG parameters automatically calculated based on problem size and memory
Manual CG parameter override (for expert users):
>>> final_recon = simple_microscope_gn( ... data, init_recon, num_iterations=20, ... cg_maxiter=15, cg_tol=1e-4 ... )
- janssen.invert.simple_microscope_optim(experimental_data: MicroscopeData, reconstruction: PtychographyReconstruction, params: PtychographyParams) PtychographyReconstruction[source]¶
Continue ptychographic reconstruction from a previous state.
Reconstructs a sample from experimental diffraction patterns using gradient-based optimization. Takes a PtychographyReconstruction (from init_simple_microscope or a previous call) and runs additional iterations, appending results to the intermediate arrays.
This enables resumable reconstruction: run 20 iterations, save the result, then later resume from iteration 21. Uses jax.lax.scan for efficient iteration and full JAX compatibility.
- Parameters:
experimental_data (
MicroscopeData) – The experimental diffraction patterns collected at different positions. Positions should be in meters.reconstruction (
PtychographyReconstruction) – Previous reconstruction state from init_simple_microscope or a previous call to this function. Contains sample, lightwave, positions, optical parameters, and intermediate history.params (
PtychographyParams) – Optimization parameters including camera_pixel_size, num_iterations, learning_rate, loss_type, optimizer_type, and bounds for optical parameters.
- Returns:
reconstruction – Updated reconstruction with: - sample : Final optimized sample - lightwave : Final optimized probe/lightwave - translated_positions : Unchanged from input - Optical parameters (may be updated if bounds optimization enabled) - intermediate_* : Previous history + new iterations appended - losses : Previous history + new iterations appended
- Return type:
PtychographyReconstruction
See also
init_simple_microscopeCreate initial reconstruction state.