"""Ptychography algorithms and optimization.
Extended Summary
----------------
High-level ptychography reconstruction algorithms that combine optimization
strategies with forward models. Provides complete reconstruction pipelines
for recovering complex-valued sample functions from intensity measurements.
Routine Listings
----------------
get_optimizer : function
Returns an optimizer object based on the specified name
simple_microscope_ptychography : function
Performs ptychography reconstruction using a simple microscope model
Notes
-----
These functions provide complete reconstruction pipelines that can be
directly applied to experimental data. All functions support JAX transformations
and automatic differentiation for gradient-based optimization.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Dict, Optional, Tuple
from jaxtyping import Array, Complex, Float, jaxtyped
from janssen.simul import simple_microscope
from janssen.utils import (
MicroscopeData,
OpticalWavefront,
PtychographyParams,
SampleFunction,
make_optical_wavefront,
make_sample_function,
scalar_float,
scalar_integer,
)
from .loss_functions import create_loss_function
from .optimizers import (
Optimizer,
adagrad_update,
adam_update,
init_adagrad,
init_adam,
init_rmsprop,
rmsprop_update,
)
jax.config.update("jax_enable_x64", True)
OPTIMIZERS: Dict[str, Optimizer] = {
"adam": Optimizer(init_adam, adam_update),
"adagrad": Optimizer(init_adagrad, adagrad_update),
"rmsprop": Optimizer(init_rmsprop, rmsprop_update),
}
def get_optimizer(optimizer_name: str) -> Optimizer:
"""Get the optimizer function based on the optimizer name.
Parameters
----------
optimizer_name : str
The name of the optimizer to get.
Returns
-------
Optimizer
The optimizer function.
"""
if optimizer_name not in OPTIMIZERS:
raise ValueError(f"Unknown optimizer: {optimizer_name}")
return OPTIMIZERS[optimizer_name]
[docs]
@jaxtyped(typechecker=beartype)
def simple_microscope_ptychography(
experimental_data: MicroscopeData,
guess_sample: SampleFunction,
guess_lightwave: OpticalWavefront,
params: PtychographyParams,
save_every: Optional[scalar_integer] = 10,
loss_type: Optional[str] = "mse",
optimizer_name: Optional[str] = "adam",
zoom_factor_bounds: Optional[Tuple[scalar_float, scalar_float]] = None,
aperture_diameter_bounds: Optional[
Tuple[scalar_float, scalar_float]
] = None,
travel_distance_bounds: Optional[Tuple[scalar_float, scalar_float]] = None,
aperture_center_bounds: Optional[
Tuple[Float[Array, " 2"], Float[Array, " 2"]]
] = None,
) -> Tuple[
Tuple[
SampleFunction, # final_sample
OpticalWavefront, # final_lightwave
scalar_float, # final_zoom_factor
scalar_float, # final_aperture_diameter
Optional[Float[Array, " 2"]], # final_aperture_center
scalar_float, # final_travel_distance
],
Tuple[
Complex[Array, " H W S"], # intermediate_samples
Complex[Array, " H W S"], # intermediate_lightwaves
Float[Array, " S"], # intermediate_zoom_factors
Float[Array, " S"], # intermediate_aperture_diameters
Float[Array, " 2 S"], # intermediate_aperture_centers
Float[Array, " S"], # intermediate_travel_distances
],
]:
"""Solve the optical ptychography inverse problem.
Here experimental diffraction patterns are used to reconstruct a sample,
lightwave, and optical system parameters.
Parameters
----------
experimental_data : MicroscopeData
The experimental diffraction patterns collected at different positions.
guess_sample : SampleFunction
Initial guess for the sample properties.
guess_lightwave : OpticalWavefront
Initial guess for the lightwave.
params : PtychographyParams
Ptychography parameters including:
- zoom_factor: Optical zoom factor for magnification
- aperture_diameter: Diameter of the aperture in meters
- travel_distance: Light propagation distance in meters
- aperture_center: Center position of the aperture (x, y)
- camera_pixel_size: Camera pixel size in meters
- learning_rate: Learning rate for optimization
- num_iterations: Number of optimization iterations
save_every : scalar_integer, optional
Save intermediate results every n iterations. Default is 10.
loss_type : str, optional
Type of loss function to use. Default is "mse".
optimizer_name : str, optional
Name of the optimizer to use. Default is "adam".
zoom_factor_bounds : Tuple[scalar_float, scalar_float], optional
Lower and upper bounds for zoom factor optimization.
aperture_diameter_bounds : Tuple[scalar_float, scalar_float], optional
Lower and upper bounds for aperture diameter optimization.
travel_distance_bounds : Tuple[scalar_float, scalar_float], optional
Lower and upper bounds for travel distance optimization.
aperture_center_bounds : Tuple[Float[Array, " 2"], Float[Array, " 2"]], optional
Lower and upper bounds for aperture center optimization.
Returns
-------
Tuple[Tuple[...], Tuple[...]]
Tuple containing:
- Final results tuple:
- final_sample : SampleFunction
Optimized sample properties.
- final_lightwave : OpticalWavefront
Optimized lightwave.
- final_zoom_factor : scalar_float
Optimized zoom factor.
- final_aperture_diameter : scalar_float
Optimized aperture diameter.
- final_aperture_center : Float[Array, " 2"] or None
Optimized aperture center.
- final_travel_distance : scalar_float
Optimized travel distance.
- Intermediate results tuple:
- intermediate_samples : Complex[Array, " H W S"]
Intermediate samples during optimization.
- intermediate_lightwaves : Complex[Array, " H W S"]
Intermediate lightwaves during optimization.
- intermediate_zoom_factors : Float[Array, " S"]
Intermediate zoom factors during optimization.
- intermediate_aperture_diameters : Float[Array, " S"]
Intermediate aperture diameters during optimization.
- intermediate_aperture_centers : Float[Array, " 2 S"]
Intermediate aperture centers during optimization.
- intermediate_travel_distances : Float[Array, " S"]
Intermediate travel distances during optimization.
"""
# Extract parameters from PtychographyParams
zoom_factor = params.zoom_factor
aperture_diameter = params.aperture_diameter
travel_distance = params.travel_distance
aperture_center = params.aperture_center
camera_pixel_size = params.camera_pixel_size
learning_rate = params.learning_rate
num_iterations = params.num_iterations
# Define bound enforcement functions
def enforce_bounds(param, param_bounds):
if param_bounds is None:
return param
lower, upper = param_bounds
return jnp.clip(param, lower, upper)
def enforce_bounds_2d(param, param_bounds):
if param_bounds is None:
return param
lower, upper = param_bounds
return jnp.clip(param, lower, upper)
# Define the forward model function for the loss calculation
def forward_fn(
sample_field,
lightwave_field,
zoom_factor,
aperture_diameter,
travel_distance,
aperture_center,
):
# Reconstruct PyTree objects from arrays
sample = make_sample_function(sample=sample_field, dx=guess_sample.dx)
lightwave = make_optical_wavefront(
field=lightwave_field,
wavelength=guess_lightwave.wavelength,
dx=guess_lightwave.dx,
z_position=guess_lightwave.z_position,
)
# Generate the microscope data using the forward model
simulated_data = simple_microscope(
sample=sample,
positions=experimental_data.positions,
lightwave=lightwave,
zoom_factor=zoom_factor,
aperture_diameter=aperture_diameter,
travel_distance=travel_distance,
camera_pixel_size=camera_pixel_size,
aperture_center=aperture_center,
)
return simulated_data.image_data
# Create loss function using the tools module
loss_func = create_loss_function(
forward_fn, experimental_data.image_data, loss_type
)
# Define function to compute loss and gradients
@jax.jit
def loss_and_grad(
sample_field,
lightwave_field,
zoom_factor,
aperture_diameter,
travel_distance,
aperture_center,
):
def loss_wrapped(
sample_field,
lightwave_field,
zoom_factor,
aperture_diameter,
travel_distance,
aperture_center,
):
# Enforce bounds before calculating loss
bounded_zoom_factor = enforce_bounds(
zoom_factor, zoom_factor_bounds
)
bounded_aperture_diameter = enforce_bounds(
aperture_diameter, aperture_diameter_bounds
)
bounded_travel_distance = enforce_bounds(
travel_distance, travel_distance_bounds
)
bounded_aperture_center = enforce_bounds_2d(
aperture_center, aperture_center_bounds
)
return loss_func(
sample_field,
lightwave_field,
bounded_zoom_factor,
bounded_aperture_diameter,
bounded_travel_distance,
bounded_aperture_center,
)
loss, grads = jax.value_and_grad(
loss_wrapped, argnums=(0, 1, 2, 3, 4, 5)
)(
sample_field,
lightwave_field,
zoom_factor,
aperture_diameter,
travel_distance,
aperture_center,
)
return loss, {
"sample": grads[0],
"lightwave": grads[1],
"zoom_factor": grads[2],
"aperture_diameter": grads[3],
"travel_distance": grads[4],
"aperture_center": grads[5],
}
# Get the selected optimizer
optimizer = get_optimizer(optimizer_name)
# Initialize optimizer states
sample_state = optimizer.init(guess_sample.sample.shape)
lightwave_state = optimizer.init(guess_lightwave.field.shape)
zoom_factor_state = optimizer.init(()) # Scalar param
aperture_diameter_state = optimizer.init(()) # Scalar param
travel_distance_state = optimizer.init(()) # Scalar param
aperture_center_state = optimizer.init(
(2,) if aperture_center is not None else ()
)
# Initialize parameters
sample_field = guess_sample.sample
lightwave_field = guess_lightwave.field
current_zoom_factor = zoom_factor
current_aperture_diameter = aperture_diameter
current_travel_distance = travel_distance
current_aperture_center = (
jnp.zeros(2) if aperture_center is None else aperture_center
)
# Set up intermediate result storage
num_saves = jnp.floor(num_iterations / save_every).astype(int)
intermediate_samples = jnp.zeros(
(sample_field.shape[0], sample_field.shape[1], num_saves),
dtype=sample_field.dtype,
)
intermediate_lightwaves = jnp.zeros(
(lightwave_field.shape[0], lightwave_field.shape[1], num_saves),
dtype=lightwave_field.dtype,
)
intermediate_zoom_factors = jnp.zeros(num_saves, dtype=jnp.float64)
intermediate_aperture_diameters = jnp.zeros(num_saves, dtype=jnp.float64)
intermediate_travel_distances = jnp.zeros(num_saves, dtype=jnp.float64)
intermediate_aperture_centers = jnp.zeros(
(2, num_saves), dtype=jnp.float64
)
@jax.jit
def update_step(
sample_field,
lightwave_field,
zoom_factor,
aperture_diameter,
travel_distance,
aperture_center,
sample_state,
lightwave_state,
zoom_factor_state,
aperture_diameter_state,
travel_distance_state,
aperture_center_state,
):
loss, grads = loss_and_grad(
sample_field,
lightwave_field,
zoom_factor,
aperture_diameter,
travel_distance,
aperture_center,
)
# Update sample
sample_field, sample_state = optimizer.update(
sample_field, grads["sample"], sample_state, learning_rate
)
# Update lightwave
lightwave_field, lightwave_state = optimizer.update(
lightwave_field, grads["lightwave"], lightwave_state, learning_rate
)
# Update zoom factor
zoom_factor, zoom_factor_state = optimizer.update(
zoom_factor, grads["zoom_factor"], zoom_factor_state, learning_rate
)
zoom_factor = enforce_bounds(zoom_factor, zoom_factor_bounds)
# Update aperture diameter
aperture_diameter, aperture_diameter_state = optimizer.update(
aperture_diameter,
grads["aperture_diameter"],
aperture_diameter_state,
learning_rate,
)
aperture_diameter = enforce_bounds(
aperture_diameter, aperture_diameter_bounds
)
# Update travel distance
travel_distance, travel_distance_state = optimizer.update(
travel_distance,
grads["travel_distance"],
travel_distance_state,
learning_rate,
)
travel_distance = enforce_bounds(
travel_distance, travel_distance_bounds
)
# Update aperture center
aperture_center, aperture_center_state = optimizer.update(
aperture_center,
grads["aperture_center"],
aperture_center_state,
learning_rate,
)
aperture_center = enforce_bounds_2d(
aperture_center, aperture_center_bounds
)
return (
sample_field,
lightwave_field,
zoom_factor,
aperture_diameter,
travel_distance,
aperture_center,
sample_state,
lightwave_state,
zoom_factor_state,
aperture_diameter_state,
travel_distance_state,
aperture_center_state,
loss,
)
# Run optimization loop
for ii in range(num_iterations):
(
sample_field,
lightwave_field,
current_zoom_factor,
current_aperture_diameter,
current_travel_distance,
current_aperture_center,
sample_state,
lightwave_state,
zoom_factor_state,
aperture_diameter_state,
travel_distance_state,
aperture_center_state,
loss,
) = update_step(
sample_field,
lightwave_field,
current_zoom_factor,
current_aperture_diameter,
current_travel_distance,
current_aperture_center,
sample_state,
lightwave_state,
zoom_factor_state,
aperture_diameter_state,
travel_distance_state,
aperture_center_state,
)
# Save intermediate results
if ii % save_every == 0:
print(f"Iteration {ii}, Loss: {loss}")
save_idx = ii // save_every
if save_idx < num_saves:
intermediate_samples = intermediate_samples.at[
:, :, save_idx
].set(sample_field)
intermediate_lightwaves = intermediate_lightwaves.at[
:, :, save_idx
].set(lightwave_field)
intermediate_zoom_factors = intermediate_zoom_factors.at[
save_idx
].set(current_zoom_factor)
intermediate_aperture_diameters = (
intermediate_aperture_diameters.at[save_idx].set(
current_aperture_diameter
)
)
intermediate_travel_distances = (
intermediate_travel_distances.at[save_idx].set(
current_travel_distance
)
)
intermediate_aperture_centers = (
intermediate_aperture_centers.at[:, save_idx].set(
current_aperture_center
)
)
# Create final objects
final_sample = make_sample_function(
sample=sample_field, dx=guess_sample.dx
)
final_lightwave = make_optical_wavefront(
field=lightwave_field,
wavelength=guess_lightwave.wavelength,
dx=guess_lightwave.dx,
z_position=guess_lightwave.z_position,
)
# Create final values tuple
final_values = (
final_sample,
final_lightwave,
current_zoom_factor,
current_aperture_diameter,
current_aperture_center,
current_travel_distance,
)
# Create intermediate values tuple
intermediate_values = (
intermediate_samples,
intermediate_lightwaves,
intermediate_zoom_factors,
intermediate_aperture_diameters,
intermediate_aperture_centers,
intermediate_travel_distances,
)
# Return both tuples as a single tuple of tuples
return (final_values, intermediate_values)