API Reference

janssen.invert

Inversion algorithms for phase retrieval and ptychography.

Extended Summary

Comprehensive algorithms for phase retrieval and ptychographic reconstruction using differentiable programming techniques. Includes various optimization strategies and loss functions for reconstructing complex-valued fields.

Submodules

engine

Reconstruction engine

ptychography

Ptychographic algorithms

optimizers

Optimization routines

loss_functions

Loss function definitions

Routine Listings

create_loss_functionfunction

Factory function for creating various loss functions

simple_microscope_ptychographyfunction

Main ptychography reconstruction algorithm using PtychographyParams

epie_opticalfunction

Extended PIE algorithm for optical ptychography

single_pie_iterationfunction

Single iteration of PIE algorithm

single_pie_sequentialfunction

Sequential PIE implementation for multiple positions

single_pie_vmap

Vectorized PIE implementation using vmap

init_adamfunction

Initialize Adam optimizer state

init_adagradfunction

Initialize Adagrad optimizer state

init_rmspropfunction

Initialize RMSprop optimizer state

Notes

All functions are JAX-compatible and support automatic differentiation. The algorithms can be composed with JIT compilation for improved performance.

janssen.invert.epie_optical(microscope_data: MicroscopeData, initial_object: OpticalWavefront, initial_surface: SampleFunction, pixel_mask: Float[Array, 'H W'], propagation_distance_1: float | Float[Array, ''], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], vmap_iterations: int | Int[Array, ''] | None = 0, alpha_object: float | Float[Array, ''] | None = 0.1, gamma_object: float | Float[Array, ''] | None = 0.5, alpha_surface: float | Float[Array, ''] | None = 0.1, gamma_surface: float | Float[Array, ''] | None = 0.5, num_loops: int | Int[Array, ''] | None = 10) tuple[OpticalWavefront, SampleFunction][source]

Reconstruct ptychography using the extended PIE algorithm.

Parameters:
  • microscope_data (MicroscopeData) – Measured intensity data with positions.

  • initial_object (OpticalWavefront) – Initial guess for object wavefront.

  • initial_surface (SampleFunction) – Initial guess for surface pattern.

  • pixel_mask (Float[Array, " H W"]) – Pixel response mask for modeling sensor characteristics.

  • propagation_distance_1 (float) – Distance from object to diffuser plane in meters.

  • propagation_distance_2 (float) – Distance from diffuser to sensor plane in meters.

  • magnification (scalar_integer) – Magnification factor for downsampling.

  • vmap_iterations (scalar_integer, optional) – Number of initial iterations to run in vmap mode for rapid convergence. If 0, use sequential mode for all iterations. If > 0, use vmap for first N iterations, then switch to sequential. Default is 0.

  • alpha_object (float, optional) – Object update mixing parameter. Default is 0.1.

  • gamma_object (float, optional) – Object update step size. Default is 0.5.

  • alpha_surface (float, optional) – Surface update mixing parameter. Default is 0.1.

  • gamma_surface (float, optional) – Surface update step size. Default is 0.5.

  • num_loops (scalar_integer, optional) – Number of iteration loops. Default is 10.

Returns:

  • recovered_objectOpticalWavefront

    Reconstructed object wavefront.

  • recovered_surfaceSampleFunction

    Reconstructed surface pattern.

Return type:

tuple of (OpticalWavefront, SampleFunction)

Notes

Algorithm: - Compute image data - Compute positions - Compute frequency grids - Compute object recovery propagation field - Compute surface pattern - Define loop body - Apply fori_loop over loops - Compute final object field - Compute final object wavefront - Compute final surface pattern - Return final object and surface

janssen.invert.single_pie_iteration(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], measurement: Float[Array, 'H W'], position: Float[Array, '2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]

Single iteration of the extended PIE algorithm.

Parameters:
  • object_prop_ft (Complex[Array, " H W"]) – Object wavefront at diffuser plane in Fourier domain.

  • surface_pattern (Complex[Array, " H W"]) – Surface pattern function.

  • measurement (Float[Array, " H W"]) – Measured intensity at current position.

  • position (Float[Array, " 2"]) – Current scanning position [x, y].

  • frequency_x_grid (Float[Array, " H W"]) – Frequency grid in x direction.

  • frequency_y_grid (Float[Array, " H W"]) – Frequency grid in y direction.

  • pixel_mask (Float[Array, " H W"]) – Pixel response mask for sensor modeling.

  • propagation_distance_2 (float) – Distance from diffuser to sensor.

  • magnification (scalar_integer) – Downsampling magnification factor.

  • alpha_object (float) – Object update mixing parameter.

  • gamma_object (float) – Object update step size.

  • alpha_surface (float) – Surface update mixing parameter.

  • gamma_surface (float) – Surface update step size.

  • wavelength (float) – Wavelength of light in meters.

  • dx (float) – Pixel spacing in meters.

Returns:

  • updated_object_ftComplex[Array, “ H W”]

    Updated object wavefront in Fourier domain.

  • updated_surfaceComplex[Array, “ H W”]

    Updated surface pattern.

Return type:

tuple of (Complex[Array, " H W"], Complex[Array, " H W"])

Notes

Algorithm: - Compute object shifted - Compute surface plane - Compute surface propagation kernel - Compute sensor plane - Compute sensor intensity - Compute ratio map - Compute ratio map upsampled - Compute sensor plane new - Compute sensor plane new in Fourier domain - Compute CTF conjugate - Compute CTF maximum squared - Compute surface propagation kernel - Compute updated surface pattern - Compute updated object wavefront - Compute updated object wavefront in Fourier domain - Return updated object and surface

janssen.invert.single_pie_sequential(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], image_data: Float[Array, 'P H W'], positions: Float[Array, 'P 2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]

Sequential processing over positions using fori_loop for proper PIE convergence.

Parameters:
  • object_prop_ft (Complex[Array, " H W"]) – Current object wavefront in Fourier domain.

  • surface_pattern (Complex[Array, " H W"]) – Current surface pattern.

  • image_data (Float[Array, " P H W"]) – Measurement data for all positions.

  • positions (Float[Array, " P 2"]) – Position coordinates for all measurements.

  • frequency_x_grid (Float[Array, " H W"]) – Frequency grid in x direction.

  • frequency_y_grid (Float[Array, " H W"]) – Frequency grid in y direction.

  • pixel_mask (Float[Array, " H W"]) – Pixel response mask for sensor modeling.

  • propagation_distance_2 (float) – Distance from diffuser to sensor.

  • magnification (scalar_integer) – Downsampling magnification factor.

  • alpha_object (float) – Object update mixing parameter.

  • gamma_object (float) – Object update step size.

  • alpha_surface (float) – Surface update mixing parameter.

  • gamma_surface (float) – Surface update step size.

  • wavelength (float) – Wavelength of light in meters.

  • dx (float) – Pixel spacing in meters.

Returns:

Updated object and surface state after sequential processing.

Return type:

tuple of (Complex[Array, " H W"], Complex[Array, " H W"])

Notes

Algorithm: - Compute number of positions - Define position body - Apply fori_loop over positions - Return updated state

janssen.invert.single_pie_vmap(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], image_data: Float[Array, 'P H W'], positions: Float[Array, 'P 2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]

Parallel processing over positions using vmap for faster but approximate PIE.

All positions use the same initial state, then updates are averaged.

Parameters:
  • object_prop_ft (Complex[Array, " H W"]) – Current object wavefront in Fourier domain.

  • surface_pattern (Complex[Array, " H W"]) – Current surface pattern.

  • image_data (Float[Array, " P H W"]) – Measurement data for all positions.

  • positions (Float[Array, " P 2"]) – Position coordinates for all measurements.

  • frequency_x_grid (Float[Array, " H W"]) – Frequency grid in x direction.

  • frequency_y_grid (Float[Array, " H W"]) – Frequency grid in y direction.

  • pixel_mask (Float[Array, " H W"]) – Pixel response mask for sensor modeling.

  • propagation_distance_2 (float) – Distance from diffuser to sensor.

  • magnification (scalar_integer) – Downsampling magnification factor.

  • alpha_object (float) – Object update mixing parameter.

  • gamma_object (float) – Object update step size.

  • alpha_surface (float) – Surface update mixing parameter.

  • gamma_surface (float) – Surface update step size.

  • wavelength (float) – Wavelength of light in meters.

  • dx (float) – Pixel spacing in meters.

Returns:

Updated object and surface state after parallel processing and averaging.

Return type:

tuple of (Complex[Array, " H W"], Complex[Array, " H W"])

Notes

Algorithm: - Apply vmap over all positions using same initial state - Compute average of all object updates - Compute average of all surface updates - Return averaged states

janssen.invert.create_loss_function(forward_function: Callable[[...], Array], experimental_data: Array, loss_type: str = 'mae') Callable[[...], Float[Array, '']][source]

Create a JIT-compatible loss function.

This function returns a new function that computes the loss between the output of a forward model and experimental data. The returned function is JIT-compatible and can be used with various optimization algorithms.

Parameters:
  • forward_function (Callable[..., Array]) – The forward model function (e.g., stem_4d).

  • experimental_data (Array) – The experimental data to compare against.

  • loss_type (str, optional) – The type of loss to use. Options are “mae” (Mean Absolute Error), “mse” (Mean Squared Error), or “rmse” (Root Mean Squared Error), by default “mae”.

Returns:

loss_fn – A JIT-compatible function that computes the loss given the model parameters and any additional arguments required by the forward function.

Return type:

Callable[[PyTree, ], Float[Array, " "]]

Notes

  • Define internal loss functions (mae_loss, mse_loss, rmse_loss).

  • Select the appropriate loss function based on loss_type.

  • Create a JIT-compiled function that:
    • Computes the forward model output.

    • Calculates the difference between model and experimental data.

    • Applies the selected loss function.

  • Return the compiled loss function.

janssen.invert.init_adagrad(shape: tuple) OptimizerState[source]

Initialize Adagrad optimizer state.

Parameters:

shape (Tuple) – Shape of the parameters to be optimized

Returns:

state – Initialized Adagrad optimizer state with zero accumulated gradients

Return type:

OptimizerState

janssen.invert.init_adam(shape: tuple) OptimizerState[source]

Initialize Adam optimizer state.

Parameters:

shape (Tuple) – Shape of the parameters to be optimized

Returns:

state – Initialized Adam optimizer state with zero moments and step=0

Return type:

OptimizerState

janssen.invert.init_rmsprop(shape: tuple) OptimizerState[source]

Initialize RMSprop optimizer state.

Parameters:

shape (Tuple) – Shape of the parameters to be optimized

Returns:

state – Initialized RMSprop optimizer state with zero moving average

Return type:

OptimizerState

janssen.invert.simple_microscope_ptychography(experimental_data: MicroscopeData, guess_sample: SampleFunction, guess_lightwave: OpticalWavefront, params: PtychographyParams, save_every: int | Int[Array, ''] | None = 10, loss_type: str | None = 'mse', optimizer_name: str | None = 'adam', zoom_factor_bounds: tuple[float | Float[Array, ''], float | Float[Array, '']] | None = None, aperture_diameter_bounds: tuple[float | Float[Array, ''], float | Float[Array, '']] | None = None, travel_distance_bounds: tuple[float | Float[Array, ''], float | Float[Array, '']] | None = None, aperture_center_bounds: tuple[Float[Array, '2'], Float[Array, '2']] | None = None) tuple[tuple[SampleFunction, OpticalWavefront, float | Float[Array, ''], float | Float[Array, ''], Float[Array, '2'] | None, float | Float[Array, '']], tuple[Complex[Array, 'H W S'], Complex[Array, 'H W S'], Float[Array, 'S'], Float[Array, 'S'], Float[Array, '2 S'], Float[Array, 'S']]][source]

Solve the optical ptychography inverse problem.

Here experimental diffraction patterns are used to reconstruct a sample, lightwave, and optical system parameters.

Parameters:
  • experimental_data (MicroscopeData) – The experimental diffraction patterns collected at different positions.

  • guess_sample (SampleFunction) – Initial guess for the sample properties.

  • guess_lightwave (OpticalWavefront) – Initial guess for the lightwave.

  • params (PtychographyParams) – Ptychography parameters including: - zoom_factor: Optical zoom factor for magnification - aperture_diameter: Diameter of the aperture in meters - travel_distance: Light propagation distance in meters - aperture_center: Center position of the aperture (x, y) - camera_pixel_size: Camera pixel size in meters - learning_rate: Learning rate for optimization - num_iterations: Number of optimization iterations

  • save_every (scalar_integer, optional) – Save intermediate results every n iterations. Default is 10.

  • loss_type (str, optional) – Type of loss function to use. Default is “mse”.

  • optimizer_name (str, optional) – Name of the optimizer to use. Default is “adam”.

  • zoom_factor_bounds (Tuple[scalar_float, scalar_float], optional) – Lower and upper bounds for zoom factor optimization.

  • aperture_diameter_bounds (Tuple[scalar_float, scalar_float], optional) – Lower and upper bounds for aperture diameter optimization.

  • travel_distance_bounds (Tuple[scalar_float, scalar_float], optional) – Lower and upper bounds for travel distance optimization.

  • aperture_center_bounds (Tuple[Float[Array, " 2"], Float[Array, " 2"]], optional) – Lower and upper bounds for aperture center optimization.

Returns:

Tuple containing: - Final results tuple:

  • final_sampleSampleFunction

    Optimized sample properties.

  • final_lightwaveOpticalWavefront

    Optimized lightwave.

  • final_zoom_factorscalar_float

    Optimized zoom factor.

  • final_aperture_diameterscalar_float

    Optimized aperture diameter.

  • final_aperture_centerFloat[Array, “ 2”] or None

    Optimized aperture center.

  • final_travel_distancescalar_float

    Optimized travel distance.

  • Intermediate results tuple:
    • intermediate_samplesComplex[Array, “ H W S”]

      Intermediate samples during optimization.

    • intermediate_lightwavesComplex[Array, “ H W S”]

      Intermediate lightwaves during optimization.

    • intermediate_zoom_factorsFloat[Array, “ S”]

      Intermediate zoom factors during optimization.

    • intermediate_aperture_diametersFloat[Array, “ S”]

      Intermediate aperture diameters during optimization.

    • intermediate_aperture_centersFloat[Array, “ 2 S”]

      Intermediate aperture centers during optimization.

    • intermediate_travel_distancesFloat[Array, “ S”]

      Intermediate travel distances during optimization.

Return type:

Tuple[Tuple[...], Tuple[...]]

janssen.lenses

Lens implementations and optical calculations.

Extended Summary

Comprehensive lens modeling and optical propagation algorithms for simulating light propagation through various optical elements. Includes implementations of common lens types and propagation methods based on wave optics.

Submodules

lens_elements

Lens elements for optical simulations

lens_prop

Lens propagation functions

Routine Listings

create_lens_phasefunction

Create phase profile for a lens based on its parameters

double_concave_lensfunction

Create parameters for a double concave lens

double_convex_lensfunction

Create parameters for a double convex lens

lens_focal_lengthfunction

Calculate focal length from lens parameters

lens_thickness_profilefunction

Calculate thickness profile of a lens

meniscus_lensfunction

Create parameters for a meniscus lens

plano_concave_lensfunction

Create parameters for a plano-concave lens

plano_convex_lensfunction

Create parameters for a plano-convex lens

propagate_through_lensfunction

Propagate optical wavefront through a lens

angular_spectrum_propfunction

Angular spectrum propagation method

digital_zoomfunction

Digital zoom transformation for optical fields

fraunhofer_propfunction

Fraunhofer (far-field) propagation

fresnel_propfunction

Fresnel (near-field) propagation

lens_propagationfunction

General lens-based propagation

optical_zoomfunction

Optical zoom transformation

Notes

All propagation functions are JAX-compatible and support automatic differentiation. The lens functions can model both ideal and realistic optical elements with aberrations.

janssen.lenses.create_lens_phase(xx: Float[Array, 'hh ww'], yy: Float[Array, 'hh ww'], params: LensParams, wavelength: float | Float[Array, '']) tuple[Float[Array, 'hh ww'], Float[Array, 'hh ww']][source]

Create the phase profile and transmission mask for a lens.

Parameters:
  • xx (Float[Array, " hh ww"]) – X coordinates grid.

  • yy (Float[Array, " hh ww"]) – Y coordinates grid.

  • params (LensParams) – Lens parameters.

  • wavelength (float) – Wavelength of light.

Return type:

tuple[Float[Array, 'hh ww'], Float[Array, 'hh ww']]

Returns:

  • phase_profile (Float[Array, " hh ww"]) – Phase profile of the lens.

  • transmission (Float[Array, " hh ww"]) – Transmission mask of the lens.

Notes

  • Calculate radial coordinates.

  • Calculate thickness profile.

  • Calculate phase profile.

  • Create transmission mask.

  • Return phase and transmission.

janssen.lenses.double_concave_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], r_ratio: float | Float[Array, ''] | None = 1.0) LensParams[source]

Create parameters for a double concave lens.

Parameters:
  • focal_length (float) – Desired focal length.

  • diameter (float) – Lens diameter.

  • n (float) – Refractive index.

  • center_thickness (float) – Center thickness.

  • r_ratio (float, optional) – Ratio of R2/R1, by default 1.0 for symmetric lens.

Returns:

params – Lens parameters.

Return type:

LensParams

Notes

  • Calculate R1 using lensmaker’s equation.

  • Calculate R2 using R_ratio.

  • Create and return LensParams.

janssen.lenses.double_convex_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], r_ratio: float | Float[Array, ''] | None = 1.0) LensParams[source]

Create parameters for a double convex lens.

Parameters:
  • focal_length (float) – Desired focal length.

  • diameter (float) – Lens diameter.

  • n (float) – Refractive index.

  • center_thickness (float) – Center thickness.

  • r_ratio (float, optional) – Ratio of r2/r1, by default 1.0 for symmetric lens.

Returns:

params – Lens parameters.

Return type:

LensParams

Notes

  • Calculate r1 using lensmaker’s equation.

  • Calculate r2 using R_ratio.

  • Create and return LensParams.

janssen.lenses.lens_focal_length(n: float | Float[Array, ''], r1: int | float | complex | Num[Array, ''], r2: int | float | complex | Num[Array, '']) float | Float[Array, ''][source]

Calculate the focal length of a lens using the lensmaker’s equation.

Parameters:
  • n (float) – Refractive index of the lens material.

  • r1 (scalar_numeric) – Radius of curvature of the first surface (positive for convex).

  • r2 (scalar_numeric) – Radius of curvature of the second surface (positive for convex).

Returns:

f – Focal length of the lens.

Return type:

float

Notes

  • Apply the lensmaker’s equation.

  • Return the calculated focal length.

janssen.lenses.lens_thickness_profile(r: Float[Array, 'H W'], r1: float | Float[Array, ''], r2: float | Float[Array, ''], center_thickness: float | Float[Array, ''], diameter: float | Float[Array, '']) Float[Array, 'H W'][source]

Calculate the thickness profile of a lens.

Parameters:
  • r (Float[Array, " H W"]) – Radial distance from the optical axis.

  • r1 (float) – Radius of curvature of the first surface.

  • r2 (float) – Radius of curvature of the second surface.

  • center_thickness (float) – Thickness at the center of the lens.

  • diameter (float) – Diameter of the lens.

Returns:

thickness – Thickness profile of the lens.

Return type:

Float[Array, " H W"]

Notes

  • Calculate surface sag for both surfaces

    only where aperture mask & r is finite.

  • Combine sags with center thickness.

  • Return thickness profile.

janssen.lenses.meniscus_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], r_ratio: float | Float[Array, ''], convex_first: bool | Bool[Array, ''] | None = True) LensParams[source]

Create parameters for a meniscus (concavo-convex) lens.

For a meniscus lens, one surface is convex (positive R) and one is concave (negative R).

Parameters:
  • focal_length (float) – Desired focal length in meters.

  • diameter (float) – Lens diameter in meters.

  • n (float) – Refractive index of lens material.

  • center_thickness (float) – Center thickness in meters.

  • r_ratio (float) – Absolute ratio of R2/R1.

  • convex_first (scalar_bool, optional) – If True, first surface is convex, by default True.

Returns:

params – Lens parameters.

Return type:

LensParams

Notes

  • Calculate magnitude of R1 using lensmaker’s equation.

  • Calculate R2 magnitude using R_ratio.

  • Assign correct signs based on convex_first.

  • Create and return LensParams.

janssen.lenses.plano_concave_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], concave_first: bool | Bool[Array, ''] | None = True) LensParams[source]

Create parameters for a plano-concave lens.

Parameters:
  • focal_length (float) – Desired focal length.

  • diameter (float) – Lens diameter.

  • n (float) – Refractive index.

  • center_thickness (float) – Center thickness.

  • concave_first (scalar_bool, optional) – If True, first surface is concave, by default True.

Returns:

params – Lens parameters.

Return type:

LensParams

Notes

  • Calculate R for curved surface.

  • Set other R to infinity (flat surface).

  • Create and return LensParams.

janssen.lenses.plano_convex_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], convex_first: bool | Bool[Array, ''] | None = True) LensParams[source]

Create parameters for a plano-convex lens.

Parameters:
  • focal_length (float) – Desired focal length.

  • diameter (float) – Lens diameter.

  • n (float) – Refractive index.

  • center_thickness (float) – Center thickness.

  • convex_first (scalar_bool, optional) – If True, first surface is convex, by default True.

Returns:

params – Lens parameters.

Return type:

LensParams

Notes

  • Calculate R for curved surface.

  • Set other R to infinity (flat surface).

  • Create and return LensParams.

janssen.lenses.propagate_through_lens(field: Complex[Array, 'hh ww'], phase_profile: Float[Array, 'hh ww'], transmission: Float[Array, 'hh ww']) Complex[Array, 'hh ww'][source]

Propagate a field through a lens.

Parameters:
  • field (Complex[Array, " hh ww"]) – Input complex field.

  • phase_profile (Float[Array, " hh ww"]) – Phase profile of the lens.

  • transmission (Float[Array, " hh ww"]) – Transmission mask of the lens.

Returns:

output_field – Field after passing through the lens.

Return type:

Complex[Array, " hh ww"]

Notes

  • Apply transmission mask.

  • Add phase profile.

  • Return modified field.

janssen.lenses.angular_spectrum_prop(incoming: OpticalWavefront, z_move: int | float | complex | Num[Array, ''], refractive_index: int | float | complex | Num[Array, ''] | None = 1.0) OpticalWavefront[source]

Propagate a complex field using the angular spectrum method.

Parameters:
  • incoming (OpticalWavefront) –

    PyTree with the following parameters:

    fieldComplex[Array, “ hh ww”]

    Input complex field

    wavelengthFloat[Array, “ “]

    Wavelength of light in meters

    dxFloat[Array, “ “]

    Grid spacing in meters

    z_positionFloat[Array, “ “]

    Wave front position in meters

  • z_move (scalar_numeric) – Propagation distance in meters This is in free space.

  • refractive_index (Optional[scalar_numeric], optional) – Index of refraction of the medium. Default is 1.0 (vacuum).

Returns:

Propagated wave front

Return type:

OpticalWavefront

Notes

Algorithm:

  • Get the shape of the input field

  • Calculate the wavenumber

  • Compute the path length

  • Create spatial frequency coordinates

  • Compute the squared spatial frequencies

  • Angular spectrum transfer function

  • Ensure evanescent waves are properly handled

  • Fourier transform of the input field

  • Apply the transfer function in the Fourier domain

  • Inverse Fourier transform to get the propagated field

  • Return the propagated field

janssen.lenses.digital_zoom(wavefront: OpticalWavefront, zoom_factor: int | float | complex | Num[Array, '']) OpticalWavefront[source]

Zoom an optical wavefront by a specified factor.

Key is this returns the same sized array as the original wavefront.

Parameters:
  • wavefront (OpticalWavefront) – Incoming optical wavefront.

  • zoom_factor (scalar_numeric) – Zoom factor (greater than 1 to zoom in, less than 1 to zoom out).

Returns:

zoomed_wavefront – Zoomed optical wavefront of the same spatial dimensions.

Return type:

OpticalWavefront

Notes

Algorithm:

For zoom in (zoom_factor >= 1.0): - Calculate the crop fraction (1 / zoom_factor) to determine the

central region to extract

  • Create interpolation coordinates for the zoomed region centered

    on the image

  • Use scipy.ndimage.map_coordinates with bilinear interpolation

    to sample the field

  • Return the zoomed field with adjusted pixel size (dx / zoom_factor)

For zoom out (zoom_factor < 1.0): - Calculate the shrink fraction (zoom_factor) to determine the

final image size

  • Create a coordinate mapping from the full image to the shrunken region

  • Use scipy.ndimage.map_coordinates to interpolate the original field

  • Apply a mask to zero out regions outside the shrunken

    area (padding effect)

  • Return the zoomed field with adjusted pixel size (dx / zoom_factor)

janssen.lenses.fraunhofer_prop(incoming: OpticalWavefront, z_move: float | Float[Array, ''], refractive_index: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]

Propagate a complex field using the Fraunhofer approximation.

Parameters:
  • incoming (OpticalWavefront) –

    PyTree with the following parameters:

    fieldComplex[Array, “ hh ww”]

    Input complex field

    wavelengthFloat[Array, “ “]

    Wavelength of light in meters

    dxFloat[Array, “ “]

    Grid spacing in meters

    z_positionFloat[Array, “ “]

    Wave front position in meters

  • z_move (float) – Propagation distance in meters. This is in free space.

  • refractive_index (float, optional) – Index of refraction of the medium. Default is 1.0 (vacuum).

Returns:

Propagated wave front

Return type:

OpticalWavefront

Notes

Algorithm:

  • Get the shape of the input field

  • Calculate the spatial frequency coordinates

  • Create the meshgrid of spatial frequencies

  • Compute the transfer function for Fraunhofer propagation

  • Compute the Fourier transform of the input field

  • Apply the transfer function in the Fourier domain

  • Inverse Fourier transform to get the propagated field

  • Return the propagated field

janssen.lenses.fresnel_prop(incoming: OpticalWavefront, z_move: int | float | complex | Num[Array, ''], refractive_index: int | float | complex | Num[Array, ''] | None = 1.0) OpticalWavefront[source]

Propagate a complex field using the Fresnel approximation.

Parameters:
  • incoming (OpticalWavefront) –

    PyTree with the following parameters:

    fieldComplex[Array, “ hh ww”]

    Input complex field

    wavelengthFloat[Array, “ “]

    Wavelength of light in meters

    dxFloat[Array, “ “]

    Grid spacing in meters

    z_positionFloat[Array, “ “]

    Wave front position in meters

  • z_move (scalar_numeric) – Propagation distance in meters This is in free space.

  • refractive_index (Optional[scalar_numeric], optional) – Index of refraction of the medium. Default is 1.0 (vacuum).

Returns:

Propagated wave front

Return type:

OpticalWavefront

Notes

Algorithm:

  • Calculate the wavenumber

  • Create spatial coordinates

  • Quadratic phase factor for Fresnel approximation

    (pre-free-space propagation)

  • Apply quadratic phase to the input field

  • Compute Fourier transform of the input field

  • Compute spatial frequency coordinates

  • Transfer function for Fresnel propagation

  • Apply the transfer function in the Fourier domain

  • Inverse Fourier transform to get the propagated field

  • Final quadratic phase factor (post-free-space propagation)

  • Apply final quadratic phase factor

  • Return the propagated field

janssen.lenses.lens_propagation(incoming: OpticalWavefront, lens: LensParams) OpticalWavefront[source]

Propagate an optical wavefront through a lens.

The lens is modeled as a thin lens with a given focal length and diameter.

Parameters:
  • incoming (OpticalWavefront) – The incoming optical wavefront

  • lens (LensParams) – The lens parameters including focal length and diameter

Returns:

The propagated optical wavefront after passing through the lens

Return type:

OpticalWavefront

Notes

Algorithm:

  • Create a meshgrid of coordinates based on the incoming wavefront’s

    shape and pixel size.

  • Calculate the phase profile and transmission function of the lens.

  • Apply the phase screen to the incoming wavefront’s field.

  • Return the new optical wavefront with the updated field, wavelength,

    and pixel size.

janssen.lenses.optical_zoom(wavefront: OpticalWavefront, zoom_factor: int | float | complex | Num[Array, '']) OpticalWavefront[source]

Modify the calibration of an optical wavefront without changing field.

Parameters:
  • wavefront (OpticalWavefront) – Incoming optical wavefront.

  • zoom_factor (scalar_numeric) – Zoom factor (greater than 1 to zoom in, less than 1 to zoom out).

Returns:

Zoomed optical wavefront of the same spatial dimensions.

Return type:

OpticalWavefront

janssen.simul

Differentiable optical simulation toolkit.

Extended Summary

Comprehensive optical simulation framework for modeling light propagation through various optical elements. All components are differentiable and optimized for JAX transformations, enabling gradient-based optimization of optical systems.

Submodules

apertures

Aperture functions for optical microscopy

elements

Optical element transformations

microscope

Microscopy simulation pipelines

helper

Helper functions for optical propagation

Routine Listings

annular_aperturefunction

Create an annular (ring-shaped) aperture

circular_aperturefunction

Create a circular aperture

gaussian_apodizerfunction

Apply Gaussian apodization to a field

gaussian_apodizer_ellipticalfunction

Apply elliptical Gaussian apodization

rectangular_aperturefunction

Create a rectangular aperture

supergaussian_apodizerfunction

Apply super-Gaussian apodization

supergaussian_apodizer_ellipticalfunction

Apply elliptical super-Gaussian apodization

variable_transmission_aperturefunction

Create aperture with variable transmission

amplitude_grating_binaryfunction

Create binary amplitude grating

apply_phase_maskfunction

Apply a phase mask to a field

apply_phase_mask_fnfunction

Apply a phase mask function

beam_splitterfunction

Model beam splitter operation

half_waveplatefunction

Half-wave plate transformation

mirror_reflectionfunction

Model mirror reflection

nd_filterfunction

Neutral density filter

phase_grating_blazed_ellipticalfunction

Elliptical blazed phase grating

phase_grating_sawtoothfunction

Sawtooth phase grating

phase_grating_sinefunction

Sinusoidal phase grating

polarizer_jonesfunction

Jones matrix for polarizer

prism_phase_rampfunction

Phase ramp from prism

quarter_waveplatefunction

Quarter-wave plate transformation

waveplate_jonesfunction

General waveplate Jones matrix

add_phase_screenfunction

Add phase screen to field

create_spatial_gridfunction

Create computational spatial grid

field_intensityfunction

Calculate field intensity

normalize_fieldfunction

Normalize optical field

scale_pixelfunction

Scale pixel size in field

linear_interactionfunction

Linear light-matter interaction

simple_diffractogramfunction

Generate diffraction pattern

simple_microscopefunction

Simple microscope forward model

Notes

All simulation functions support automatic differentiation and can be composed to model complex optical systems. The toolkit is optimized for both forward simulation and inverse problems in optics.

janssen.simul.annular_aperture(incoming: OpticalWavefront, inner_diameter: float | Float[Array, ''], outer_diameter: float | Float[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]

Apply an annular (ring) aperture with inner and outer diameters.

Parameters:
  • incoming (OpticalWavefront) – Input wavefront PyTree.

  • inner_diameter (float) – Inner blocked diameter in meters.

  • outer_diameter (float) – Outer clear aperture diameter in meters.

  • center (Optional[Float[Array, " 2"]], optional) – Ring center [x0, y0] in meters, by default [0, 0].

  • transmittivity (Optional[scalar_float], optional) – Uniform transmittivity in the ring (0..1), by default 1.0.

Returns:

apertured – Wavefront after applying the annular aperture.

Return type:

OpticalWavefront

Notes

  • Build centered (x, y) grids in meters.

  • Compute radial distance from center.

  • Create mask for inner_radius < r <= outer_radius.

  • Multiply by transmittivity (clipped), apply, and return.

janssen.simul.circular_aperture(incoming: OpticalWavefront, diameter: float | Float[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]

Apply a circular aperture to the incoming wavefront.

The aperture is defined by its physical diameter and (optional) center.

Parameters:
  • incoming (OpticalWavefront) – Input wavefront PyTree.

  • diameter (float) – Aperture diameter in meters.

  • center (Optional[Float[Array, " 2"]], optional) – Physical center [x0, y0] of the aperture in meters, by default [0, 0].

  • transmittivity (Optional[scalar_float], optional) – Uniform transmittivity inside the aperture (0..1), by default 1.0.

Returns:

apertured – Wavefront after applying the circular aperture.

Return type:

OpticalWavefront

Notes

  • Build centered (x, y) grids in meters.

  • Compute radial distance from the specified center.

  • Create a binary mask for r <= diameter/2.

  • Multiply by transmittivity (clipped to [0, 1]).

  • Apply to the complex field and return.

janssen.simul.gaussian_apodizer(incoming: OpticalWavefront, sigma: float | Float[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]

Apply a Gaussian apodizer (smooth transmission mask) to the wavefront.

Parameters:
  • incoming (OpticalWavefront) – Input optical wavefront.

  • sigma (float) – Gaussian width parameter in meters.

  • center (Optional[Float[Array, " 2"]], optional) – Physical center [x0, y0] of the Gaussian in meters, by default [0, 0].

  • peak_transmittivity (Optional[scalar_float], optional) – Maximum transmission at the Gaussian center, by default 1.0.

Returns:

apertured – Wavefront after applying Gaussian apodization.

Return type:

OpticalWavefront

Notes

  • Build centered (x, y) grids.

  • Compute squared radial distance from center.

  • Evaluate Gaussian exp(-r^2 / (2*sigma^2)).

  • Scale by peak transmittivity, clip to [0,1].

  • Multiply with incoming field and return.

janssen.simul.gaussian_apodizer_elliptical(incoming: OpticalWavefront, sigma_x: float | Float[Array, ''], sigma_y: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0, center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]

Apply an elliptical Gaussian apodizer to the wavefront.

With optional rotation, through an angle theta.

Parameters:
  • incoming (OpticalWavefront) – Input optical wavefront.

  • sigma_x (float) – Gaussian width along the x’-axis (meters) after rotation by theta.

  • sigma_y (float) – Gaussian width along the y’-axis (meters) after rotation by theta.

  • theta (Optional[scalar_float], optional) – Rotation angle in radians (counter-clockwise), by default 0.0.

  • center (Optional[Float[Array, " 2"]], optional) – Physical center [x0, y0] in meters, by default [0, 0].

  • peak_transmittivity (Optional[scalar_float], optional) – Maximum transmission at the center, by default 1.0.

Returns:

apertured – Wavefront after applying elliptical Gaussian apodization.

Return type:

OpticalWavefront

Notes

  • Build centered (x, y) grids.

  • Translate by center, rotate by theta → (x’, y’).

  • Evaluate exp(-0.5 * ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )).

  • Scale by peak_transmittivity, clip to [0, 1].

  • Multiply with incoming field and return.

janssen.simul.rectangular_aperture(incoming: OpticalWavefront, width: float | Float[Array, ''], height: float | Float[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]

Apply an axis-aligned rectangular aperture to the incoming wavefront.

Parameters:
  • incoming (OpticalWavefront) – Input wavefront PyTree.

  • width (float) – Rectangle width along x in meters.

  • height (float) – Rectangle height along y in meters.

  • center (Optional[Float[Array, " 2"]], optional) – Rectangle center [x0, y0] in meters, by default [0, 0].

  • transmittivity (Optional[scalar_float], optional) – Uniform transmittivity inside the rectangle (0..1), by default 1.0.

Returns:

apertured – Wavefront after applying the rectangular aperture.

Return type:

OpticalWavefront

Notes

  • Build centered (x, y) grids in meters.

  • Compute half-width/half-height and an inside-rectangle mask.

  • Multiply by transmittivity (clipped).

  • Apply to the complex field and return.

janssen.simul.supergaussian_apodizer(incoming: OpticalWavefront, sigma: float | Float[Array, ''], m: int | float | complex | Num[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]

Apply a super-Gaussian apodizer to the wavefront.

Transmission profile: exp(- (r^2 / sigma^2)^m ).

Parameters:
  • incoming (OpticalWavefront) – Input optical wavefront.

  • sigma (float) – Width parameter in meters (sets the roll-off scale).

  • m (scalar_numeric) – Super-Gaussian order (m=1 → Gaussian, m>1 → flatter top).

  • center (Optional[Float[Array, " 2"]], optional) – Physical center [x0, y0] of the profile, by default [0, 0].

  • peak_transmittivity (Optional[scalar_float], optional) – Maximum transmission at the center, by default 1.0.

Returns:

apertured – Wavefront after applying super-Gaussian apodization.

Return type:

OpticalWavefront

Notes

  • Build centered (x, y) grids.

  • Compute squared radial distance from center.

  • Evaluate exp(- (r^2 / sigma^2)^m ).

  • Scale by peak transmittivity, clip to [0,1].

  • Multiply with incoming field and return.

janssen.simul.supergaussian_apodizer_elliptical(incoming: OpticalWavefront, sigma_x: float | Float[Array, ''], sigma_y: float | Float[Array, ''], m: int | float | complex | Num[Array, ''], theta: float | Float[Array, ''] | None = 0.0, center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]

Apply an elliptical super-Gaussian apodizer with optional rotation.

Transmission profile: exp( - ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )^m ).

Parameters:
  • incoming (OpticalWavefront) – Input optical wavefront.

  • sigma_x (float) – Width along x’ (meters) after rotation by theta.

  • sigma_y (float) – Width along y’ (meters) after rotation by theta.

  • m (scalar_numeric) – Super-Gaussian order (m=1 → Gaussian; m>1 → flatter top, sharper edges).

  • theta (Optional[scalar_float], optional) – Rotation angle in radians (counter-clockwise), by default 0.0.

  • center (Optional[Float[Array, " 2"]], optional) – Physical center [x0, y0] in meters, by default [0, 0].

  • peak_transmittivity (Optional[scalar_float], optional) – Maximum transmission at the center, by default 1.0.

Returns:

apertured – Wavefront after applying elliptical super-Gaussian apodization.

Return type:

OpticalWavefront

Notes

  • Build centered (x, y) grids.

  • Translate by center, rotate by theta → (x’, y’).

  • Evaluate exp( - ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )^m ).

  • Scale by peak_transmittivity, clip to [0, 1].

  • Multiply with incoming field and return.

janssen.simul.variable_transmission_aperture(incoming: OpticalWavefront, transmission: float | Float[Array, ''] | Float[Array, '...']) OpticalWavefront[source]

Apply an arbitrary (spatially varying) transmission to the wavefront.

Parameters:
  • incoming (OpticalWavefront) – Input wavefront PyTree.

  • transmission (Union[scalar_float, Float[Array, " H W"]]) – Precomputed transmission map (0..1) with shape “H W”, or a scalar attenuation factor for uniform transmission.

Returns:

apertured – Wavefront after applying the transmission.

Return type:

OpticalWavefront

Examples

Uniform attenuation:

>>> wf2 = variable_transmission_aperture(wf, 0.5)  # 50% transmission

Spatially varying transmission:

>>> tmap = create_transmission_map(...)  # Shape (H, W)
>>> wf2 = variable_transmission_aperture(wf, tmap)

Notes

  • For scalar transmission: applies uniform attenuation.

  • For array transmission: applies spatially varying transmission map.

  • Transmission values are clipped to [0, 1].

  • This function is fully JAX-compatible and uses jax.lax.cond.

janssen.simul.amplitude_grating_binary(incoming: OpticalWavefront, period: float | Float[Array, ''], duty_cycle: float | Float[Array, ''] | None = 0.5, theta: float | Float[Array, ''] | None = 0.0, trans_high: float | Float[Array, ''] | None = 1.0, trans_low: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]

Binary amplitude grating with given duty cycle.

Parameters:
  • incoming (OpticalWavefront) – Input field.

  • period (float) – Period in meters.

  • duty_cycle (float, optional) – Fraction of period in ‘high’ state (0..1), by default 0.5.

  • theta (float, optional) – Orientation (radians), by default 0.0.

  • trans_high (float, optional) – Amplitude transmittance for ‘high’ bars, by default 1.0.

  • trans_low (float, optional) – Amplitude transmittance for ‘low’ bars, by default 0.0.

Returns:

Field after amplitude modulation.

Return type:

OpticalWavefront

Notes

  • Compute u along grating direction.

  • Map u modulo period → binary mask via duty cycle.

  • Apply amplitude levels to field.

janssen.simul.apply_phase_mask(incoming: OpticalWavefront, phase_map: Float[Array, 'H W']) OpticalWavefront[source]

Apply an arbitrary phase mask (e.g., SLM, turbulence screen).

Field_out = field_in * exp(i * phase_map).

Parameters:
  • incoming (OpticalWavefront) – Input field.

  • phase_map (Float[Array, " H W"]) – Phase in radians, same spatial shape as field.

Returns:

Field with added phase.

Return type:

OpticalWavefront

janssen.simul.apply_phase_mask_fn(incoming: OpticalWavefront, phase_fn: Callable[[Float[Array, 'H W'], Float[Array, 'H W']], Float[Array, 'H W']]) OpticalWavefront[source]

Build and apply a phase mask from a callable phase_fn(xx, yy).

Parameters:
  • incoming (OpticalWavefront) – Input field.

  • phase_fn (callable) – Function producing a phase map (radians) given centered grids xx, yy (meters).

Returns:

Field with added phase.

Return type:

OpticalWavefront

janssen.simul.beam_splitter(incoming: OpticalWavefront, t2: float | Float[Array, ''] | None = 0.5, r2: float | Float[Array, ''] | None = 0.5, normalize: bool | None = True) tuple[OpticalWavefront, OpticalWavefront][source]

Split an input field into transmitted and reflected components.

Parameters:
  • incoming (OpticalWavefront) – Input wavefront (scalar field).

  • t2 (float, optional) – Complex transmission amplitude, by default jnp.sqrt(0.5).

  • r2 (float, optional) – Complex reflection amplitude. Default 1j * jnp.sqrt(0.5) for 50/50 convention.

  • normalize (bool, optional) – If True, scale (t, r) so that |t|^2 + |r|^2 = 1, by default True.

Return type:

tuple[OpticalWavefront, OpticalWavefront]

Returns:

  • wf_T (OpticalWavefront) – Transmitted arm (t * field).

  • wf_R (OpticalWavefront) – Reflected arm (r * field).

Notes

  • Optionally renormalize (t, r).

  • Multiply field by t and r.

  • Return two wavefronts sharing same metadata.

janssen.simul.half_waveplate(incoming: OpticalWavefront, theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]

Apply a half-wave plate (δ = π) with fast-axis angle theta.

Parameters:
  • incoming (OpticalWavefront) – Vector field Complex[H, W, 2] (Jones: ex, ey).

  • theta (float, optional) – Fast-axis angle in radians (CCW from x), by default 0.0.

Returns:

hw_wavefront – Retarded field after half-wave plate.

Return type:

OpticalWavefront

Notes

Call waveplate_jones with delta = π.

janssen.simul.mirror_reflection(incoming: OpticalWavefront, flip_x: bool | None = True, flip_y: bool | None = False, add_pi_phase: bool | None = True, conjugate: bool | None = True) OpticalWavefront[source]

Mirror reflection: coordinate flips with optional π-phase and conjugation.

Parameters:
  • incoming (OpticalWavefront) – Input wavefront.

  • flip_x (bool, optional) – Flip along x-axis (columns), by default True.

  • flip_y (bool, optional) – Flip along y-axis (rows), by default False.

  • add_pi_phase (bool, optional) – Multiply by exp(i*pi) = -1 to simulate phase inversion on reflection. Default True.

  • conjugate (bool, optional) – Conjugate the complex field, useful when reversing propagation direction. Default is True.

Returns:

Reflected wavefront.

Return type:

OpticalWavefront

Notes

  • Flip axes as requested (jnp.flip).

  • Optional complex conjugation.

  • Optional -1 factor for π phase.

janssen.simul.nd_filter(incoming: OpticalWavefront, optical_density: float | Float[Array, ''] | None = 0.0, transmittance: float | Float[Array, ''] | None = -1.0) OpticalWavefront[source]

Neutral density (ND) filter as a uniform amplitude attenuator.

Parameters:
  • incoming (OpticalWavefront) – Input field.

  • optical_density (float, optional) – OD; intensity transmittance T = 10^(-OD). If given, overrides transmittance. Default is 0.0.

  • transmittance (float, optional) – Intensity transmittance T in [0, 1]. Used if optical_density is 0.

Returns:

nd_wavefront – Attenuated wavefront.

Return type:

OpticalWavefront

Notes

  • Determine intensity T from OD or provided T.

  • Amplitude factor a = sqrt(T).

  • Multiply field by a and return.

janssen.simul.phase_grating_blazed_elliptical(incoming: OpticalWavefront, period_x: float | Float[Array, ''], period_y: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0, depth: float | Float[Array, ''] | None = 6.283185307179586, two_dim: bool | None = False) OpticalWavefront[source]

Orientation-aware elliptical blazed grating.

Supports anisotropic periods along rotated axes (x’, y’) and optional 2D blaze.

Parameters:
  • incoming (OpticalWavefront) – Input scalar wavefront.

  • period_x (float) – Blaze period along x’ in meters (after rotation by theta).

  • period_y (float) – Blaze period along y’ in meters (after rotation by theta).

  • theta (float, optional) – Grating orientation angle in radians (CCW from x), by default 0.0.

  • depth (float, optional) – Peak-to-peak phase depth in radians, by default 2π.

  • two_dim (bool, optional) – If False (default), apply a 1D blaze along x’ only. If True, create a 2D blazed lattice using both x’ and y’.

Returns:

phase_grating_wavefront – Field after applying the elliptical blazed phase.

Return type:

OpticalWavefront

Notes

  • Build centered grids xx, yy (meters) and rotate → (x’, y’).

  • Compute fractional coordinates

    ..math:: fu = frac(x’/period_x) fv = frac(y’/period_y)

  • if two_dim is True

    ..math:: phase = depth * frac(fu + fv)

    else,

    ..math:: phase = depth * fu

  • Multiply by exp(i * phase) and return.

janssen.simul.phase_grating_sawtooth(incoming: OpticalWavefront, period: float | Float[Array, ''], depth: float | Float[Array, ''], theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]

Sawtooth phase grating with peak-to-peak depth (radians).

Parameters:
  • incoming (OpticalWavefront) – Input field.

  • period (float) – Grating period in meters.

  • depth (float) – Phase depth over one period in radians.

  • theta (float, optional) – Orientation (radians), by default 0.0.

Returns:

Field after blazed phase modulation.

Return type:

OpticalWavefront

Notes

  • Compute fractional coordinate within each period.

  • Sawtooth phase in [0, depth) → shift to mean-zero if desired

    (kept at [0, depth)).

  • Apply phase with exp(i*phase).

janssen.simul.phase_grating_sine(incoming: OpticalWavefront, period: float | Float[Array, ''], depth: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]

Sinusoidal phase grating.

Phase = depth * sin(2π * u / period), where u is the coordinate along the grating direction.

Parameters:
  • incoming (OpticalWavefront) – Input field.

  • period (float) – Grating period in meters.

  • depth (float) – Phase modulation depth in radians.

  • theta (float, optional) – Grating orientation (radians, CCW from x), by default 0.0.

Returns:

Field after phase modulation.

Return type:

OpticalWavefront

janssen.simul.polarizer_jones(incoming: OpticalWavefront, theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]

Linear polarizer at angle theta (radians, CCW from x-axis).

Applied to a 2-component Jones field (ex, ey) stored in the last dimension.

Parameters:
  • incoming (OpticalWavefront) – Field shape must be Complex[H, W, 2].

  • theta (float, optional) – Transmission axis angle (radians), by default 0.0.

Returns:

Polarized field with same shape.

Return type:

OpticalWavefront

Notes

  • Jones matrix: P = R(-θ) @ [[1, 0],[0, 0]] @ R(θ).

  • Apply P to [ex, ey] at each pixel.

janssen.simul.prism_phase_ramp(incoming: OpticalWavefront, deflect_x: float | Float[Array, ''] | None = 0.0, deflect_y: float | Float[Array, ''] | None = 0.0, use_small_angle: bool | None = True) OpticalWavefront[source]

Apply a linear phase ramp to simulate a prism-induced beam deviation.

Parameters:
  • incoming (OpticalWavefront) – Input scalar wavefront.

  • deflect_x (float, optional) – Deflection along +x. If use_small_angle is True, interpreted as angle (rad). Otherwise interpreted as spatial frequency kx [rad/m], by default 0.0.

  • deflect_y (float, optional) – Deflection along +y (angle or ky), by default 0.0.

  • use_small_angle (bool, optional) – If True, convert small angles to kx, ky via k*sin(angle) ~ k*angle. Default True.

Returns:

Wavefront with added linear phase.

Return type:

OpticalWavefront

Notes

  • Build xx, yy grids (m).

  • Compute kx, ky from deflections.

  • Phase = kx*xx + ky*yy; multiply by exp(i*phase).

janssen.simul.quarter_waveplate(incoming: OpticalWavefront, theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]

Apply a quarter-wave plate (δ = π/2) with fast-axis angle theta.

Parameters:
  • incoming (OpticalWavefront) – Vector field Complex[H, W, 2] (Jones: ex, ey).

  • theta (float, optional) – Fast-axis angle in radians (CCW from x), by default 0.0.

Returns:

qw_wavefront – Retarded field after quarter-wave plate.

Return type:

OpticalWavefront

Notes

Call waveplate_jones with delta = π/2.

janssen.simul.waveplate_jones(incoming: OpticalWavefront, delta: float | Float[Array, ''], theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]

Waveplate/retarder with retardance delta and fast-axis angle theta.

Special cases: quarter-wave (delta=π/2), half-wave (delta=π).

Parameters:
  • incoming (OpticalWavefront) – Field shape must be Complex[H, W, 2].

  • delta (float) – Phase delay between fast and slow axes in radians.

  • theta (float, optional) – Fast-axis angle (radians, CCW from x), by default 0.0.

Returns:

jones_wavefront – Retarded field with same shape.

Return type:

OpticalWavefront

Notes

  • Jones matrix: J = R(-θ) @ diag(1, e^{iδ}) @ R(θ).

  • Apply J to [ex, ey] per pixel.

janssen.simul.add_phase_screen(field: Num[Array, 'hh ww'], phase: Float[Array, 'hh ww']) Complex[Array, 'H W'][source]

Add a phase screen to a complex field.

Parameters:
  • field (Num[Array, " hh ww"]) – Input complex field.

  • phase (Float[Array, " hh ww"]) – Phase screen to add.

Returns:

screened_field – Field with phase screen added.

Return type:

Complex[Array, " hh ww"]

Notes

  • Multiply the input field by the exponential of the phase screen.

  • Return the screened field.

janssen.simul.create_spatial_grid(diameter: Num[Array, ''], num_points: Int[Array, '']) tuple[Float[Array, 'nn nn'], Float[Array, 'nn nn']][source]

Create a 2D spatial grid for optical propagation.

Parameters:
  • diameter (Num[Array, " "]) – Physical size of the grid in meters.

  • num_points (Int[Array, " "]) – Number of points in each dimension.

Return type:

tuple[Float[Array, 'nn nn'], Float[Array, 'nn nn']]

Returns:

  • xx (Float[Array, " nn nn"]) – X coordinate grid in meters.

  • yy (Float[Array, " nn nn"]) – Y coordinate grid in meters.

Notes

  • Create a linear space of points along the x-axis.

  • Create a linear space of points along the y-axis.

  • Create a meshgrid of spatial coordinates.

  • Return the meshgrid.

janssen.simul.field_intensity(field: Complex[Array, 'hh ww']) Float[Array, 'hh ww'][source]

Calculate intensity from complex field.

Parameters:

field (Complex[Array, " hh ww"]) – Input complex field.

Returns:

intensity – Intensity of the field.

Return type:

Float[Array, " hh ww"]

Notes

  • Calculate the intensity as the square of the absolute value of the field.

  • Return the intensity.

janssen.simul.normalize_field(field: Complex[Array, 'hh ww']) Complex[Array, 'hh ww'][source]

Normalize complex field to unit power.

Parameters:

field (Complex[Array, " hh ww"]) – Input complex field.

Returns:

normalized_field – Normalized complex field.

Return type:

Complex[Array, " hh ww"]

Notes

  • Calculate the power of the field as the sum of the square of

    the absolute value of the field.

  • Normalize the field by dividing by the square root of the power.

  • Return the normalized field.

janssen.simul.scale_pixel(wavefront: OpticalWavefront, new_dx: float | Float[Array, '']) OpticalWavefront[source]

Rescale OpticalWavefront pixel size while keeping array shape fixed.

JAX-compatible (jit/vmap-safe). Crops or pads to preserve shape.

Parameters:
  • wavefront (OpticalWavefront) – OpticalWavefront to be resized.

  • new_dx (float) – New pixel size (meters).

Returns:

scaled_wavefront – Resized OpticalWavefront with updated pixel size and resized field, which is of the same size as the original field.

Return type:

OpticalWavefront

Notes

  • If the new pixel size is smaller than the old one, then the new FOV is smaller too at the same field size. So we will first find the new smaller FOV, and crop to that size with the current pixel size. Then we will resize to the new pizel size with the cropped FOV so that the size of the field remains the same. So here the order is crop, then resize.

  • If the new pixel size is larger than the old one, then the new FOV of the final field is larger too

  • Return the resized OpticalWavefront.

janssen.simul.linear_interaction(sample: SampleFunction, light: OpticalWavefront) OpticalWavefront[source]

Propagate an optical wavefront through a sample.

The sample is modeled as a complex function that modifies the incoming wavefront. Using linear interaction.

Parameters:
  • sample (SampleFunction) – The sample function representing the optical properties of the sample

  • light (OpticalWavefront) – The incoming optical wavefront

Returns:

The propagated optical wavefront after passing through the sample

Return type:

OpticalWavefront

janssen.simul.simple_diffractogram(sample_cut: SampleFunction, lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) Diffractogram[source]

Calculate the diffractogram of a sample using a simple model.

The lightwave interacts with the sample linearly, and is then zoomed optically. Following this it interacts with a circular aperture before propagating to the camera plane. The camera image is then scaled to the pixel size of the camera. The diffractogram is created from the camera image.

Parameters:
  • sample_cut (SampleFunction) – The sample function representing the optical properties of the sample

  • lightwave (OpticalWavefront) – The incoming optical wavefront

  • zoom_factor (float) – The zoom factor for the optical system

  • aperture_diameter (float) – The diameter of the aperture in meters

  • travel_distance (float) – The distance traveled by the light in meters

  • camera_pixel_size (float) – The pixel size of the camera in meters

  • aperture_center (Optional[Float[Array, " 2"]], optional) – The center of the aperture in pixels

Returns:

The calculated diffractogram of the sample

Return type:

Diffractogram

Notes

Algorithm:

  • Propagate the lightwave through the sample using linear interaction

  • Apply optical zoom to the wavefront

  • Apply a circular aperture to the zoomed wavefront

  • Propagate the wavefront to the camera plane using Fraunhofer propagation

  • Scale the pixel size of the camera image

  • Calculate the field intensity of the camera image

  • Create a diffractogram from the camera image

janssen.simul.simple_microscope(sample: SampleFunction, positions: Num[Array, 'n 2'], lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) MicroscopeData[source]

Calculate the 3D diffractograms of the entire imaging.

This cuts the sample, and then generates a diffractogram with the desired camera pixel size - all done in parallel. Done at every pixel positions.

Parameters:
  • sample (SampleFunction) – The sample function representing the optical properties of the sample

  • positions (Num[Array, " n 2"]) – The positions in the sample plane where the diffractograms are calculated.

  • lightwave (OpticalWavefront) – The incoming optical wavefront

  • zoom_factor (float) – The zoom factor for the optical system

  • aperture_diameter (float) – The diameter of the aperture in meters

  • travel_distance (float) – The distance traveled by the light in meters

  • camera_pixel_size (float) – The pixel size of the camera in meters

  • aperture_center (Optional[Float[Array, " 2"]], optional) – The center of the aperture in pixels

Returns:

The calculated diffractograms of the sample at the specified positions

Return type:

MicroscopeData

Notes

Algorithm:

  • Get the size of the lightwave field

  • Calculate the pixel positions in the sample plane

  • For each position, cut out the sample and calculate the diffractogram

  • Combine the diffractograms into a single MicroscopeData object

  • Return the MicroscopeData object

janssen.utils

Common utility functions used throughout the code.

Extended Summary

Core utilities for the janssen package including type definitions, factory functions, and decorators for type checking and validation. Provides the foundation for type-safe JAX programming with PyTrees.

Submodules

factory

Factory functions for creating data structures

types

Type definitions and PyTrees

Routine Listings

DiffractogramPyTree

PyTree for storing diffraction patterns

GridParamsPyTree

PyTree for computational grid parameters

LensParamsPyTree

PyTree for lens optical parameters

MicroscopeDataPyTree

PyTree for microscopy data

OpticalWavefrontPyTree

PyTree for optical wavefront representation

OptimizerStatePyTree

PyTree for optimizer state tracking

PtychographyParamsPyTree

PyTree for ptychography reconstruction parameters

SampleFunctionPyTree

PyTree for sample representation

make_diffractogramfunction

Factory function for Diffractogram creation

make_grid_paramsfunction

Factory function for GridParams creation

make_lens_paramsfunction

Factory function for LensParams creation

make_microscope_datafunction

Factory function for MicroscopeData creation

make_optical_wavefrontfunction

Factory function for OpticalWavefront creation

make_optimizer_statefunction

Factory function for OptimizerState creation

make_ptychography_paramsfunction

Factory function for PtychographyParams creation

make_sample_functionfunction

Factory function for SampleFunction creation

non_jax_numberTypeAlias

Type alias for Python numeric types

scalar_boolTypeAlias

Type alias for scalar boolean values

scalar_complexTypeAlias

Type alias for scalar complex values

scalar_floatTypeAlias

Type alias for scalar float values

scalar_integerTypeAlias

Type alias for scalar integer values

scalar_numericTypeAlias

Type alias for any scalar numeric value

Notes

Always use factory functions for creating PyTree instances to ensure proper type checking and validation. All PyTrees are registered with JAX and support automatic differentiation.

janssen.utils.make_diffractogram(image: Float[Array, 'hh ww'], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) Diffractogram[source]

JAX-safe factory function for Diffractogram with data validation.

Parameters:
  • image (Float[Array, " hh ww"]) – Image data

  • wavelength (float) – Wavelength of the optical wavefront in meters

  • dx (float) – Spatial sampling interval (grid spacing) in meters

Returns:

validated_diffractogram – Validated diffractogram instance

Return type:

Diffractogram

Raises:

ValueError – If data is invalid or parameters are out of valid ranges

Notes

Algorithm:

  • Convert inputs to JAX arrays

  • Validate image array:
    • Check it’s 2D

    • Ensure all values are finite and non-negative

  • Validate parameters:
    • Check wavelength is positive

    • Check dx is positive

  • Create and return Diffractogram instance

janssen.utils.make_grid_params(xx: Float[Array, 'hh ww'], yy: Float[Array, 'hh ww'], phase_profile: Float[Array, 'hh ww'], transmission: Float[Array, 'hh ww']) GridParams[source]

JAX-safe factory function for GridParams with data validation.

Parameters:
  • xx (Float[Array, " hh ww"]) – Spatial grid in the x-direction

  • yy (Float[Array, " hh ww"]) – Spatial grid in the y-direction

  • phase_profile (Float[Array, " hh ww"]) – Phase profile of the optical field

  • transmission (Float[Array, " hh ww"]) – Transmission profile of the optical field

Returns:

validated_grid_params – Validated grid parameters instance

Return type:

GridParams

Raises:

ValueError – If array shapes are inconsistent or data is invalid

Notes

Algorithm:

  • Convert inputs to JAX arrays

  • Validate array shapes:
    • Check all arrays are 2D

    • Check all arrays have the same shape

  • Validate data:
    • Ensure transmission values are between 0 and 1

    • Ensure phase values are finite

    • Ensure grid coordinates are finite

  • Create and return GridParams instance

janssen.utils.make_lens_params(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], r1: float | Float[Array, ''], r2: float | Float[Array, '']) LensParams[source]

JAX-safe factory function for LensParams with data validation.

Parameters:
  • focal_length (float) – Focal length of the lens in meters

  • diameter (float) – Diameter of the lens in meters

  • n (float) – Refractive index of the lens material

  • center_thickness (float) – Thickness at the center of the lens in meters

  • r1 (float) – Radius of curvature of the first surface in meters (positive for convex)

  • r2 (float) – Radius of curvature of the second surface in meters (positive for convex)

Returns:

validated_lens_params – Validated lens parameters instance

Return type:

LensParams

Raises:

ValueError – If parameters are invalid or out of valid ranges

Notes

Algorithm:

  • Convert inputs to JAX arrays

  • Validate parameters:
    • Check focal_length is positive

    • Check diameter is positive

    • Check refractive index is positive

    • Check center_thickness is positive

    • Check radii are finite

  • Create and return LensParams instance

janssen.utils.make_microscope_data(image_data: Float[Array, 'pp hh ww'] | Float[Array, 'xx yy hh ww'], positions: Num[Array, 'pp 2'], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) MicroscopeData[source]

JAX-safe factory function for MicroscopeData with data validation.

Parameters:
  • image_data (Union[Float[Array, " pp hh ww"], Float[Array, " xx yy hh ww"]]) – 3D or 4D image data representing the optical field

  • positions (Num[Array, " pp 2"]) – Positions of the images during collection

  • wavelength (float) – Wavelength of the optical wavefront in meters

  • dx (float) – Spatial sampling interval (grid spacing) in meters

Returns:

validated_microscope_data – Validated microscope data instance

Return type:

MicroscopeData

Raises:

ValueError – If data is invalid or parameters are out of valid ranges

Notes

Algorithm:

  • Convert inputs to JAX arrays

  • Validate image_data:
    • Check it’s 3D or 4D

    • Ensure all values are finite and non-negative

  • Validate positions:
    • Check it’s 2D with shape (pp, 2)

    • Ensure all values are finite

  • Validate parameters:
    • Check wavelength is positive

    • Check dx is positive

  • Validate consistency:
    • Check P matches between image_data and positions

  • Create and return MicroscopeData instance

janssen.utils.make_optical_wavefront(field: Complex[Array, 'hh ww'] | Complex[Array, 'hh ww 2'], wavelength: float | Float[Array, ''], dx: float | Float[Array, ''], z_position: float | Float[Array, ''], polarization: bool | Bool[Array, ''] | None = False) OpticalWavefront[source]

JAX-safe factory function for OpticalWavefront with data validation.

Parameters:
  • field (Union[Complex[Array, " hh ww"], Complex[Array, " hh ww 2"]]) – Complex amplitude of the optical field. Should be 2D for scalar fields or 3D with last dimension 2 for polarized fields.

  • wavelength (float) – Wavelength of the optical wavefront in meters

  • dx (float) – Spatial sampling interval (grid spacing) in meters

  • z_position (float) – Axial position of the wavefront in the propagation direction in meters.

  • polarization (scalar_bool, optional) – Whether the field is polarized (True for 3D field, False for 2D field). Default is False.

Returns:

validated_optical_wavefront – Validated optical wavefront instance

Return type:

OpticalWavefront

Raises:

ValueError – If data is invalid or parameters are out of valid ranges

Notes

Algorithm:

  • Convert inputs to JAX arrays

  • Validate field array:
    • Check it’s 2D

    • Ensure all values are finite

  • Validate parameters:
    • Check wavelength is positive

    • Check dx is positive

    • Check z_position is finite

  • Create and return OpticalWavefront instance

janssen.utils.make_optimizer_state(shape: tuple, m: Complex[Array, '...'] | complex | Complex[Array, ''] | None = 1j, v: Float[Array, '...'] | float | Float[Array, ''] | None = 0.0, step: int | Int[Array, ''] | None = 0) OptimizerState[source]

JAX-safe factory function for OptimizerState with data validation.

Parameters:
  • shape (Tuple) – Shape of the parameters to be optimized

  • m (Optional[Complex[Array, "..."]], optional) – First moment estimate. If None, initialized to zeros with given shape. Default is 1j.

  • v (Optional[Float[Array, "..."]], optional) – Second moment estimate. If None, initialized to zeros with given shape. Default is 0.0.

  • step (Optional[scalar_integer], optional) – Step count. Default is 0.

Returns:

validated_optimizer_state – Validated optimizer state instance

Return type:

OptimizerState

Raises:

ValueError – If arrays have incompatible shapes with the given shape parameter

Notes

Algorithm:

  • If m is None, initialize with complex zeros of given shape

  • If v is None, initialize with real zeros of given shape

  • If step is None, initialize to 0

  • Convert all inputs to JAX arrays with appropriate dtypes

  • Validate arrays have compatible shapes

  • Create and return OptimizerState instance

janssen.utils.make_ptychography_params(zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], aperture_center: Float[Array, '2'], camera_pixel_size: float | Float[Array, ''], learning_rate: float | Float[Array, ''], num_iterations: int | Int[Array, '']) PtychographyParams[source]

Create a PtychographyParams PyTree with validated parameters.

Parameters:
  • zoom_factor (float) – Optical zoom factor for magnification (must be positive)

  • aperture_diameter (float) – Diameter of the aperture in meters (must be positive)

  • travel_distance (float) – Light propagation distance in meters (must be positive)

  • aperture_center (Float[Array, " 2"]) – Center position of the aperture (x, y) in meters

  • camera_pixel_size (float) – Camera pixel size in meters (must be positive)

  • learning_rate (float) – Learning rate for optimization (must be positive)

  • num_iterations (scalar_integer) – Number of optimization iterations (must be positive)

Returns:

Validated ptychography parameters as a PyTree

Return type:

PtychographyParams

Notes

This function performs runtime validation to ensure all parameters are properly formatted and within valid ranges before creating the PtychographyParams PyTree.

janssen.utils.make_sample_function(sample: Complex[Array, 'hh ww'], dx: float | Float[Array, '']) SampleFunction[source]

JAX-safe factory function for SampleFunction with data validation.

Parameters:
  • sample (Complex[Array, " hh ww"]) – The sample function

  • dx (float) – Spatial sampling interval (grid spacing) in meters

Returns:

validated_sample_function – Validated sample function instance

Return type:

SampleFunction

Raises:

ValueError – If data is invalid or parameters are out of valid ranges

Notes

Algorithm:

  • Convert inputs to JAX arrays

  • Validate sample array:
    • Check it’s 2D

    • Ensure all values are finite

  • Validate parameters:
    • Check dx is positive

  • Create and return SampleFunction instance

class janssen.utils.Diffractogram(image: Float[Array, 'hh ww'], wavelength: Float[Array, ''], dx: Float[Array, ''])[source]

Bases: NamedTuple

PyTree structure for representing a single diffractogram.

image

Image data.

Type:

Float[Array, " hh ww"]

wavelength

Wavelength of the optical wavefront in meters.

Type:

Float[Array, " "]

dx

Spatial sampling interval (grid spacing) in meters.

Type:

Float[Array, " "]

image: Float[Array, 'hh ww']

Alias for field number 0

wavelength: Float[Array, '']

Alias for field number 1

dx: Float[Array, '']

Alias for field number 2

class janssen.utils.GridParams(xx: Float[Array, 'hh ww'], yy: Float[Array, 'hh ww'], phase_profile: Float[Array, 'hh ww'], transmission: Float[Array, 'hh ww'])[source]

Bases: NamedTuple

PyTree structure for computational grid parameters.

xx

Spatial grid in the x-direction

Type:

Float[Array, " hh ww"]

yy

Spatial grid in the y-direction

Type:

Float[Array, " hh ww"]

phase_profile

Phase profile of the optical field

Type:

Float[Array, " hh ww"]

transmission

Transmission profile of the optical field

Type:

Float[Array, " hh ww"]

Notes

This class is registered as a PyTree node, making it compatible with JAX transformations like jit, grad, and vmap. The auxiliary data in tree_flatten is None as all relevant data is stored in JAX arrays.

xx: Float[Array, 'hh ww']

Alias for field number 0

yy: Float[Array, 'hh ww']

Alias for field number 1

phase_profile: Float[Array, 'hh ww']

Alias for field number 2

transmission: Float[Array, 'hh ww']

Alias for field number 3

class janssen.utils.LensParams(focal_length: Float[Array, ''], diameter: Float[Array, ''], n: Float[Array, ''], center_thickness: Float[Array, ''], r1: Float[Array, ''], r2: Float[Array, ''])[source]

Bases: NamedTuple

PyTree structure for lens parameters.

focal_length

Focal length of the lens in meters

Type:

Float[Array, " "]

diameter

Diameter of the lens in meters

Type:

Float[Array, " "]

n

Refractive index of the lens material

Type:

Float[Array, " "]

center_thickness

Thickness at the center of the lens in meters

Type:

Float[Array, " "]

r1

Radius of curvature of the first surface in meters (positive for convex)

Type:

Float[Array, " "]

r2

Radius of curvature of the second surface in meters ( positive for convex)

Type:

Float[Array, " "]

focal_length: Float[Array, '']

Alias for field number 0

diameter: Float[Array, '']

Alias for field number 1

n: Float[Array, '']

Alias for field number 2

center_thickness: Float[Array, '']

Alias for field number 3

r1: Float[Array, '']

Alias for field number 4

r2: Float[Array, '']

Alias for field number 5

class janssen.utils.MicroscopeData(image_data: Float[Array, 'pp hh ww'] | Float[Array, 'xx yy hh ww'], positions: Num[Array, 'pp 2'], wavelength: Float[Array, ''], dx: Float[Array, ''])[source]

Bases: NamedTuple

PyTree structure for representing an 3D or 4D microscope image.

image_data

3D or 4D image data representing the optical field.

Type:

Float[Array, " pp hh ww"] | Float[Array, " xx yy hh ww"]

positions

Positions of the images during collection.

Type:

Num[Array, " pp 2"]

wavelength

Wavelength of the optical wavefront in meters.

Type:

Float[Array, " "]

dx

Spatial sampling interval (grid spacing) in meters.

Type:

Float[Array, " "]

image_data: Union[Float[Array, 'pp hh ww'], Float[Array, 'xx yy hh ww']]

Alias for field number 0

positions: Num[Array, 'pp 2']

Alias for field number 1

wavelength: Float[Array, '']

Alias for field number 2

dx: Float[Array, '']

Alias for field number 3

class janssen.utils.OpticalWavefront(field: Complex[Array, 'hh ww'] | Complex[Array, 'hh ww 2'], wavelength: Float[Array, ''], dx: Float[Array, ''], z_position: Float[Array, ''], polarization: Bool[Array, ''])[source]

Bases: NamedTuple

PyTree structure for representing an optical wavefront.

field

Complex amplitude of the optical field. Can be scalar (H, W) or polarized with two components (H, W, 2).

Type:

Union[Complex[Array, " hh ww"], Complex[Array, " hh ww 2"]]

wavelength

Wavelength of the optical wavefront in meters.

Type:

Float[Array, " "]

dx

Spatial sampling interval (grid spacing) in meters.

Type:

Float[Array, " "]

z_position

Axial position of the wavefront along the propagation direction. In meters.

Type:

Float[Array, " "]

polarization

Whether the field is polarized (True for 3D field, False for 2D field).

Type:

Bool[Array, " "]

field: Union[Complex[Array, 'hh ww'], Complex[Array, 'hh ww 2']]

Alias for field number 0

wavelength: Float[Array, '']

Alias for field number 1

dx: Float[Array, '']

Alias for field number 2

z_position: Float[Array, '']

Alias for field number 3

polarization: Bool[Array, '']

Alias for field number 4

class janssen.utils.OptimizerState(m: Complex[Array, '...'], v: Float[Array, '...'], step: Int[Array, ''])[source]

Bases: NamedTuple

PyTree structure for maintaining optimizer state.

m

First moment estimate (for Adam-like optimizers)

Type:

Complex[Array, "..."]

v

Second moment estimate (for Adam-like optimizers)

Type:

Float[Array, "..."]

step

Step count

Type:

Int[Array, " "]

m: ']

Alias for field number 0

v: ']

Alias for field number 1

step: Int[Array, '']

Alias for field number 2

class janssen.utils.PtychographyParams(zoom_factor: Float[Array, ''], aperture_diameter: Float[Array, ''], travel_distance: Float[Array, ''], aperture_center: Float[Array, '2'], camera_pixel_size: Float[Array, ''], learning_rate: Float[Array, ''], num_iterations: Int[Array, ''])[source]

Bases: NamedTuple

PyTree structure for ptychography reconstruction parameters.

zoom_factor

Optical zoom factor for magnification

Type:

Float[Array, " "]

aperture_diameter

Diameter of the aperture in meters

Type:

Float[Array, " "]

travel_distance

Light propagation distance in meters

Type:

Float[Array, " "]

aperture_center

Center position of the aperture (x, y) in meters

Type:

Float[Array, " 2"]

camera_pixel_size

Camera pixel size in meters (typically fixed)

Type:

Float[Array, " "]

learning_rate

Learning rate for optimization

Type:

Float[Array, " "]

num_iterations

Number of optimization iterations

Type:

Int[Array, " "]

Notes

This class encapsulates all the optical and optimization parameters used in ptychographic reconstruction. It is registered as a PyTree node to enable JAX transformations and gradient-based optimization of these parameters.

zoom_factor: Float[Array, '']

Alias for field number 0

aperture_diameter: Float[Array, '']

Alias for field number 1

travel_distance: Float[Array, '']

Alias for field number 2

aperture_center: Float[Array, '2']

Alias for field number 3

camera_pixel_size: Float[Array, '']

Alias for field number 4

learning_rate: Float[Array, '']

Alias for field number 5

num_iterations: Int[Array, '']

Alias for field number 6

class janssen.utils.SampleFunction(sample: Complex[Array, 'hh ww'], dx: Float[Array, ''])[source]

Bases: NamedTuple

PyTree structure for representing a sample function.

sample

The sample function.

Type:

Complex[Array, " hh ww"]

dx

Spatial sampling interval (grid spacing) in meters.

Type:

Float[Array, " "]

sample: Complex[Array, 'hh ww']

Alias for field number 0

dx: Float[Array, '']

Alias for field number 1