"""Ptychography algorithms and optimization.
Extended Summary
----------------
High-level ptychography reconstruction algorithms that combine
optimization strategies with forward models. Provides complete reconstruction
pipelines for recovering complex-valued sample functions from intensity
measurements.
Routine Listings
----------------
optimal_cg_params : function
Calculate optimal conjugate gradient parameters for memory constraints
profile_gn_memory : function
Profile GPU memory usage during Gauss-Newton optimization
simple_microscope_optim : function
Performs ptychography reconstruction using gradient-based optimization
simple_microscope_epie : function
Performs ptychography reconstruction using extended PIE algorithm
simple_microscope_gn : function
Performs ptychography reconstruction using Gauss-Newton optimization
_gn_state_to_ptychography_reconstruction : function, internal
Packs Gauss-Newton state and geometry into a PtychographyReconstruction.
Notes
-----
These functions provide complete reconstruction pipelines that can be
directly applied to experimental data. All functions support JAX
transformations and automatic differentiation for gradient-based optimization.
"""
import jax
import jax.numpy as jnp
import optax
from beartype import beartype
from beartype.typing import Callable, Tuple
from jax import lax
from jaxtyping import Array, Complex, Float, Int, jaxtyped
from janssen.scopes import simple_microscope
from janssen.types import (
EpieData,
GaussNewtonState,
MicroscopeData,
OpticalWavefront,
PtychographyParams,
PtychographyReconstruction,
SampleFunction,
make_epie_data,
make_gauss_newton_state,
make_optical_wavefront,
make_ptychography_reconstruction,
make_sample_function,
)
from janssen.utils import (
fourier_shift,
get_device_memory_gb,
gn_loss_history,
gn_solve,
unflatten_params,
)
from .initialization import init_simple_epie
from .loss_functions import create_loss_function
OPTIMIZERS: Tuple[
optax.GradientTransformationExtraArgs,
optax.GradientTransformationExtraArgs,
optax.GradientTransformationExtraArgs,
optax.GradientTransformationExtraArgs,
] = (
optax.adam,
optax.adagrad,
optax.rmsprop,
optax.sgd,
)
LOSS_TYPES: Tuple[str, str, str] = ("mse", "mae", "poisson")
[docs]
@jaxtyped(typechecker=beartype)
def simple_microscope_optim( # noqa: PLR0915
experimental_data: MicroscopeData,
reconstruction: PtychographyReconstruction,
params: PtychographyParams,
) -> PtychographyReconstruction:
"""Continue ptychographic reconstruction from a previous state.
Reconstructs a sample from experimental diffraction patterns using
gradient-based optimization. Takes a PtychographyReconstruction
(from init_simple_microscope or a previous call) and runs additional
iterations, appending results to the intermediate arrays.
This enables resumable reconstruction: run 20 iterations, save the
result, then later resume from iteration 21. Uses jax.lax.scan for
efficient iteration and full JAX compatibility.
Parameters
----------
experimental_data : MicroscopeData
The experimental diffraction patterns collected at different
positions. Positions should be in meters.
reconstruction : PtychographyReconstruction
Previous reconstruction state from init_simple_microscope or
a previous call to this function. Contains sample, lightwave,
positions, optical parameters, and intermediate history.
params : PtychographyParams
Optimization parameters including camera_pixel_size, num_iterations,
learning_rate, loss_type, optimizer_type, and bounds for optical
parameters.
Returns
-------
reconstruction : PtychographyReconstruction
Updated reconstruction with:
- sample : Final optimized sample
- lightwave : Final optimized probe/lightwave
- translated_positions : Unchanged from input
- Optical parameters (may be updated if bounds optimization enabled)
- intermediate_* : Previous history + new iterations appended
- losses : Previous history + new iterations appended
See Also
--------
init_simple_microscope : Create initial reconstruction state.
"""
guess_sample: SampleFunction = reconstruction.sample
guess_lightwave: OpticalWavefront = reconstruction.lightwave
translated_positions: Float[Array, " N 2"] = (
reconstruction.translated_positions
)
zoom_factor: Float[Array, " "] = reconstruction.zoom_factor
aperture_diameter: Float[Array, " "] = reconstruction.aperture_diameter
travel_distance: Float[Array, " "] = reconstruction.travel_distance
aperture_center: Float[Array, " 2"] = (
jnp.zeros(2)
if reconstruction.aperture_center is None
else reconstruction.aperture_center
)
prev_intermediate_samples: Complex[Array, " H W S"] = (
reconstruction.intermediate_samples
)
prev_intermediate_lightwaves: Complex[Array, " H W S"] = (
reconstruction.intermediate_lightwaves
)
prev_intermediate_zoom_factors: Float[Array, " S"] = (
reconstruction.intermediate_zoom_factors
)
prev_intermediate_aperture_diameters: Float[Array, " S"] = (
reconstruction.intermediate_aperture_diameters
)
prev_intermediate_aperture_centers: Float[Array, " 2 S"] = (
reconstruction.intermediate_aperture_centers
)
prev_intermediate_travel_distances: Float[Array, " S"] = (
reconstruction.intermediate_travel_distances
)
prev_losses: Float[Array, " N 2"] = reconstruction.losses
camera_pixel_size: Float[Array, " "] = params.camera_pixel_size
num_iterations: Int[Array, " "] = params.num_iterations
learning_rate: Float[Array, " "] = params.learning_rate
loss_type: Int[Array, " "] = params.loss_type
optimizer_type: Int[Array, " "] = params.optimizer_type
zoom_factor_bounds: Float[Array, " 2"] = params.zoom_factor_bounds
aperture_diameter_bounds: Float[Array, " 2"] = (
params.aperture_diameter_bounds
)
travel_distance_bounds: Float[Array, " 2"] = params.travel_distance_bounds
aperture_center_bounds: Float[Array, " 2 2"] = (
params.aperture_center_bounds
)
start_iteration: Int[Array, " "] = jnp.array(
prev_losses.shape[0], dtype=jnp.int64
)
num_iterations_int: int = int(num_iterations)
sample_dx: Float[Array, " "] = guess_sample.dx
guess_sample_field: Complex[Array, " H W"] = guess_sample.sample
loss_type_str: str = LOSS_TYPES[int(loss_type)]
def _forward_fn(
sample_field: Complex[Array, " H W"],
lightwave_field: Complex[Array, " H W"],
zf: Float[Array, " "],
ad: Float[Array, " "],
td: Float[Array, " "],
ac: Float[Array, " 2"],
) -> Float[Array, " N H W"]:
sample: SampleFunction = make_sample_function(
sample=sample_field, dx=sample_dx
)
lightwave: OpticalWavefront = make_optical_wavefront(
field=lightwave_field,
wavelength=guess_lightwave.wavelength,
dx=guess_lightwave.dx,
z_position=guess_lightwave.z_position,
)
simulated_data: MicroscopeData = simple_microscope(
sample=sample,
positions=translated_positions,
lightwave=lightwave,
zoom_factor=zf,
aperture_diameter=ad,
travel_distance=td,
camera_pixel_size=camera_pixel_size,
aperture_center=ac,
)
return simulated_data.image_data
loss_func: Callable[..., Float[Array, " "]] = create_loss_function(
_forward_fn, experimental_data.image_data, loss_type_str
)
def _compute_loss(
sample_field: Complex[Array, " H W"],
lightwave_field: Complex[Array, " H W"],
zf: Float[Array, " "],
ad: Float[Array, " "],
td: Float[Array, " "],
ac: Float[Array, " 2"],
) -> Float[Array, " "]:
bounded_zf: Float[Array, " "] = jnp.clip(
zf, zoom_factor_bounds[0], zoom_factor_bounds[1]
)
bounded_ad: Float[Array, " "] = jnp.clip(
ad, aperture_diameter_bounds[0], aperture_diameter_bounds[1]
)
bounded_td: Float[Array, " "] = jnp.clip(
td, travel_distance_bounds[0], travel_distance_bounds[1]
)
bounded_ac: Float[Array, " 2"] = jnp.clip(
ac, aperture_center_bounds[0], aperture_center_bounds[1]
)
return loss_func(
sample_field,
lightwave_field,
bounded_zf,
bounded_ad,
bounded_td,
bounded_ac,
)
optimizer: optax.GradientTransformation = OPTIMIZERS[int(optimizer_type)](
float(learning_rate)
)
sample_opt_state: optax.OptState = optimizer.init(guess_sample_field)
sample_field: Complex[Array, " H W"] = guess_sample_field
lightwave_field: Complex[Array, " H W"] = guess_lightwave.field
def _scan_body(
carry: Tuple[
Complex[Array, " H W"],
Complex[Array, " H W"],
Float[Array, " "],
Float[Array, " "],
Float[Array, " "],
Float[Array, " 2"],
optax.OptState,
],
_iteration: Int[Array, " "],
) -> Tuple[
Tuple[
Complex[Array, " H W"],
Complex[Array, " H W"],
Float[Array, " "],
Float[Array, " "],
Float[Array, " "],
Float[Array, " 2"],
optax.OptState,
],
Tuple[
Complex[Array, " H W"],
Complex[Array, " H W"],
Float[Array, " "],
Float[Array, " "],
Float[Array, " "],
Float[Array, " 2"],
Float[Array, " "],
],
]:
sf, lf, zf, ad, td, ac, opt_state = carry
loss_val, grad = jax.value_and_grad(_compute_loss, argnums=0)(
sf, lf, zf, ad, td, ac
)
updates, new_opt_state = optimizer.update(grad, opt_state, sf)
new_sf = optax.apply_updates(sf, updates)
new_carry = (new_sf, lf, zf, ad, td, ac, new_opt_state)
output = (new_sf, lf, zf, ad, td, ac, loss_val)
return new_carry, output
init_carry = (
sample_field,
lightwave_field,
zoom_factor,
aperture_diameter,
travel_distance,
aperture_center,
sample_opt_state,
)
iterations: Int[Array, " N"] = jnp.arange(
num_iterations_int, dtype=jnp.int64
)
final_carry, outputs = lax.scan(_scan_body, init_carry, iterations)
(
intermediate_samples_new,
intermediate_lightwaves_new,
intermediate_zoom_factors_new,
intermediate_aperture_diameters_new,
intermediate_travel_distances_new,
intermediate_aperture_centers_new,
losses_new,
) = outputs
intermediate_samples: Complex[Array, " H W S"] = jnp.transpose(
intermediate_samples_new, (1, 2, 0)
)
intermediate_lightwaves: Complex[Array, " H W S"] = jnp.transpose(
intermediate_lightwaves_new, (1, 2, 0)
)
intermediate_aperture_centers: Float[Array, " 2 S"] = jnp.transpose(
intermediate_aperture_centers_new, (1, 0)
)
iteration_numbers: Float[Array, " N"] = start_iteration + jnp.arange(
num_iterations_int, dtype=jnp.float64
)
losses: Float[Array, " N 2"] = jnp.stack(
[iteration_numbers, losses_new], axis=1
)
(
final_sample_field,
final_lightwave_field,
current_zoom_factor,
current_aperture_diameter,
current_travel_distance,
current_aperture_center,
_,
) = final_carry
final_sample: SampleFunction = make_sample_function(
sample=final_sample_field, dx=sample_dx
)
final_lightwave: OpticalWavefront = make_optical_wavefront(
field=final_lightwave_field,
wavelength=guess_lightwave.wavelength,
dx=guess_lightwave.dx,
z_position=guess_lightwave.z_position,
)
combined_intermediate_samples: Complex[Array, " H W S"] = jnp.concatenate(
[prev_intermediate_samples, intermediate_samples], axis=-1
)
combined_intermediate_lightwaves: Complex[Array, " H W S"] = (
jnp.concatenate(
[prev_intermediate_lightwaves, intermediate_lightwaves], axis=-1
)
)
combined_intermediate_zoom_factors: Float[Array, " S"] = jnp.concatenate(
[prev_intermediate_zoom_factors, intermediate_zoom_factors_new],
axis=-1,
)
combined_intermediate_aperture_diameters: Float[Array, " S"] = (
jnp.concatenate(
[
prev_intermediate_aperture_diameters,
intermediate_aperture_diameters_new,
],
axis=-1,
)
)
combined_intermediate_aperture_centers: Float[Array, " 2 S"] = (
jnp.concatenate(
[
prev_intermediate_aperture_centers,
intermediate_aperture_centers,
],
axis=-1,
)
)
combined_intermediate_travel_distances: Float[Array, " S"] = (
jnp.concatenate(
[
prev_intermediate_travel_distances,
intermediate_travel_distances_new,
],
axis=-1,
)
)
combined_losses: Float[Array, " N 2"] = jnp.concatenate(
[prev_losses, losses], axis=0
)
full_and_intermediate: PtychographyReconstruction = (
make_ptychography_reconstruction(
sample=final_sample,
lightwave=final_lightwave,
translated_positions=translated_positions,
zoom_factor=current_zoom_factor,
aperture_diameter=current_aperture_diameter,
aperture_center=current_aperture_center,
travel_distance=current_travel_distance,
intermediate_samples=combined_intermediate_samples,
intermediate_lightwaves=combined_intermediate_lightwaves,
intermediate_zoom_factors=combined_intermediate_zoom_factors,
intermediate_aperture_diameters=(
combined_intermediate_aperture_diameters
),
intermediate_aperture_centers=(
combined_intermediate_aperture_centers
),
intermediate_travel_distances=(
combined_intermediate_travel_distances
),
losses=combined_losses,
)
)
return full_and_intermediate
def _sm_epie_core(
epie_data: EpieData,
iterations: Int[Array, " N"],
alpha: float = 1.0,
beta: float = 1.0,
) -> EpieData:
"""FFT-based ePIE core algorithm with Fourier shifting.
Pure JAX implementation of ePIE reconstruction using FFT-based
position shifting. The object and probe have the same size as the
diffraction patterns, and position shifts are applied via phase
ramps in Fourier space for sub-pixel accuracy.
Parameters
----------
epie_data : EpieData
Preprocessed data from init_simple_epie containing:
- diffraction_patterns: Scaled camera images (N, H, W)
- probe: Initial probe centered in the array (H, W)
- sample: Initial sample estimate, same size as probe (H, W)
- positions: Scan positions in pixels relative to center (0, 0)
iterations : Int[Array, " N"]
Array of iteration indices to scan over.
alpha : float, optional
ePIE step size for object update. Default is 1.0.
beta : float, optional
ePIE step size for probe update. Default is 1.0.
Set to 0 to freeze probe and only update object.
Returns
-------
EpieData
Updated EpieData with reconstructed sample and probe.
Notes
-----
**Sequential ePIE Algorithm with FFT Shifting**
For each iteration, we loop through all scan positions sequentially.
At each position (dx, dy) relative to center:
1. Shift probe by (dx, dy) to the scan position
2. exit_wave = object * shifted_probe
3. detector = FFT(exit_wave)
4. Replace amplitude: detector_new = detector * sqrt(I) / |detector|
5. exit_wave_new = IFFT(detector_new)
6. Update object and probe using ePIE formulas (in lab frame)
7. Use updated object for next position
Key insight: We shift the PROBE to each position rather than shifting
the object. This keeps updates in the lab frame where they belong.
"""
diffraction_patterns: Float[Array, " N H W"] = (
epie_data.diffraction_patterns
)
sample_field: Complex[Array, " H W"] = epie_data.sample
probe_field: Complex[Array, " H W"] = epie_data.probe
positions: Float[Array, " N 2"] = epie_data.positions
eps: float = 1e-8
def _epie_single_position(
carry: Tuple[Complex[Array, " H W"], Complex[Array, " H W"]],
inputs: Tuple[Float[Array, " H W"], Float[Array, " 2"]],
) -> Tuple[
Tuple[Complex[Array, " H W"], Complex[Array, " H W"]],
None,
]:
"""Process one scan position, updating object and probe."""
obj, probe = carry
measurement, pos = inputs
shift_x: Float[Array, " "] = pos[0]
shift_y: Float[Array, " "] = pos[1]
probe_shifted: Complex[Array, " H W"] = fourier_shift(
probe, shift_x, shift_y
)
exit_wave: Complex[Array, " H W"] = obj * probe_shifted
exit_wave_ft: Complex[Array, " H W"] = jnp.fft.fftshift(
jnp.fft.fft2(exit_wave)
)
measured_amplitude: Float[Array, " H W"] = jnp.sqrt(
jnp.maximum(measurement, 0.0)
)
current_amplitude: Float[Array, " H W"] = jnp.abs(exit_wave_ft) + eps
exit_wave_ft_updated: Complex[Array, " H W"] = (
exit_wave_ft * measured_amplitude / current_amplitude
)
exit_wave_updated: Complex[Array, " H W"] = jnp.fft.ifft2(
jnp.fft.ifftshift(exit_wave_ft_updated)
)
diff: Complex[Array, " H W"] = exit_wave_updated - exit_wave
probe_conj: Complex[Array, " H W"] = jnp.conj(probe_shifted)
probe_intensity: Float[Array, " H W"] = jnp.abs(probe_shifted) ** 2
probe_max_intensity: Float[Array, " "] = jnp.max(probe_intensity)
obj_update: Complex[Array, " H W"] = (
alpha * probe_conj * diff / (probe_max_intensity + eps)
)
obj_new: Complex[Array, " H W"] = obj + obj_update
obj_conj: Complex[Array, " H W"] = jnp.conj(obj)
obj_intensity: Float[Array, " H W"] = jnp.abs(obj) ** 2
obj_max_intensity: Float[Array, " "] = jnp.max(obj_intensity)
probe_update_shifted: Complex[Array, " H W"] = (
beta * obj_conj * diff / (obj_max_intensity + eps)
)
probe_update: Complex[Array, " H W"] = fourier_shift(
probe_update_shifted, -shift_x, -shift_y
)
probe_new: Complex[Array, " H W"] = probe + probe_update
return (obj_new, probe_new), None
def _epie_one_iteration(
carry: Tuple[Complex[Array, " H W"], Complex[Array, " H W"]],
_iter_idx: Int[Array, " "],
) -> Tuple[
Tuple[Complex[Array, " H W"], Complex[Array, " H W"]],
None,
]:
"""One ePIE iteration: sequential pass through all positions."""
final_carry, _ = lax.scan(
_epie_single_position,
carry,
(diffraction_patterns, positions),
)
return final_carry, None
init_carry: Tuple[Complex[Array, " H W"], Complex[Array, " H W"]] = (
sample_field,
probe_field,
)
final_carry, _ = lax.scan(_epie_one_iteration, init_carry, iterations)
final_sample_field: Complex[Array, " H W"]
final_probe_field: Complex[Array, " H W"]
final_sample_field, final_probe_field = final_carry
result: EpieData = make_epie_data(
diffraction_patterns=diffraction_patterns,
probe=final_probe_field,
sample=final_sample_field,
positions=positions,
effective_dx=epie_data.effective_dx,
wavelength=epie_data.wavelength,
original_camera_pixel_size=epie_data.original_camera_pixel_size,
zoom_factor=epie_data.zoom_factor,
)
return result
[docs]
@jaxtyped(typechecker=beartype)
def simple_microscope_epie( # noqa: PLR0914, PLR0915
experimental_data: MicroscopeData,
reconstruction: PtychographyReconstruction,
params: PtychographyParams,
) -> PtychographyReconstruction:
"""Ptychographic reconstruction using extended PIE algorithm.
High-level orchestration function that preprocesses data, runs the
FFT-based ePIE algorithm, and returns results in PtychographyReconstruction
format. Supports resuming from previous reconstructions.
Parameters
----------
experimental_data : MicroscopeData
Experimental diffraction patterns collected at different positions.
Positions should be in meters.
reconstruction : PtychographyReconstruction
Previous reconstruction state from init_simple_microscope or a
previous call. Contains sample, lightwave, positions, optical
parameters, and intermediate history.
params : PtychographyParams
Optimization parameters:
- learning_rate: Controls ePIE step size (alpha parameter)
- num_iterations: Number of complete sweeps over all positions
- camera_pixel_size: Physical size of camera pixels in meters
Returns
-------
reconstruction : PtychographyReconstruction
Updated reconstruction with:
- sample: Final optimized sample
- lightwave: Final optimized probe/lightwave
- translated_positions: Unchanged from input
- Optical parameters: Unchanged from input
- intermediate_*: Previous history + new iterations appended
- losses: Previous history + new iterations appended
Notes
-----
**Workflow**
1. Preprocess data using init_simple_epie (scales to FFT coordinates)
2. If resuming, use previous sample/probe as starting point
3. Run _sm_epie_core for the requested iterations
4. Convert results back to PtychographyReconstruction format
**Resume Support**
When prev_losses has entries, the function uses the existing sample
and probe from the reconstruction as the starting point instead of
the freshly initialized values from init_simple_epie.
See Also
--------
init_simple_microscope : Create initial reconstruction state.
init_simple_epie : Preprocessing for FFT-compatible ePIE.
_sm_epie_core : Core ePIE algorithm.
simple_microscope_ptychography : Gradient-based reconstruction.
"""
guess_sample: SampleFunction = reconstruction.sample
guess_lightwave: OpticalWavefront = reconstruction.lightwave
zoom_factor: Float[Array, " "] = reconstruction.zoom_factor
aperture_diameter: Float[Array, " "] = reconstruction.aperture_diameter
travel_distance: Float[Array, " "] = reconstruction.travel_distance
aperture_center: Float[Array, " 2"] = (
jnp.zeros(2)
if reconstruction.aperture_center is None
else reconstruction.aperture_center
)
prev_intermediate_samples: Complex[Array, " H W S"] = (
reconstruction.intermediate_samples
)
prev_intermediate_lightwaves: Complex[Array, " H W S"] = (
reconstruction.intermediate_lightwaves
)
prev_intermediate_zoom_factors: Float[Array, " S"] = (
reconstruction.intermediate_zoom_factors
)
prev_intermediate_aperture_diameters: Float[Array, " S"] = (
reconstruction.intermediate_aperture_diameters
)
prev_intermediate_aperture_centers: Float[Array, " 2 S"] = (
reconstruction.intermediate_aperture_centers
)
prev_intermediate_travel_distances: Float[Array, " S"] = (
reconstruction.intermediate_travel_distances
)
prev_losses: Float[Array, " N 2"] = reconstruction.losses
camera_pixel_size: Float[Array, " "] = params.camera_pixel_size
num_iterations: Int[Array, " "] = params.num_iterations
alpha: Float[Array, " "] = params.learning_rate
num_iterations_int: int = int(num_iterations)
epie_data: EpieData = init_simple_epie(
experimental_data=experimental_data,
effective_dx=guess_sample.dx,
wavelength=guess_lightwave.wavelength,
zoom_factor=zoom_factor,
aperture_diameter=aperture_diameter,
travel_distance=travel_distance,
camera_pixel_size=camera_pixel_size,
)
epie_sample_shape: tuple[int, int] = epie_data.sample.shape
recon_sample_shape: tuple[int, int] = guess_sample.sample.shape
is_epie_resume: bool = epie_sample_shape == recon_sample_shape
if is_epie_resume:
epie_data = make_epie_data(
diffraction_patterns=epie_data.diffraction_patterns,
probe=guess_lightwave.field,
sample=guess_sample.sample,
positions=epie_data.positions,
effective_dx=epie_data.effective_dx,
wavelength=epie_data.wavelength,
original_camera_pixel_size=epie_data.original_camera_pixel_size,
zoom_factor=epie_data.zoom_factor,
)
iterations: Int[Array, " N"] = jnp.arange(
num_iterations_int, dtype=jnp.int64
)
result_epie: EpieData = _sm_epie_core(
epie_data=epie_data,
iterations=iterations,
alpha=float(alpha),
)
final_sample_field: Complex[Array, " Hs Ws"] = result_epie.sample
final_probe_field: Complex[Array, " H W"] = result_epie.probe
output_dx: Float[Array, " "] = result_epie.effective_dx
final_sample: SampleFunction = make_sample_function(
sample=final_sample_field, dx=output_dx
)
final_lightwave: OpticalWavefront = make_optical_wavefront(
field=final_probe_field,
wavelength=guess_lightwave.wavelength,
dx=output_dx,
z_position=guess_lightwave.z_position,
)
loss_val: Float[Array, " "] = jnp.array(0.0)
sample_shape: tuple[int, ...] = (
*final_sample_field.shape,
num_iterations_int,
)
probe_shape: tuple[int, ...] = (
*final_probe_field.shape,
num_iterations_int,
)
intermediate_samples: Complex[Array, " Hs Ws N"] = jnp.broadcast_to(
final_sample_field[..., None], sample_shape
)
intermediate_lightwaves: Complex[Array, " H W N"] = jnp.broadcast_to(
final_probe_field[..., None], probe_shape
)
intermediate_zoom_factors: Float[Array, " N"] = jnp.full(
num_iterations_int, zoom_factor
)
intermediate_aperture_diameters: Float[Array, " N"] = jnp.full(
num_iterations_int, aperture_diameter
)
intermediate_travel_distances: Float[Array, " N"] = jnp.full(
num_iterations_int, travel_distance
)
intermediate_aperture_centers: Float[Array, " 2 N"] = jnp.broadcast_to(
aperture_center[:, None], (2, num_iterations_int)
)
iteration_numbers: Float[Array, " N"] = jnp.arange(
num_iterations_int, dtype=jnp.float64
)
losses_arr: Float[Array, " N"] = jnp.full(num_iterations_int, loss_val)
losses: Float[Array, " N 2"] = jnp.stack(
[iteration_numbers, losses_arr], axis=1
)
if is_epie_resume:
intermediate_samples = jnp.concatenate(
[prev_intermediate_samples, intermediate_samples], axis=-1
)
intermediate_lightwaves = jnp.concatenate(
[prev_intermediate_lightwaves, intermediate_lightwaves], axis=-1
)
intermediate_zoom_factors = jnp.concatenate(
[prev_intermediate_zoom_factors, intermediate_zoom_factors],
axis=-1,
)
intermediate_aperture_diameters = jnp.concatenate(
[
prev_intermediate_aperture_diameters,
intermediate_aperture_diameters,
],
axis=-1,
)
intermediate_aperture_centers = jnp.concatenate(
[
prev_intermediate_aperture_centers,
intermediate_aperture_centers,
],
axis=-1,
)
intermediate_travel_distances = jnp.concatenate(
[
prev_intermediate_travel_distances,
intermediate_travel_distances,
],
axis=-1,
)
losses = jnp.concatenate([prev_losses, losses], axis=0)
positions_meters: Float[Array, " N 2"] = epie_data.positions * output_dx
result: PtychographyReconstruction = make_ptychography_reconstruction(
sample=final_sample,
lightwave=final_lightwave,
translated_positions=positions_meters,
zoom_factor=zoom_factor,
aperture_diameter=aperture_diameter,
aperture_center=aperture_center,
travel_distance=travel_distance,
intermediate_samples=intermediate_samples,
intermediate_lightwaves=intermediate_lightwaves,
intermediate_zoom_factors=intermediate_zoom_factors,
intermediate_aperture_diameters=intermediate_aperture_diameters,
intermediate_aperture_centers=intermediate_aperture_centers,
intermediate_travel_distances=intermediate_travel_distances,
losses=losses,
)
return result
[docs]
def profile_gn_memory(
experimental_data: MicroscopeData,
reconstruction: PtychographyReconstruction,
cg_maxiter: int = 5,
verbose: bool = True,
) -> dict:
"""Profile memory usage during Gauss-Newton optimization.
Tracks memory allocation at key stages to diagnose OOM issues.
Works with GPU, TPU, and CPU backends (memory stats available on
GPU and TPU only). Useful for understanding actual XLA memory
behavior vs theoretical predictions.
Parameters
----------
experimental_data : MicroscopeData
Experimental diffraction patterns
reconstruction : PtychographyReconstruction
Initial reconstruction state
cg_maxiter : int, optional
CG iterations to test. Start low (3-5) to avoid OOM. Default is 5.
verbose : bool, optional
Print detailed memory snapshots. Default is True.
Returns
-------
memory_profile : dict
Dictionary with keys:
- 'baseline': Memory before GN step
- 'after_warmup': Memory after warmup compilation
- 'after_gn': Memory after 1 GN iteration
- 'peak_per_device_gb': Peak memory per device (float or None)
- 'succeeded': Whether profiling completed without OOM
Notes
-----
Memory profiling is supported on:
- GPU (CUDA, ROCm): Full support via device.memory_stats()
- TPU: Full support via device.memory_stats()
- CPU: Limited support (profiling runs but memory stats unavailable)
On unsupported platforms, profiling still runs to test for OOM,
but peak_per_device_gb will be None.
Examples
--------
>>> profile = profile_gn_memory(data, init_recon, cg_maxiter=3)
>>> print(f"Peak memory: {profile['peak_per_device_gb']:.2f} GB")
"""
def get_memory_snapshot(label: str) -> dict:
"""Capture memory across all devices (GPU/TPU/CPU)."""
snapshot = {"label": label, "devices": []}
supported_devices = 0
for i, device in enumerate(jax.devices()):
try:
stats = device.memory_stats()
snapshot["devices"].append(
{
"id": i,
"platform": device.platform,
"bytes": stats.get("bytes_in_use", 0),
"peak": stats.get("peak_bytes_in_use", 0),
}
)
supported_devices += 1
except (AttributeError, NotImplementedError, KeyError):
snapshot["devices"].append(
{
"id": i,
"platform": device.platform,
"unsupported": True,
}
)
if verbose and supported_devices > 0:
print(f"\n{label}:")
total_gb = (
sum(d.get("bytes", 0) for d in snapshot["devices"]) / 1e9
)
peak_gb = max(d.get("peak", 0) for d in snapshot["devices"]) / 1e9
print(
f" Total: {total_gb:.2f} GB, "
f"Peak/device: {peak_gb:.2f} GB "
f"({supported_devices}/{len(jax.devices())} devices)"
)
elif verbose and supported_devices == 0:
platforms = {d["platform"] for d in snapshot["devices"]}
print(f"\n{label}:")
print(
f" Memory profiling not supported on "
f"{', '.join(platforms)}"
)
return snapshot
profile = {}
profile["baseline"] = get_memory_snapshot("Baseline")
try:
if verbose:
print(f"\nRunning GN with cg_maxiter={cg_maxiter}...")
_ = simple_microscope_gn(
experimental_data,
reconstruction,
num_iterations=1,
cg_maxiter=cg_maxiter,
cg_tol=1e-2,
)
profile["after_gn"] = get_memory_snapshot("After 1 GN iteration")
profile["succeeded"] = True
peak_values = [
d.get("peak", 0)
for s in profile.values()
if isinstance(s, dict) and "devices" in s
for d in s["devices"]
if "peak" in d
]
if peak_values:
profile["peak_per_device_gb"] = max(peak_values) / 1e9
if verbose:
print("\n✓ Profiling succeeded!")
peak_mem = profile["peak_per_device_gb"]
print(f" Peak memory: {peak_mem:.2f} GB/device")
else:
profile["peak_per_device_gb"] = None
if verbose:
print("\n✓ Profiling succeeded " "(memory stats unavailable)")
except Exception as e:
profile["after_gn"] = get_memory_snapshot("At failure")
profile["succeeded"] = False
profile["error"] = str(e)
if verbose:
print(f"\n✗ Profiling failed: {e}")
return profile
@jaxtyped(typechecker=beartype)
def optimal_cg_params(
experimental_data: MicroscopeData,
reconstruction: PtychographyReconstruction,
memory_per_device_gb: float = -1.0,
safety_factor: float = 0.3,
) -> Tuple[int, float]:
"""Calculate optimal conjugate gradient parameters for memory constraints.
Automatically determines cg_maxiter and cg_tol that will fit in
available GPU memory while maintaining good convergence. Accounts for
problem size (number of positions, sample/probe dimensions), device
count, and memory constraints.
Implementation Logic
--------------------
The calculation follows a four-step process:
1. **Problem Size Analysis**:
- Extracts sample dimensions (Hs, Ws) from reconstruction.sample
- Extracts probe dimensions (Hp, Wp) from reconstruction.lightwave
- Total parameters: (Hs × Ws) + (Hp × Wp)
- Each parameter is complex128 (16 bytes)
2. **Memory Estimation**:
- CG memory per iteration has two components:
a. Parameter space: 3 × param_size × 16 bytes
(stores x, r, p vectors in parameter space)
b. Residual space: 6 × N × H × W × 4 bytes
(Jacobian-vector products create multiple residual evaluations)
- Baseline memory: forward model + gradients
≈ 2 × (N × H × W × 4 bytes) for diffractograms
- Available memory: (devices × memory_per_device × safety_factor)
- baseline
3. **CG Iteration Calculation**:
- Max iterations: floor(available_memory / memory_per_iteration)
- Clamp to [5, 20] range (min 5 for convergence, max 20 for
conservative memory safety with compilation overhead)
4. **Tolerance Selection**:
- Continuous log-linear relationship: tol = 10^(-(maxiter + 25)/15)
- More iterations → tighter tolerance (lower value)
- Fewer iterations → looser tolerance (higher value)
- Examples: maxiter=50 → tol=1e-5, maxiter=20 → tol=1e-3,
maxiter=5 → tol=1e-2
Parameters
----------
experimental_data : MicroscopeData
Experimental diffraction patterns. Used to determine number of
positions and diffractogram dimensions.
reconstruction : PtychographyReconstruction
Reconstruction state containing sample and probe. Used to
determine parameter dimensions.
memory_per_device_gb : float, optional
Available memory per device in GB. Default is -1.0 which triggers
automatic detection via get_device_memory_gb() (nvidia-smi for GPUs,
system RAM for CPUs). Typical GPU values: V100/RTX 6000 = 16 GB,
A100 = 40 GB, H100 = 80 GB.
safety_factor : float, optional
Fraction of memory to use for CG (0-1). Default is 0.3 to leave
headroom for compilation overhead (2-3× runtime memory),
fragmentation, and peak allocations. Increase to 0.4-0.5 for
problems with <100 positions or after warm-up compilation.
Returns
-------
cg_maxiter : int
Recommended maximum conjugate gradient iterations. Clamped to
[5, 50].
cg_tol : float
Recommended CG convergence tolerance. Automatically selected
based on maxiter to balance accuracy and convergence rate.
Notes
-----
**Memory Model**:
The memory estimation accounts for XLA's actual buffer allocation:
- Theoretical minimum: ~6 residual evaluations per CG iteration
- Empirical overhead: ~40× residual size per iteration (measured)
- Overhead sources: XLA buffer stacking, compilation temporaries,
CG intermediate storage, sharding boundary copies
- safety_factor=0.3 accounts for additional ~3× peak memory during
compilation and runtime fluctuations
- For very large problems (N > 500), reduce to 0.2 if needed
- For small problems (<100 positions) after warm-up, increase to
0.4-0.5 for better accuracy
**Accuracy vs Memory Tradeoff**:
Lower cg_maxiter means:
- Less accurate Gauss-Newton steps (higher residual after CG)
- More GN iterations needed to reach same MSE (~2-4 extra iterations
per 10 reduction in maxiter)
- Faster per-iteration time (less CG work)
- Lower memory usage (fewer intermediate vectors)
Higher cg_maxiter means:
- More accurate GN steps (lower residual after CG)
- Fewer GN iterations to convergence
- Slower per-iteration time (more CG work)
- Higher memory usage (more intermediate vectors)
**Example Calculations**:
For N=400, H=W=256, sample=512×512, probe=256×256, 7 GPUs, 16GB each:
- Total params: 512² + 256² = 327,680 complex128 = 5.24 MB
- Param space memory/iter: 3 × 5.24 MB = 15.72 MB
- Residual space memory/iter: 40 × 400 × 256 × 256 × 4 = 4.19 GB
- CG memory/iter: 15.72 MB + 4.19 GB ≈ 4.21 GB
- Baseline memory: 2 × (400 × 256 × 256 × 4) = 210 MB
- Available: 7 × 16GB × 0.3 - 0.21 GB ≈ 33.4 GB
- Max iters: 33.4 GB / 4.21 GB ≈ 7
- Recommended: cg_maxiter=7, tol≈4e-3
For same problem with 1 GPU:
- Available: 1 × 16GB × 0.3 - 0.21 GB ≈ 4.6 GB
- Max iters: 4.6 GB / 4.21 GB ≈ 1
- Recommended: cg_maxiter=5 (minimum), tol=1e-2
For N=400, sample=1024×1024, probe=512×512, 7 GPUs:
- Total params: 1024² + 512² = 1,310,720 complex128 = 21 MB
- Param space: 3 × 21 MB = 63 MB, Residual space: 4.19 GB
- CG memory/iter: 4.25 GB
- Available: ≈ 33.4 GB
- Max iters: 33.4 GB / 4.25 GB ≈ 7
- Recommended: cg_maxiter=7, tol≈4e-3
For N=400, sample=2048×2048, probe=512×512, 7 GPUs:
- Total params: 2048² + 512² = 4,456,448 complex128 = 71.3 MB
- Param space: 3 × 71.3 MB = 214 MB, Residual space: 4.19 GB
- CG memory/iter: 4.40 GB
- Available: ≈ 33.4 GB
- Max iters: 33.4 GB / 4.40 GB ≈ 7
- Recommended: cg_maxiter=7, tol≈4e-3
For N=400, sample=4096×4096, probe=1024×1024, 7 GPUs:
- Total params: 4096² + 1024² = 17,825,792 complex128 = 285 MB
- Param space: 3 × 285 MB = 855 MB, Residual space: 4.19 GB
- CG memory/iter: 5.05 GB
- Available: ≈ 33.4 GB
- Max iters: 33.4 GB / 5.05 GB ≈ 6
- Recommended: cg_maxiter=6, tol≈5e-3
For N=400, sample=8192×8192, probe=2048×2048, 7 GPUs, 16GB:
- Total params: 8192² + 2048² = 71,303,168 complex128 = 1.14 GB
- Param space: 3 × 1.14 GB = 3.42 GB, Residual space: 4.19 GB
- CG memory/iter: 7.61 GB
- Available: ≈ 33.4 GB
- Max iters: 33.4 GB / 7.61 GB ≈ 4
- Recommended: cg_maxiter=5 (minimum), tol=1e-2
**Design Decisions**:
- safety_factor=0.3 chosen empirically: provides reliable OOM
avoidance accounting for ~3× compilation overhead
- Clamp minimum to 5: CG needs at least a few iterations for any
progress
- Clamp maximum to 20: conservative cap for memory safety with
compilation overhead; beyond 20, better to do more GN iterations
than risk OOM
- Continuous tolerance formula tol = 10^(-(maxiter + 25)/15) provides
smooth scaling: more iterations enable tighter tolerances, fewer
iterations require looser tolerances for CG convergence
See Also
--------
simple_microscope_gn : Uses these parameters for optimization
Examples
--------
>>> # Automatic GPU memory detection (default)
>>> data = MicroscopeData(...)
>>> init_recon = init_simple_microscope(data, ...)
>>> cg_maxiter, cg_tol = optimal_cg_params(data, init_recon)
>>> print(f"Recommended: cg_maxiter={cg_maxiter}, cg_tol={cg_tol}")
>>> result = simple_microscope_gn(
... data, init_recon, cg_maxiter=cg_maxiter, cg_tol=cg_tol
... )
>>>
>>> # Manual override for specific GPU memory
>>> cg_maxiter, cg_tol = optimal_cg_params(
... data, init_recon, memory_per_device_gb=40.0
... )
"""
if memory_per_device_gb < 0:
_, memory_per_device_gb = get_device_memory_gb()
num_positions: int = experimental_data.image_data.shape[0]
diff_height: int = experimental_data.image_data.shape[1]
diff_width: int = experimental_data.image_data.shape[2]
sample_shape: Tuple[int, int] = reconstruction.sample.sample.shape
probe_shape: Tuple[int, int] = reconstruction.lightwave.field.shape
sample_size: int = sample_shape[0] * sample_shape[1]
probe_size: int = probe_shape[0] * probe_shape[1]
total_params: int = sample_size + probe_size
bytes_per_complex128: int = 16
param_space_vectors: int = 3
param_memory_per_iter_bytes: float = (
total_params * bytes_per_complex128 * param_space_vectors
)
residual_size: int = num_positions * diff_height * diff_width
bytes_per_float32: int = 4
residual_evaluations_per_iter: int = 40
residual_memory_per_iter_bytes: float = (
residual_size * bytes_per_float32 * residual_evaluations_per_iter
)
cg_memory_per_iter_bytes: float = (
param_memory_per_iter_bytes + residual_memory_per_iter_bytes
)
cg_memory_per_iter_gb: float = cg_memory_per_iter_bytes / 1e9
num_devices: int = len(jax.devices())
total_memory_gb: float = memory_per_device_gb * num_devices
diffractogram_memory_bytes: float = (
num_positions * diff_height * diff_width * 4
)
diffractogram_memory_gb: float = diffractogram_memory_bytes / 1e9
baseline_memory_gb: float = diffractogram_memory_gb * 2
available_memory_gb: float = (
total_memory_gb * safety_factor - baseline_memory_gb
)
if available_memory_gb <= 0:
cg_maxiter: int = 5
cg_tol: float = 1e-2
else:
max_cg_iters: int = int(available_memory_gb / cg_memory_per_iter_gb)
conservative_cap: int = 20
cg_maxiter: int = max(5, min(conservative_cap, max_cg_iters))
log_tol: float = -(cg_maxiter + 25.0) / 15.0
cg_tol: float = 10.0**log_tol
result: Tuple[int, float] = (cg_maxiter, cg_tol)
return result
[docs]
@jaxtyped(typechecker=beartype)
def simple_microscope_gn( # noqa: PLR0915
experimental_data: MicroscopeData,
reconstruction: PtychographyReconstruction,
num_iterations: int = 10,
initial_damping: float = 1e-3,
cg_maxiter: int = -1,
cg_tol: float = -1.0,
save_every: int = 10,
) -> PtychographyReconstruction:
"""Perform ptychographic reconstruction using Gauss-Newton optimization.
Uses second-order Gauss-Newton optimization with Levenberg-Marquardt
damping for ptychography reconstruction. Solves the nonlinear
least-squares problem:
min_{sample, probe} 0.5 * ||sqrt(I_exp) - sqrt(I_pred)||^2
where I_exp are experimental diffraction patterns and I_pred are
simulated patterns from the forward model. The function uses JAX's
autodiff for Jacobian-free optimization via conjugate gradient.
The function automatically includes warm-up compilation for large
problems (>50 positions) and uses memory-optimized defaults for
conjugate gradient to fit in 16GB GPU memory. Multi-GPU sharding
is inherited from simple_microscope via the residual function.
Implementation Logic
--------------------
The optimization follows a four-stage pipeline:
1. **Residual Function Definition**:
- _amplitude_residuals(params) unflatens params into sample and
probe
- Calls simple_microscope to compute simulated patterns
(automatically sharded)
- Computes amplitude residuals: sqrt(I_exp) - sqrt(I_pred)
- Returns flattened residual vector of length N × H × W
2. **Warm-up Compilation** (Automatic for N > 50):
- Creates warmup_data with first 25 positions
- Defines _warmup_residuals using warmup subset
- Runs gn_solve for 1 iteration with max_iterations=1
- Critical: triggers JIT compilation with small problem
- Compilation time: ~30s for warmup vs 5+ minutes for full problem
- Memory during compilation: ~2× runtime memory
- After warmup, full problem reuses compiled kernels
3. **Chunked Gauss-Newton Optimization**:
- Runs gn_loss_history in chunks of size
`save_every`
- Each GN iteration:
a. Computes residuals r(θ) and loss = 0.5 * ||r||^2
b. Forms (J^T J + λI) operator via jtj_matvec
c. Solves (J^T J + λI) δ = -J^T r via conjugate gradient
d. Updates parameters: θ_new = θ + δ
e. Adapts damping λ based on trust-region criterion
- Stores full per-iteration loss history across all chunks
- Stores sample/probe snapshots only at chunk boundaries
(every `save_every` iterations)
- Reduces memory compared to storing full state history every step
4. **Result Packaging**:
- Converts GaussNewtonState to PtychographyReconstruction
- Unflattens final parameters into sample and probe
- Appends sparse intermediate_* snapshots and full loss history
- Returns updated reconstruction
Parameters
----------
experimental_data : MicroscopeData
Experimental diffraction patterns and scan positions.
reconstruction : PtychographyReconstruction
Initial reconstruction state from init_simple_microscope.
num_iterations : int, optional
Number of Gauss-Newton iterations. Default is 10.
initial_damping : float, optional
Initial Levenberg-Marquardt damping parameter λ. Default is 1e-3.
Adapts automatically based on step quality.
cg_maxiter : int, optional
Maximum conjugate gradient iterations per GN step. Default is -1,
which automatically calculates optimal value via optimal_cg_params
based on problem size and available memory. Set to positive value
to override automatic calculation.
cg_tol : float, optional
CG convergence tolerance. Default is -1.0, which automatically
calculates optimal value via optimal_cg_params. Set to positive
value to override automatic calculation.
save_every : int, optional
Save sample/probe snapshots in intermediate history once every
`save_every` iterations. Full loss history is still recorded at
every iteration. Default is 10.
Returns
-------
reconstruction : PtychographyReconstruction
Updated reconstruction with optimized sample and lightwave.
Notes
-----
**Warm-up Compilation Rationale**:
Without warm-up, JIT compilation happens during the first GN iteration
with the full problem size:
- Compilation memory: ~2× runtime memory for N=400 positions
- For 256×256 diffractograms on 7 GPUs: ~4 GB per device
- 16GB GPUs: OOM during compilation
- Compilation time: 5-7 minutes for full problem
With warm-up (25 positions):
- Compilation memory: ~2× runtime memory for N=25
- For 256×256 diffractograms: ~250 MB per device → no OOM
- Compilation time: ~30 seconds
- Full problem reuses compiled kernels (only recompiles shape-dependent
operations)
- Total time saved: 4-6 minutes
**Automatic CG Parameter Optimization**:
By default (cg_maxiter=-1, cg_tol=-1.0), the function automatically
calculates optimal parameters via optimal_cg_params based on:
- Problem size (number of positions, sample/probe dimensions)
- Available GPU memory (assumes 16GB per device)
- Number of devices detected via jax.devices()
This ensures the solver fits in memory while maximizing accuracy.
Override by passing positive values for manual control.
**Memory-Optimized CG Behavior**:
Conjugate gradient memory scales with maxiter:
- Each CG iteration stores: residual, direction, Ap vectors
- Memory per iteration: ~3 × parameter_size
- For sample (512×512) + probe (256×256): ~2.5 GB per CG iteration
- cg_maxiter=50: ~125 GB peak memory → OOM on 16GB GPUs
- cg_maxiter=20: ~50 GB peak memory → still OOM on 16GB GPUs
- cg_maxiter=10: ~25 GB peak memory → fits with 7-GPU sharding
Quality impact of reduced maxiter:
- cg_maxiter=50, tol=1e-5: δ accurate to ~1e-5
- cg_maxiter=10, tol=1e-3: δ accurate to ~1e-3
- GN convergence: relative MSE decrease per iteration ~5-10%
- Impact: ~2-4 extra GN iterations to reach same MSE
- Tradeoff: 80% memory reduction for 20-40% more GN iterations
**Multi-GPU Sharding**:
Sharding is automatic via simple_microscope in _amplitude_residuals:
- Forward model distributes positions across devices
- Jacobian-vector products (jvp/vjp) respect sharding
- CG operates on sharded vectors → memory distributed
- No explicit sharding code needed in this function
**Design Decisions**:
- Amplitude residuals (sqrt(I)) rather than intensity residuals (I):
better conditioning, noise model closer to Poisson
- Warm-up threshold of 50 positions chosen empirically: problems
smaller than 50 compile fast enough without warm-up
- Warm-up size of 25 positions: 50% of threshold, large enough for
stable compilation
- Single warm-up iteration (max_iterations=1): only need compilation,
not convergence
- Automatic CG parameter optimization (cg_maxiter=-1, cg_tol=-1.0):
enabled by default to eliminate manual tuning for novice users
- Sentinel values allow expert users to override with manual settings
when needed
See Also
--------
optimal_cg_params : Calculate optimal CG parameters for your problem
simple_microscope_optim : First-order gradient-based optimization
simple_microscope_epie : Extended PIE algorithm
gn_solve : General-purpose Gauss-Newton solver
gn_loss_history : GN solver with loss-only history
Examples
--------
Basic usage with automatic CG parameter optimization (default):
>>> data = MicroscopeData(...)
>>> init_recon = init_simple_microscope(data, ...)
>>> final_recon = simple_microscope_gn(data, init_recon, num_iterations=20)
# CG parameters automatically calculated based on problem size and memory
Manual CG parameter override (for expert users):
>>> final_recon = simple_microscope_gn(
... data, init_recon, num_iterations=20,
... cg_maxiter=15, cg_tol=1e-4
... )
"""
sample: SampleFunction = reconstruction.sample
lightwave: OpticalWavefront = reconstruction.lightwave
translated_positions: Float[Array, " N 2"] = (
reconstruction.translated_positions
)
aperture_center: Float[Array, " 2"] = (
jnp.zeros(2)
if reconstruction.aperture_center is None
else reconstruction.aperture_center
)
if cg_maxiter < 0 or cg_tol < 0:
optimal_maxiter: int
optimal_tol: float
optimal_maxiter, optimal_tol = optimal_cg_params(
experimental_data, reconstruction
)
cg_maxiter_used: int = (
optimal_maxiter if cg_maxiter < 0 else cg_maxiter
)
cg_tol_used: float = optimal_tol if cg_tol < 0 else cg_tol
else:
cg_maxiter_used: int = cg_maxiter
cg_tol_used: float = cg_tol
if save_every <= 0:
msg: str = f"save_every must be positive, got {save_every}"
raise ValueError(msg)
def _amplitude_residuals(params: Float[Array, " n"]) -> Float[Array, " m"]:
sample_shape: Tuple[int, int] = sample.sample.shape
probe_shape: Tuple[int, int] = lightwave.field.shape
sample_field: Complex[Array, " Hs Ws"]
probe_field: Complex[Array, " Hp Wp"]
sample_field, probe_field = unflatten_params(
params, sample_shape, probe_shape
)
sample_fn: SampleFunction = make_sample_function(
sample=sample_field, dx=sample.dx
)
lightwave_fn: OpticalWavefront = make_optical_wavefront(
field=probe_field,
wavelength=lightwave.wavelength,
dx=lightwave.dx,
z_position=lightwave.z_position,
)
simulated: MicroscopeData = simple_microscope(
sample=sample_fn,
positions=translated_positions,
lightwave=lightwave_fn,
zoom_factor=reconstruction.zoom_factor,
aperture_diameter=reconstruction.aperture_diameter,
travel_distance=reconstruction.travel_distance,
camera_pixel_size=experimental_data.dx,
aperture_center=aperture_center,
)
measured: Float[Array, " N H W"] = experimental_data.image_data
predicted: Float[Array, " N H W"] = simulated.image_data
amp_measured: Float[Array, " N H W"] = jnp.sqrt(
jnp.maximum(measured, 1e-12)
)
amp_predicted: Float[Array, " N H W"] = jnp.sqrt(
jnp.maximum(predicted, 1e-12)
)
return (amp_measured - amp_predicted).ravel()
state: GaussNewtonState = make_gauss_newton_state(
sample=sample.sample,
probe=lightwave.field,
iteration=0,
loss=jnp.inf,
damping=initial_damping,
converged=False,
)
num_positions: int = experimental_data.image_data.shape[0]
warmup_threshold: int = 50
if num_positions > warmup_threshold:
warmup_positions: int = 25
warmup_data: MicroscopeData = MicroscopeData(
image_data=experimental_data.image_data[:warmup_positions],
positions=experimental_data.positions[:warmup_positions],
dx=experimental_data.dx,
wavelength=experimental_data.wavelength,
)
warmup_translated_positions: Float[Array, " warmup 2"] = (
translated_positions[:warmup_positions]
)
def _warmup_residuals(
params: Float[Array, " n"],
) -> Float[Array, " m"]:
sample_shape: Tuple[int, int] = sample.sample.shape
probe_shape: Tuple[int, int] = lightwave.field.shape
sample_field: Complex[Array, " Hs Ws"]
probe_field: Complex[Array, " Hp Wp"]
sample_field, probe_field = unflatten_params(
params, sample_shape, probe_shape
)
sample_fn: SampleFunction = make_sample_function(
sample=sample_field, dx=sample.dx
)
lightwave_fn: OpticalWavefront = make_optical_wavefront(
field=probe_field,
wavelength=lightwave.wavelength,
dx=lightwave.dx,
z_position=lightwave.z_position,
)
simulated: MicroscopeData = simple_microscope(
sample=sample_fn,
positions=warmup_translated_positions,
lightwave=lightwave_fn,
zoom_factor=reconstruction.zoom_factor,
aperture_diameter=reconstruction.aperture_diameter,
travel_distance=reconstruction.travel_distance,
camera_pixel_size=warmup_data.dx,
aperture_center=aperture_center,
)
measured: Float[Array, " warmup H W"] = warmup_data.image_data
predicted: Float[Array, " warmup H W"] = simulated.image_data
amp_measured: Float[Array, " warmup H W"] = jnp.sqrt(
jnp.maximum(measured, 1e-12)
)
amp_predicted: Float[Array, " warmup H W"] = jnp.sqrt(
jnp.maximum(predicted, 1e-12)
)
return (amp_measured - amp_predicted).ravel()
_ = gn_solve(
state,
_warmup_residuals,
max_iterations=1,
cg_maxiter=cg_maxiter_used,
cg_tol=cg_tol_used,
)
full_loss_chunks: list[Float[Array, " N"]]
full_loss_chunks = []
sample_snapshots: list[Complex[Array, " H W"]]
sample_snapshots = []
probe_snapshots: list[Complex[Array, " H W"]]
probe_snapshots = []
current_state: GaussNewtonState = state
remaining: int = num_iterations
while remaining > 0:
this_chunk: int = min(save_every, remaining)
current_state, chunk_losses = gn_loss_history(
current_state,
_amplitude_residuals,
max_iterations=this_chunk,
cg_maxiter=cg_maxiter_used,
cg_tol=cg_tol_used,
)
full_loss_chunks.append(chunk_losses)
sample_snapshots.append(current_state.sample)
probe_snapshots.append(current_state.probe)
remaining -= this_chunk
if num_iterations > 0:
all_losses: Float[Array, " N"] = jnp.concatenate(
full_loss_chunks, axis=0
)
snapshot_samples: Complex[Array, " K H W"] = jnp.stack(
sample_snapshots, axis=0
)
snapshot_probes: Complex[Array, " K H W"] = jnp.stack(
probe_snapshots, axis=0
)
else:
all_losses = jnp.zeros((0,), dtype=jnp.float64)
snapshot_samples = jnp.zeros(
(0, *sample.sample.shape), dtype=jnp.complex128
)
snapshot_probes = jnp.zeros(
(0, *lightwave.field.shape), dtype=jnp.complex128
)
final_state: GaussNewtonState = current_state
return _gn_state_to_ptychography_reconstruction(
final_state,
snapshot_samples,
snapshot_probes,
all_losses,
reconstruction,
sample.dx,
)
def _gn_state_to_ptychography_reconstruction(
final_state: GaussNewtonState,
snapshot_samples: Complex[Array, " K H W"],
snapshot_probes: Complex[Array, " K H W"],
all_losses: Float[Array, " N"],
reconstruction: PtychographyReconstruction,
sample_dx: Float[Array, " "],
) -> PtychographyReconstruction:
"""Pack Gauss-Newton state and geometry into a PtychographyReconstruction.
Builds sample/lightwave from final_state, stores sparse sample/probe
snapshots, records full loss history, and concatenates with any previous
intermediates/losses from reconstruction. Handles global iteration
numbering across resume calls.
"""
lightwave: OpticalWavefront = reconstruction.lightwave
translated_positions: Float[Array, " N 2"] = (
reconstruction.translated_positions
)
zoom_factor: Float[Array, " "] = reconstruction.zoom_factor
aperture_diameter: Float[Array, " "] = reconstruction.aperture_diameter
travel_distance: Float[Array, " "] = reconstruction.travel_distance
aperture_center: Float[Array, " 2"] = (
jnp.zeros(2)
if reconstruction.aperture_center is None
else reconstruction.aperture_center
)
final_sample: SampleFunction = make_sample_function(
sample=final_state.sample, dx=sample_dx
)
final_lightwave: OpticalWavefront = make_optical_wavefront(
field=final_state.probe,
wavelength=lightwave.wavelength,
dx=lightwave.dx,
z_position=lightwave.z_position,
)
prev_losses: Float[Array, " M 2"] = reconstruction.losses
num_iterations: int = int(all_losses.shape[0])
num_snapshots: int = int(snapshot_samples.shape[0])
if num_snapshots == 0:
intermediate_samples: Complex[Array, " H W 0"] = jnp.zeros(
(*final_state.sample.shape, 0), dtype=jnp.complex128
)
intermediate_lightwaves: Complex[Array, " H W 0"] = jnp.zeros(
(*final_state.probe.shape, 0), dtype=jnp.complex128
)
intermediate_zoom_factors: Float[Array, " 0"] = jnp.zeros(
(0,), dtype=jnp.float64
)
intermediate_aperture_diameters: Float[Array, " 0"] = jnp.zeros(
(0,), dtype=jnp.float64
)
intermediate_aperture_centers: Float[Array, " 2 0"] = jnp.zeros(
(2, 0), dtype=jnp.float64
)
intermediate_travel_distances: Float[Array, " 0"] = jnp.zeros(
(0,), dtype=jnp.float64
)
losses: Float[Array, " 0 2"] = jnp.zeros((0, 2), dtype=jnp.float64)
else:
intermediate_samples: Complex[Array, " H W K"] = jnp.transpose(
snapshot_samples, (1, 2, 0)
)
intermediate_lightwaves: Complex[Array, " H W K"] = jnp.transpose(
snapshot_probes, (1, 2, 0)
)
intermediate_zoom_factors: Float[Array, " K"] = jnp.full(
num_snapshots, zoom_factor, dtype=jnp.float64
)
intermediate_aperture_diameters: Float[Array, " K"] = jnp.full(
num_snapshots, aperture_diameter, dtype=jnp.float64
)
intermediate_aperture_centers: Float[Array, " 2 K"] = jnp.broadcast_to(
aperture_center[:, None], (2, num_snapshots)
)
intermediate_travel_distances: Float[Array, " K"] = jnp.full(
num_snapshots, travel_distance, dtype=jnp.float64
)
losses = jnp.zeros((0, 2), dtype=jnp.float64)
start_iteration: Float[Array, " "] = jnp.where(
prev_losses.shape[0] > 0,
prev_losses[-1, 0] + 1.0,
0.0,
)
if num_iterations > 0:
iteration_numbers: Float[Array, " N"] = start_iteration + jnp.arange(
num_iterations, dtype=jnp.float64
)
losses = jnp.stack([iteration_numbers, all_losses], axis=1)
if prev_losses.shape[0] > 0:
intermediate_samples = jnp.concatenate(
[
reconstruction.intermediate_samples,
intermediate_samples,
],
axis=-1,
)
intermediate_lightwaves = jnp.concatenate(
[
reconstruction.intermediate_lightwaves,
intermediate_lightwaves,
],
axis=-1,
)
intermediate_zoom_factors = jnp.concatenate(
[
reconstruction.intermediate_zoom_factors,
intermediate_zoom_factors,
],
axis=-1,
)
intermediate_aperture_diameters = jnp.concatenate(
[
reconstruction.intermediate_aperture_diameters,
intermediate_aperture_diameters,
],
axis=-1,
)
intermediate_aperture_centers = jnp.concatenate(
[
reconstruction.intermediate_aperture_centers,
intermediate_aperture_centers,
],
axis=-1,
)
intermediate_travel_distances = jnp.concatenate(
[
reconstruction.intermediate_travel_distances,
intermediate_travel_distances,
],
axis=-1,
)
losses = jnp.concatenate([prev_losses, losses], axis=0)
return make_ptychography_reconstruction(
sample=final_sample,
lightwave=final_lightwave,
translated_positions=translated_positions,
zoom_factor=zoom_factor,
aperture_diameter=aperture_diameter,
aperture_center=aperture_center,
travel_distance=travel_distance,
intermediate_samples=intermediate_samples,
intermediate_lightwaves=intermediate_lightwaves,
intermediate_zoom_factors=intermediate_zoom_factors,
intermediate_aperture_diameters=intermediate_aperture_diameters,
intermediate_aperture_centers=intermediate_aperture_centers,
intermediate_travel_distances=intermediate_travel_distances,
losses=losses,
)