"""Factory functions for creating data structures.
Extended Summary
----------------
Factory functions for creating data structures with runtime type checking.
All runtime validations use JAX safe conditional statements.
Routine Listings
----------------
make_lens_params : function
Creates a LensParams instance with runtime type checking
make_grid_params : function
Creates a GridParams instance with runtime type checking
make_optical_wavefront : function
Creates an OpticalWavefront instance with runtime type checking
make_microscope_data : function
Creates a MicroscopeData instance with runtime type checking
make_diffractogram : function
Creates a Diffractogram instance with runtime type checking
make_sample_function : function
Creates a SampleFunction instance with runtime type checking
make_optimizer_state : function
Creates an OptimizerState instance with runtime type checking
make_ptychography_params : function
Creates a PtychographyParams instance with runtime type checking
Notes
-----
Always use these factory functions instead of directly instantiating the
NamedTuple classes to ensure proper runtime type checking of the contents.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Optional, Tuple, Union
from jax import lax
from jaxtyping import Array, Bool, Complex, Float, Int, Num, jaxtyped
from .types import (
Diffractogram,
GridParams,
LensParams,
MicroscopeData,
OpticalWavefront,
OptimizerState,
PtychographyParams,
SampleFunction,
scalar_bool,
scalar_complex,
scalar_float,
scalar_integer,
)
jax.config.update("jax_enable_x64", True)
[docs]
@jaxtyped(typechecker=beartype)
def make_lens_params(
focal_length: scalar_float,
diameter: scalar_float,
n: scalar_float,
center_thickness: scalar_float,
r1: scalar_float,
r2: scalar_float,
) -> LensParams:
"""JAX-safe factory function for LensParams with data validation.
Parameters
----------
focal_length : scalar_float
Focal length of the lens in meters
diameter : scalar_float
Diameter of the lens in meters
n : scalar_float
Refractive index of the lens material
center_thickness : scalar_float
Thickness at the center of the lens in meters
r1 : scalar_float
Radius of curvature of the first surface in meters
(positive for convex)
r2 : scalar_float
Radius of curvature of the second surface in meters
(positive for convex)
Returns
-------
validated_lens_params : LensParams
Validated lens parameters instance
Raises
------
ValueError
If parameters are invalid or out of valid ranges
Notes
-----
Algorithm:
- Convert inputs to JAX arrays
- Validate parameters:
- Check focal_length is positive
- Check diameter is positive
- Check refractive index is positive
- Check center_thickness is positive
- Check radii are finite
- Create and return LensParams instance
"""
focal_length = jnp.asarray(focal_length, dtype=jnp.float64)
diameter = jnp.asarray(diameter, dtype=jnp.float64)
n = jnp.asarray(n, dtype=jnp.float64)
center_thickness = jnp.asarray(center_thickness, dtype=jnp.float64)
r1 = jnp.asarray(r1, dtype=jnp.float64)
r2 = jnp.asarray(r2, dtype=jnp.float64)
def validate_and_create() -> LensParams:
def check_focal_length() -> Float[Array, " "]:
return lax.cond(
focal_length > 0,
lambda: focal_length,
lambda: lax.stop_gradient(
lax.cond(False, lambda: focal_length, lambda: focal_length)
),
)
def check_diameter() -> Float[Array, " "]:
return lax.cond(
diameter > 0,
lambda: diameter,
lambda: lax.stop_gradient(
lax.cond(False, lambda: diameter, lambda: diameter)
),
)
def check_refractive_index() -> Float[Array, " "]:
return lax.cond(
n > 0,
lambda: n,
lambda: lax.stop_gradient(
lax.cond(False, lambda: n, lambda: n)
),
)
def check_center_thickness() -> Float[Array, " "]:
return lax.cond(
center_thickness > 0,
lambda: center_thickness,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: center_thickness,
lambda: center_thickness,
)
),
)
def check_radii_finite() -> (
Tuple[Float[Array, " "], Float[Array, " "]]
):
return lax.cond(
jnp.logical_and(jnp.isfinite(r1), jnp.isfinite(r2)),
lambda: (r1, r2),
lambda: lax.stop_gradient(
lax.cond(False, lambda: (r1, r2), lambda: (r1, r2))
),
)
check_focal_length()
check_diameter()
check_refractive_index()
check_center_thickness()
check_radii_finite()
return LensParams(
focal_length=focal_length,
diameter=diameter,
n=n,
center_thickness=center_thickness,
r1=r1,
r2=r2,
)
validated_lens_params: LensParams = validate_and_create()
return validated_lens_params
[docs]
@jaxtyped(typechecker=beartype)
def make_grid_params(
xx: Float[Array, " hh ww"],
yy: Float[Array, " hh ww"],
phase_profile: Float[Array, " hh ww"],
transmission: Float[Array, " hh ww"],
) -> GridParams:
"""JAX-safe factory function for GridParams with data validation.
Parameters
----------
xx : Float[Array, " hh ww"]
Spatial grid in the x-direction
yy : Float[Array, " hh ww"]
Spatial grid in the y-direction
phase_profile : Float[Array, " hh ww"]
Phase profile of the optical field
transmission : Float[Array, " hh ww"]
Transmission profile of the optical field
Returns
-------
validated_grid_params : GridParams
Validated grid parameters instance
Raises
------
ValueError
If array shapes are inconsistent or data is invalid
Notes
-----
Algorithm:
- Convert inputs to JAX arrays
- Validate array shapes:
- Check all arrays are 2D
- Check all arrays have the same shape
- Validate data:
- Ensure transmission values are between 0 and 1
- Ensure phase values are finite
- Ensure grid coordinates are finite
- Create and return GridParams instance
"""
xx = jnp.asarray(xx, dtype=jnp.float64)
yy = jnp.asarray(yy, dtype=jnp.float64)
phase_profile = jnp.asarray(phase_profile, dtype=jnp.float64)
transmission = jnp.asarray(transmission, dtype=jnp.float64)
def validate_and_create() -> GridParams:
array_dims: int = 2
hh: int
ww: int
hh, ww = xx.shape
def check_2d_arrays() -> Tuple[
Float[Array, " hh ww"],
Float[Array, " hh ww"],
Float[Array, " hh ww"],
Float[Array, " hh ww"],
]:
return lax.cond(
jnp.logical_and(
jnp.logical_and(
xx.ndim == array_dims, yy.ndim == array_dims
),
jnp.logical_and(
phase_profile.ndim == array_dims,
transmission.ndim == array_dims,
),
),
lambda: (xx, yy, phase_profile, transmission),
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: (xx, yy, phase_profile, transmission),
lambda: (xx, yy, phase_profile, transmission),
)
),
)
def check_same_shape() -> Tuple[
Float[Array, " hh ww"],
Float[Array, " hh ww"],
Float[Array, " hh ww"],
Float[Array, " hh ww"],
]:
return lax.cond(
jnp.logical_and(
jnp.logical_and(
xx.shape == (hh, ww), yy.shape == (hh, ww)
),
jnp.logical_and(
phase_profile.shape == (hh, ww),
transmission.shape == (hh, ww),
),
),
lambda: (xx, yy, phase_profile, transmission),
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: (xx, yy, phase_profile, transmission),
lambda: (xx, yy, phase_profile, transmission),
)
),
)
def check_transmission_range() -> Float[Array, " hh ww"]:
return lax.cond(
jnp.logical_and(
jnp.all(transmission >= 0), jnp.all(transmission <= 1)
),
lambda: transmission,
lambda: lax.stop_gradient(
lax.cond(False, lambda: transmission, lambda: transmission)
),
)
def check_phase_finite() -> Float[Array, " hh ww"]:
return lax.cond(
jnp.all(jnp.isfinite(phase_profile)),
lambda: phase_profile,
lambda: lax.stop_gradient(
lax.cond(
False, lambda: phase_profile, lambda: phase_profile
)
),
)
def check_grid_finite() -> (
Tuple[Float[Array, " hh ww"], Float[Array, " hh ww"]]
):
return lax.cond(
jnp.logical_and(
jnp.all(jnp.isfinite(xx)), jnp.all(jnp.isfinite(yy))
),
lambda: (xx, yy),
lambda: lax.stop_gradient(
lax.cond(False, lambda: (xx, yy), lambda: (xx, yy))
),
)
check_2d_arrays()
check_same_shape()
check_transmission_range()
check_phase_finite()
check_grid_finite()
return GridParams(
xx=xx,
yy=yy,
phase_profile=phase_profile,
transmission=transmission,
)
validated_grid_params: GridParams = validate_and_create()
return validated_grid_params
[docs]
@jaxtyped(typechecker=beartype)
def make_optical_wavefront(
field: Union[Complex[Array, " hh ww"], Complex[Array, " hh ww 2"]],
wavelength: scalar_float,
dx: scalar_float,
z_position: scalar_float,
polarization: Optional[scalar_bool] = False,
) -> OpticalWavefront:
"""JAX-safe factory function for OpticalWavefront with data validation.
Parameters
----------
field : Union[Complex[Array, " hh ww"], Complex[Array, " hh ww 2"]]
Complex amplitude of the optical field. Should be 2D for scalar fields
or 3D with last dimension 2 for polarized fields.
wavelength : scalar_float
Wavelength of the optical wavefront in meters
dx : scalar_float
Spatial sampling interval (grid spacing) in meters
z_position : scalar_float
Axial position of the wavefront in the propagation direction in meters.
polarization : scalar_bool, optional
Whether the field is polarized (True for 3D field, False for 2D field).
Default is False.
Returns
-------
validated_optical_wavefront : OpticalWavefront
Validated optical wavefront instance
Raises
------
ValueError
If data is invalid or parameters are out of valid ranges
Notes
-----
Algorithm:
- Convert inputs to JAX arrays
- Validate field array:
- Check it's 2D
- Ensure all values are finite
- Validate parameters:
- Check wavelength is positive
- Check dx is positive
- Check z_position is finite
- Create and return OpticalWavefront instance
"""
field: Complex[Array, " hh ww"] = jnp.asarray(field, dtype=jnp.complex128)
wavelength: Float[Array, " "] = jnp.asarray(wavelength, dtype=jnp.float64)
dx: Float[Array, " "] = jnp.asarray(dx, dtype=jnp.float64)
z_position: Float[Array, " "] = jnp.asarray(z_position, dtype=jnp.float64)
polarization: Bool[Array, " "] = jnp.asarray(polarization, dtype=jnp.bool_)
def validate_and_create() -> OpticalWavefront:
def check_field_dimensions() -> (
Union[Complex[Array, " hh ww"], Complex[Array, " hh ww 2"]]
):
non_polar_dimensions: int = 2
polar_dimensions: int = 3
def check_polarized() -> Complex[Array, " hh ww 2"]:
return lax.cond(
jnp.logical_and(
field.ndim == polar_dimensions,
field.shape[-1] == non_polar_dimensions,
),
lambda: field,
lambda: lax.stop_gradient(
lax.cond(False, lambda: field, lambda: field)
),
)
def check_scalar() -> Complex[Array, " hh ww"]:
return lax.cond(
field.ndim == non_polar_dimensions,
lambda: field,
lambda: lax.stop_gradient(
lax.cond(False, lambda: field, lambda: field)
),
)
return lax.cond(
polarization,
check_polarized,
check_scalar,
)
def check_field_finite() -> (
Union[Complex[Array, " hh ww"], Complex[Array, " hh ww 2"]]
):
return lax.cond(
jnp.all(jnp.isfinite(field)),
lambda: field,
lambda: lax.stop_gradient(
lax.cond(False, lambda: field, lambda: field)
),
)
def check_wavelength() -> Float[Array, " "]:
return lax.cond(
wavelength > 0,
lambda: wavelength,
lambda: lax.stop_gradient(
lax.cond(False, lambda: wavelength, lambda: wavelength)
),
)
def check_dx() -> Float[Array, " "]:
return lax.cond(
dx > 0,
lambda: dx,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dx, lambda: dx)
),
)
def check_z_position() -> Float[Array, " "]:
return lax.cond(
jnp.isfinite(z_position),
lambda: z_position,
lambda: lax.stop_gradient(
lax.cond(False, lambda: z_position, lambda: z_position)
),
)
check_field_dimensions()
check_field_finite()
check_wavelength()
check_dx()
check_z_position()
return OpticalWavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
polarization=polarization,
)
validated_optical_wavefront: OpticalWavefront = validate_and_create()
return validated_optical_wavefront
[docs]
@jaxtyped(typechecker=beartype)
def make_microscope_data(
image_data: Union[Float[Array, " pp hh ww"], Float[Array, " xx yy hh ww"]],
positions: Num[Array, " pp 2"],
wavelength: scalar_float,
dx: scalar_float,
) -> MicroscopeData:
"""JAX-safe factory function for MicroscopeData with data validation.
Parameters
----------
image_data : Union[Float[Array, " pp hh ww"], Float[Array, " xx yy hh ww"]]
3D or 4D image data representing the optical field
positions : Num[Array, " pp 2"]
Positions of the images during collection
wavelength : scalar_float
Wavelength of the optical wavefront in meters
dx : scalar_float
Spatial sampling interval (grid spacing) in meters
Returns
-------
validated_microscope_data : MicroscopeData
Validated microscope data instance
Raises
------
ValueError
If data is invalid or parameters are out of valid ranges
Notes
-----
Algorithm:
- Convert inputs to JAX arrays
- Validate image_data:
- Check it's 3D or 4D
- Ensure all values are finite and non-negative
- Validate positions:
- Check it's 2D with shape (pp, 2)
- Ensure all values are finite
- Validate parameters:
- Check wavelength is positive
- Check dx is positive
- Validate consistency:
- Check P matches between image_data and positions
- Create and return MicroscopeData instance
"""
image_data: Union[
Float[Array, " pp hh ww"], Float[Array, " xx yy hh ww"]
] = jnp.asarray(image_data, dtype=jnp.float64)
positions: Num[Array, " pp 2"] = jnp.asarray(positions, dtype=jnp.float64)
wavelength: Float[Array, " "] = jnp.asarray(wavelength, dtype=jnp.float64)
dx: Float[Array, " "] = jnp.asarray(dx, dtype=jnp.float64)
expected_image_dim = 2
expected_diffractogram_dim_3d: int = 3
expected_diffractogram_dim_4d: int = 4
def validate_and_create() -> MicroscopeData:
def check_image_dimensions() -> (
Union[Float[Array, " P H W"], Float[Array, " X Y H W"]]
):
return lax.cond(
jnp.logical_or(
image_data.ndim == expected_diffractogram_dim_3d,
image_data.ndim == expected_diffractogram_dim_4d,
),
lambda: image_data,
lambda: lax.stop_gradient(
lax.cond(False, lambda: image_data, lambda: image_data)
),
)
def check_image_finite() -> (
Union[Float[Array, " P H W"], Float[Array, " X Y H W"]]
):
return lax.cond(
jnp.all(jnp.isfinite(image_data)),
lambda: image_data,
lambda: lax.stop_gradient(
lax.cond(False, lambda: image_data, lambda: image_data)
),
)
def check_image_nonnegative() -> (
Union[Float[Array, " P H W"], Float[Array, " X Y H W"]]
):
return lax.cond(
jnp.all(image_data >= 0),
lambda: image_data,
lambda: lax.stop_gradient(
lax.cond(False, lambda: image_data, lambda: image_data)
),
)
def check_positions_shape() -> Num[Array, " P 2"]:
return lax.cond(
positions.shape[1] == expected_image_dim,
lambda: positions,
lambda: lax.stop_gradient(
lax.cond(False, lambda: positions, lambda: positions)
),
)
def check_positions_finite() -> Num[Array, " P 2"]:
return lax.cond(
jnp.all(jnp.isfinite(positions)),
lambda: positions,
lambda: lax.stop_gradient(
lax.cond(False, lambda: positions, lambda: positions)
),
)
def check_wavelength() -> Float[Array, " "]:
return lax.cond(
wavelength > 0,
lambda: wavelength,
lambda: lax.stop_gradient(
lax.cond(False, lambda: wavelength, lambda: wavelength)
),
)
def check_dx() -> Float[Array, " "]:
return lax.cond(
dx > 0,
lambda: dx,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dx, lambda: dx)
),
)
def check_consistency() -> Tuple[
Union[Float[Array, " P H W"], Float[Array, " X Y H W"]],
Num[Array, " P 2"],
]:
pp = positions.shape[0]
def check_3d_consistency() -> Tuple[
Union[Float[Array, " pp H W"], Float[Array, " X Y H W"]],
Num[Array, " pp 2"],
]:
return lax.cond(
image_data.shape[0] == pp,
lambda: (image_data, positions),
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: (image_data, positions),
lambda: (image_data, positions),
)
),
)
def check_4d_consistency() -> Tuple[
Union[Float[Array, " P H W"], Float[Array, " X Y H W"]],
Num[Array, " P 2"],
]:
return lax.cond(
image_data.shape[0] * image_data.shape[1] == pp,
lambda: (image_data, positions),
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: (image_data, positions),
lambda: (image_data, positions),
)
),
)
return lax.cond(
image_data.ndim == expected_image_dim,
check_3d_consistency,
check_4d_consistency,
)
check_image_dimensions()
check_image_finite()
check_image_nonnegative()
check_positions_shape()
check_positions_finite()
check_wavelength()
check_dx()
check_consistency()
return MicroscopeData(
image_data=image_data,
positions=positions,
wavelength=wavelength,
dx=dx,
)
validated_microscope_data: MicroscopeData = validate_and_create()
return validated_microscope_data
[docs]
@jaxtyped(typechecker=beartype)
def make_diffractogram(
image: Float[Array, " hh ww"],
wavelength: scalar_float,
dx: scalar_float,
) -> Diffractogram:
"""JAX-safe factory function for Diffractogram with data validation.
Parameters
----------
image : Float[Array, " hh ww"]
Image data
wavelength : scalar_float
Wavelength of the optical wavefront in meters
dx : scalar_float
Spatial sampling interval (grid spacing) in meters
Returns
-------
validated_diffractogram : Diffractogram
Validated diffractogram instance
Raises
------
ValueError
If data is invalid or parameters are out of valid ranges
Notes
-----
Algorithm:
- Convert inputs to JAX arrays
- Validate image array:
- Check it's 2D
- Ensure all values are finite and non-negative
- Validate parameters:
- Check wavelength is positive
- Check dx is positive
- Create and return Diffractogram instance
"""
image: Float[Array, " H W"] = jnp.asarray(image, dtype=jnp.float64)
wavelength: Float[Array, " "] = jnp.asarray(wavelength, dtype=jnp.float64)
dx: Float[Array, " "] = jnp.asarray(dx, dtype=jnp.float64)
expected_sample_dim: int = 2
def validate_and_create() -> Diffractogram:
def check_2d_image() -> Float[Array, " H W"]:
return lax.cond(
image.ndim == expected_sample_dim,
lambda: image,
lambda: lax.stop_gradient(
lax.cond(False, lambda: image, lambda: image)
),
)
def check_image_finite() -> Float[Array, " H W"]:
return lax.cond(
jnp.all(jnp.isfinite(image)),
lambda: image,
lambda: lax.stop_gradient(
lax.cond(False, lambda: image, lambda: image)
),
)
def check_image_nonnegative() -> Float[Array, " H W"]:
return lax.cond(
jnp.all(image >= 0),
lambda: image,
lambda: lax.stop_gradient(
lax.cond(False, lambda: image, lambda: image)
),
)
def check_wavelength() -> Float[Array, " "]:
return lax.cond(
wavelength > 0,
lambda: wavelength,
lambda: lax.stop_gradient(
lax.cond(False, lambda: wavelength, lambda: wavelength)
),
)
def check_dx() -> Float[Array, " "]:
return lax.cond(
dx > 0,
lambda: dx,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dx, lambda: dx)
),
)
check_2d_image()
check_image_finite()
check_image_nonnegative()
check_wavelength()
check_dx()
return Diffractogram(
image=image,
wavelength=wavelength,
dx=dx,
)
validated_diffractogram: Diffractogram = validate_and_create()
return validated_diffractogram
[docs]
@jaxtyped(typechecker=beartype)
def make_sample_function(
sample: Complex[Array, " hh ww"],
dx: scalar_float,
) -> SampleFunction:
"""JAX-safe factory function for SampleFunction with data validation.
Parameters
----------
sample : Complex[Array, " hh ww"]
The sample function
dx : scalar_float
Spatial sampling interval (grid spacing) in meters
Returns
-------
validated_sample_function : SampleFunction
Validated sample function instance
Raises
------
ValueError
If data is invalid or parameters are out of valid ranges
Notes
-----
Algorithm:
- Convert inputs to JAX arrays
- Validate sample array:
- Check it's 2D
- Ensure all values are finite
- Validate parameters:
- Check dx is positive
- Create and return SampleFunction instance
"""
sample: Complex[Array, " hh ww"] = jnp.asarray(
sample, dtype=jnp.complex128
)
dx: Float[Array, " "] = jnp.asarray(dx, dtype=jnp.float64)
expected_sample_dim: int = 2
def validate_and_create() -> SampleFunction:
def check_2d_sample() -> Complex[Array, " hh ww"]:
return lax.cond(
sample.ndim == expected_sample_dim,
lambda: sample,
lambda: lax.stop_gradient(
lax.cond(False, lambda: sample, lambda: sample)
),
)
def check_sample_finite() -> Complex[Array, " hh ww"]:
return lax.cond(
jnp.all(jnp.isfinite(sample)),
lambda: sample,
lambda: lax.stop_gradient(
lax.cond(False, lambda: sample, lambda: sample)
),
)
def check_dx() -> scalar_float:
return lax.cond(
dx > 0,
lambda: dx,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dx, lambda: dx)
),
)
check_2d_sample()
check_sample_finite()
check_dx()
return SampleFunction(
sample=sample,
dx=dx,
)
validated_sample_function: SampleFunction = validate_and_create()
return validated_sample_function
[docs]
@jaxtyped(typechecker=beartype)
def make_optimizer_state(
shape: Tuple,
m: Optional[Union[Complex[Array, " ..."], scalar_complex]] = 1j,
v: Optional[Union[Float[Array, " ..."], scalar_float]] = 0.0,
step: Optional[scalar_integer] = 0,
) -> OptimizerState:
"""JAX-safe factory function for OptimizerState with data validation.
Parameters
----------
shape : Tuple
Shape of the parameters to be optimized
m : Optional[Complex[Array, "..."]], optional
First moment estimate. If None, initialized to zeros with given shape.
Default is 1j.
v : Optional[Float[Array, "..."]], optional
Second moment estimate. If None, initialized to zeros with given shape.
Default is 0.0.
step : Optional[scalar_integer], optional
Step count. Default is 0.
Returns
-------
validated_optimizer_state : OptimizerState
Validated optimizer state instance
Raises
------
ValueError
If arrays have incompatible shapes with the given shape parameter
Notes
-----
Algorithm:
- If m is None, initialize with complex zeros of given shape
- If v is None, initialize with real zeros of given shape
- If step is None, initialize to 0
- Convert all inputs to JAX arrays with appropriate dtypes
- Validate arrays have compatible shapes
- Create and return OptimizerState instance
"""
sentinel_m = 1j
sentinel_v = 0.0
m_input = jnp.asarray(m, dtype=jnp.complex128)
v_input = jnp.asarray(v, dtype=jnp.float64)
step_input = jnp.asarray(step, dtype=jnp.int32)
m_array = lax.cond(
jnp.all(m_input == sentinel_m),
lambda: jnp.zeros(shape, dtype=jnp.complex128),
lambda: lax.cond(
m_input.ndim == 0,
lambda: jnp.broadcast_to(m_input, shape),
lambda: m_input,
),
)
v_array = lax.cond(
jnp.logical_and(jnp.all(v_input == sentinel_v), v_input.ndim == 0),
lambda: jnp.zeros(shape, dtype=jnp.float64),
lambda: lax.cond(
v_input.ndim == 0,
lambda: jnp.broadcast_to(v_input, shape),
lambda: v_input,
),
)
step_array = step_input
def validate_and_create() -> OptimizerState:
def check_m_shape() -> Complex[Array, " ..."]:
return lax.cond(
m_array.shape == shape,
lambda: m_array,
lambda: lax.stop_gradient(
lax.cond(False, lambda: m_array, lambda: m_array)
),
)
def check_v_shape() -> Float[Array, " ..."]:
return lax.cond(
v_array.shape == shape,
lambda: v_array,
lambda: lax.stop_gradient(
lax.cond(False, lambda: v_array, lambda: v_array)
),
)
def check_step_scalar() -> Int[Array, " "]:
return lax.cond(
step_array.ndim == 0,
lambda: step_array,
lambda: lax.stop_gradient(
lax.cond(False, lambda: step_array, lambda: step_array)
),
)
check_m_shape()
check_v_shape()
check_step_scalar()
return OptimizerState(
m=m_array,
v=v_array,
step=step_array,
)
validated_optimizer_state: OptimizerState = validate_and_create()
return validated_optimizer_state
[docs]
@jaxtyped(typechecker=beartype)
def make_ptychography_params(
zoom_factor: scalar_float,
aperture_diameter: scalar_float,
travel_distance: scalar_float,
aperture_center: Float[Array, " 2"],
camera_pixel_size: scalar_float,
learning_rate: scalar_float,
num_iterations: scalar_integer,
) -> PtychographyParams:
"""Create a PtychographyParams PyTree with validated parameters.
Parameters
----------
zoom_factor : scalar_float
Optical zoom factor for magnification (must be positive)
aperture_diameter : scalar_float
Diameter of the aperture in meters (must be positive)
travel_distance : scalar_float
Light propagation distance in meters (must be positive)
aperture_center : Float[Array, " 2"]
Center position of the aperture (x, y) in meters
camera_pixel_size : scalar_float
Camera pixel size in meters (must be positive)
learning_rate : scalar_float
Learning rate for optimization (must be positive)
num_iterations : scalar_integer
Number of optimization iterations (must be positive)
Returns
-------
PtychographyParams
Validated ptychography parameters as a PyTree
Notes
-----
This function performs runtime validation to ensure all parameters
are properly formatted and within valid ranges before creating the
PtychographyParams PyTree.
"""
# Convert scalars to JAX arrays
zoom_factor_array = jnp.asarray(zoom_factor, dtype=jnp.float64)
aperture_diameter_array = jnp.asarray(aperture_diameter, dtype=jnp.float64)
travel_distance_array = jnp.asarray(travel_distance, dtype=jnp.float64)
aperture_center_array = jnp.asarray(aperture_center, dtype=jnp.float64)
camera_pixel_size_array = jnp.asarray(camera_pixel_size, dtype=jnp.float64)
learning_rate_array = jnp.asarray(learning_rate, dtype=jnp.float64)
num_iterations_array = jnp.asarray(num_iterations, dtype=jnp.int64)
def validate_and_create() -> PtychographyParams:
def check_positive_zoom() -> Float[Array, " "]:
return lax.cond(
zoom_factor_array > 0,
lambda: zoom_factor_array,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: zoom_factor_array,
lambda: zoom_factor_array,
)
),
)
def check_positive_aperture() -> Float[Array, " "]:
return lax.cond(
aperture_diameter_array > 0,
lambda: aperture_diameter_array,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: aperture_diameter_array,
lambda: aperture_diameter_array,
)
),
)
def check_positive_distance() -> Float[Array, " "]:
return lax.cond(
travel_distance_array > 0,
lambda: travel_distance_array,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: travel_distance_array,
lambda: travel_distance_array,
)
),
)
def check_aperture_center_shape() -> Float[Array, " 2"]:
return lax.cond(
aperture_center_array.shape == (2,),
lambda: aperture_center_array,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: aperture_center_array,
lambda: aperture_center_array,
)
),
)
def check_positive_pixel_size() -> Float[Array, " "]:
return lax.cond(
camera_pixel_size_array > 0,
lambda: camera_pixel_size_array,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: camera_pixel_size_array,
lambda: camera_pixel_size_array,
)
),
)
def check_positive_learning_rate() -> Float[Array, " "]:
return lax.cond(
learning_rate_array > 0,
lambda: learning_rate_array,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: learning_rate_array,
lambda: learning_rate_array,
)
),
)
def check_positive_iterations() -> Int[Array, " "]:
return lax.cond(
num_iterations_array > 0,
lambda: num_iterations_array,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: num_iterations_array,
lambda: num_iterations_array,
)
),
)
# Run all validation checks
check_positive_zoom()
check_positive_aperture()
check_positive_distance()
check_aperture_center_shape()
check_positive_pixel_size()
check_positive_learning_rate()
check_positive_iterations()
return PtychographyParams(
zoom_factor=zoom_factor_array,
aperture_diameter=aperture_diameter_array,
travel_distance=travel_distance_array,
aperture_center=aperture_center_array,
camera_pixel_size=camera_pixel_size_array,
learning_rate=learning_rate_array,
num_iterations=num_iterations_array,
)
validated_params: PtychographyParams = validate_and_create()
return validated_params