API Reference¶
janssen.invert¶
Inversion algorithms for phase retrieval and ptychography.
Extended Summary¶
Comprehensive algorithms for phase retrieval and ptychographic reconstruction using differentiable programming techniques. Includes various optimization strategies and loss functions for reconstructing complex-valued fields.
Submodules¶
- engine
Reconstruction engine
- ptychography
Ptychographic algorithms
- optimizers
Optimization routines
- loss_functions
Loss function definitions
Routine Listings¶
- create_loss_functionfunction
Factory function for creating various loss functions
- simple_microscope_ptychographyfunction
Main ptychography reconstruction algorithm using PtychographyParams
- epie_opticalfunction
Extended PIE algorithm for optical ptychography
- single_pie_iterationfunction
Single iteration of PIE algorithm
- single_pie_sequentialfunction
Sequential PIE implementation for multiple positions
- single_pie_vmap
Vectorized PIE implementation using vmap
- init_adamfunction
Initialize Adam optimizer state
- init_adagradfunction
Initialize Adagrad optimizer state
- init_rmspropfunction
Initialize RMSprop optimizer state
Notes
All functions are JAX-compatible and support automatic differentiation. The algorithms can be composed with JIT compilation for improved performance.
- janssen.invert.epie_optical(microscope_data: MicroscopeData, initial_object: OpticalWavefront, initial_surface: SampleFunction, pixel_mask: Float[Array, 'H W'], propagation_distance_1: float | Float[Array, ''], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], vmap_iterations: int | Int[Array, ''] | None = 0, alpha_object: float | Float[Array, ''] | None = 0.1, gamma_object: float | Float[Array, ''] | None = 0.5, alpha_surface: float | Float[Array, ''] | None = 0.1, gamma_surface: float | Float[Array, ''] | None = 0.5, num_loops: int | Int[Array, ''] | None = 10) tuple[OpticalWavefront, SampleFunction][source]¶
Reconstruct ptychography using the extended PIE algorithm.
- Parameters:
microscope_data (
MicroscopeData) – Measured intensity data with positions.initial_object (
OpticalWavefront) – Initial guess for object wavefront.initial_surface (
SampleFunction) – Initial guess for surface pattern.pixel_mask (
Float[Array," H W"]) – Pixel response mask for modeling sensor characteristics.propagation_distance_1 (
float) – Distance from object to diffuser plane in meters.propagation_distance_2 (
float) – Distance from diffuser to sensor plane in meters.magnification (
scalar_integer) – Magnification factor for downsampling.vmap_iterations (
scalar_integer, optional) – Number of initial iterations to run in vmap mode for rapid convergence. If 0, use sequential mode for all iterations. If > 0, use vmap for first N iterations, then switch to sequential. Default is 0.alpha_object (
float, optional) – Object update mixing parameter. Default is 0.1.gamma_object (
float, optional) – Object update step size. Default is 0.5.alpha_surface (
float, optional) – Surface update mixing parameter. Default is 0.1.gamma_surface (
float, optional) – Surface update step size. Default is 0.5.num_loops (
scalar_integer, optional) – Number of iteration loops. Default is 10.
- Returns:
- recovered_objectOpticalWavefront
Reconstructed object wavefront.
- recovered_surfaceSampleFunction
Reconstructed surface pattern.
- Return type:
tupleof(OpticalWavefront,SampleFunction)
Notes
Algorithm: - Compute image data - Compute positions - Compute frequency grids - Compute object recovery propagation field - Compute surface pattern - Define loop body - Apply fori_loop over loops - Compute final object field - Compute final object wavefront - Compute final surface pattern - Return final object and surface
- janssen.invert.single_pie_iteration(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], measurement: Float[Array, 'H W'], position: Float[Array, '2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]¶
Single iteration of the extended PIE algorithm.
- Parameters:
object_prop_ft (
Complex[Array," H W"]) – Object wavefront at diffuser plane in Fourier domain.surface_pattern (
Complex[Array," H W"]) – Surface pattern function.measurement (
Float[Array," H W"]) – Measured intensity at current position.position (
Float[Array," 2"]) – Current scanning position [x, y].frequency_x_grid (
Float[Array," H W"]) – Frequency grid in x direction.frequency_y_grid (
Float[Array," H W"]) – Frequency grid in y direction.pixel_mask (
Float[Array," H W"]) – Pixel response mask for sensor modeling.propagation_distance_2 (
float) – Distance from diffuser to sensor.magnification (
scalar_integer) – Downsampling magnification factor.alpha_object (
float) – Object update mixing parameter.gamma_object (
float) – Object update step size.alpha_surface (
float) – Surface update mixing parameter.gamma_surface (
float) – Surface update step size.wavelength (
float) – Wavelength of light in meters.dx (
float) – Pixel spacing in meters.
- Returns:
- updated_object_ftComplex[Array, “ H W”]
Updated object wavefront in Fourier domain.
- updated_surfaceComplex[Array, “ H W”]
Updated surface pattern.
- Return type:
tupleof(Complex[Array," H W"],Complex[Array," H W"])
Notes
Algorithm: - Compute object shifted - Compute surface plane - Compute surface propagation kernel - Compute sensor plane - Compute sensor intensity - Compute ratio map - Compute ratio map upsampled - Compute sensor plane new - Compute sensor plane new in Fourier domain - Compute CTF conjugate - Compute CTF maximum squared - Compute surface propagation kernel - Compute updated surface pattern - Compute updated object wavefront - Compute updated object wavefront in Fourier domain - Return updated object and surface
- janssen.invert.single_pie_sequential(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], image_data: Float[Array, 'P H W'], positions: Float[Array, 'P 2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]¶
Sequential processing over positions using fori_loop for proper PIE convergence.
- Parameters:
object_prop_ft (
Complex[Array," H W"]) – Current object wavefront in Fourier domain.surface_pattern (
Complex[Array," H W"]) – Current surface pattern.image_data (
Float[Array," P H W"]) – Measurement data for all positions.positions (
Float[Array," P 2"]) – Position coordinates for all measurements.frequency_x_grid (
Float[Array," H W"]) – Frequency grid in x direction.frequency_y_grid (
Float[Array," H W"]) – Frequency grid in y direction.pixel_mask (
Float[Array," H W"]) – Pixel response mask for sensor modeling.propagation_distance_2 (
float) – Distance from diffuser to sensor.magnification (
scalar_integer) – Downsampling magnification factor.alpha_object (
float) – Object update mixing parameter.gamma_object (
float) – Object update step size.alpha_surface (
float) – Surface update mixing parameter.gamma_surface (
float) – Surface update step size.wavelength (
float) – Wavelength of light in meters.dx (
float) – Pixel spacing in meters.
- Returns:
Updated object and surface state after sequential processing.
- Return type:
tupleof(Complex[Array," H W"],Complex[Array," H W"])
Notes
Algorithm: - Compute number of positions - Define position body - Apply fori_loop over positions - Return updated state
- janssen.invert.single_pie_vmap(object_prop_ft: Complex[Array, 'H W'], surface_pattern: Complex[Array, 'H W'], image_data: Float[Array, 'P H W'], positions: Float[Array, 'P 2'], frequency_x_grid: Float[Array, 'H W'], frequency_y_grid: Float[Array, 'H W'], pixel_mask: Float[Array, 'H W'], propagation_distance_2: float | Float[Array, ''], magnification: int | Int[Array, ''], alpha_object: float | Float[Array, ''], gamma_object: float | Float[Array, ''], alpha_surface: float | Float[Array, ''], gamma_surface: float | Float[Array, ''], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) tuple[Complex[Array, 'H W'], Complex[Array, 'H W']][source]¶
Parallel processing over positions using vmap for faster but approximate PIE.
All positions use the same initial state, then updates are averaged.
- Parameters:
object_prop_ft (
Complex[Array," H W"]) – Current object wavefront in Fourier domain.surface_pattern (
Complex[Array," H W"]) – Current surface pattern.image_data (
Float[Array," P H W"]) – Measurement data for all positions.positions (
Float[Array," P 2"]) – Position coordinates for all measurements.frequency_x_grid (
Float[Array," H W"]) – Frequency grid in x direction.frequency_y_grid (
Float[Array," H W"]) – Frequency grid in y direction.pixel_mask (
Float[Array," H W"]) – Pixel response mask for sensor modeling.propagation_distance_2 (
float) – Distance from diffuser to sensor.magnification (
scalar_integer) – Downsampling magnification factor.alpha_object (
float) – Object update mixing parameter.gamma_object (
float) – Object update step size.alpha_surface (
float) – Surface update mixing parameter.gamma_surface (
float) – Surface update step size.wavelength (
float) – Wavelength of light in meters.dx (
float) – Pixel spacing in meters.
- Returns:
Updated object and surface state after parallel processing and averaging.
- Return type:
tupleof(Complex[Array," H W"],Complex[Array," H W"])
Notes
Algorithm: - Apply vmap over all positions using same initial state - Compute average of all object updates - Compute average of all surface updates - Return averaged states
- janssen.invert.create_loss_function(forward_function: Callable[[...], Array], experimental_data: Array, loss_type: str = 'mae') Callable[[...], Float[Array, '']][source]¶
Create a JIT-compatible loss function.
This function returns a new function that computes the loss between the output of a forward model and experimental data. The returned function is JIT-compatible and can be used with various optimization algorithms.
- Parameters:
forward_function (
Callable[...,Array]) – The forward model function (e.g., stem_4d).experimental_data (
Array) – The experimental data to compare against.loss_type (
str, optional) – The type of loss to use. Options are “mae” (Mean Absolute Error), “mse” (Mean Squared Error), or “rmse” (Root Mean Squared Error), by default “mae”.
- Returns:
loss_fn – A JIT-compatible function that computes the loss given the model parameters and any additional arguments required by the forward function.
- Return type:
Callable[[PyTree,],Float[Array," "]]
Notes
Define internal loss functions (mae_loss, mse_loss, rmse_loss).
Select the appropriate loss function based on loss_type.
- Create a JIT-compiled function that:
Computes the forward model output.
Calculates the difference between model and experimental data.
Applies the selected loss function.
Return the compiled loss function.
- janssen.invert.init_adagrad(shape: tuple) OptimizerState[source]¶
Initialize Adagrad optimizer state.
- Parameters:
shape (
Tuple) – Shape of the parameters to be optimized- Returns:
state – Initialized Adagrad optimizer state with zero accumulated gradients
- Return type:
OptimizerState
- janssen.invert.init_adam(shape: tuple) OptimizerState[source]¶
Initialize Adam optimizer state.
- Parameters:
shape (
Tuple) – Shape of the parameters to be optimized- Returns:
state – Initialized Adam optimizer state with zero moments and step=0
- Return type:
OptimizerState
- janssen.invert.init_rmsprop(shape: tuple) OptimizerState[source]¶
Initialize RMSprop optimizer state.
- Parameters:
shape (
Tuple) – Shape of the parameters to be optimized- Returns:
state – Initialized RMSprop optimizer state with zero moving average
- Return type:
OptimizerState
- janssen.invert.simple_microscope_ptychography(experimental_data: MicroscopeData, guess_sample: SampleFunction, guess_lightwave: OpticalWavefront, params: PtychographyParams, save_every: int | Int[Array, ''] | None = 10, loss_type: str | None = 'mse', optimizer_name: str | None = 'adam', zoom_factor_bounds: tuple[float | Float[Array, ''], float | Float[Array, '']] | None = None, aperture_diameter_bounds: tuple[float | Float[Array, ''], float | Float[Array, '']] | None = None, travel_distance_bounds: tuple[float | Float[Array, ''], float | Float[Array, '']] | None = None, aperture_center_bounds: tuple[Float[Array, '2'], Float[Array, '2']] | None = None) tuple[tuple[SampleFunction, OpticalWavefront, float | Float[Array, ''], float | Float[Array, ''], Float[Array, '2'] | None, float | Float[Array, '']], tuple[Complex[Array, 'H W S'], Complex[Array, 'H W S'], Float[Array, 'S'], Float[Array, 'S'], Float[Array, '2 S'], Float[Array, 'S']]][source]¶
Solve the optical ptychography inverse problem.
Here experimental diffraction patterns are used to reconstruct a sample, lightwave, and optical system parameters.
- Parameters:
experimental_data (
MicroscopeData) – The experimental diffraction patterns collected at different positions.guess_sample (
SampleFunction) – Initial guess for the sample properties.guess_lightwave (
OpticalWavefront) – Initial guess for the lightwave.params (
PtychographyParams) – Ptychography parameters including: - zoom_factor: Optical zoom factor for magnification - aperture_diameter: Diameter of the aperture in meters - travel_distance: Light propagation distance in meters - aperture_center: Center position of the aperture (x, y) - camera_pixel_size: Camera pixel size in meters - learning_rate: Learning rate for optimization - num_iterations: Number of optimization iterationssave_every (
scalar_integer, optional) – Save intermediate results every n iterations. Default is 10.loss_type (
str, optional) – Type of loss function to use. Default is “mse”.optimizer_name (
str, optional) – Name of the optimizer to use. Default is “adam”.zoom_factor_bounds (
Tuple[scalar_float,scalar_float], optional) – Lower and upper bounds for zoom factor optimization.aperture_diameter_bounds (
Tuple[scalar_float,scalar_float], optional) – Lower and upper bounds for aperture diameter optimization.travel_distance_bounds (
Tuple[scalar_float,scalar_float], optional) – Lower and upper bounds for travel distance optimization.aperture_center_bounds (
Tuple[Float[Array," 2"],Float[Array," 2"]], optional) – Lower and upper bounds for aperture center optimization.
- Returns:
Tuple containing: - Final results tuple:
- final_sampleSampleFunction
Optimized sample properties.
- final_lightwaveOpticalWavefront
Optimized lightwave.
- final_zoom_factorscalar_float
Optimized zoom factor.
- final_aperture_diameterscalar_float
Optimized aperture diameter.
- final_aperture_centerFloat[Array, “ 2”] or None
Optimized aperture center.
- final_travel_distancescalar_float
Optimized travel distance.
- Intermediate results tuple:
- intermediate_samplesComplex[Array, “ H W S”]
Intermediate samples during optimization.
- intermediate_lightwavesComplex[Array, “ H W S”]
Intermediate lightwaves during optimization.
- intermediate_zoom_factorsFloat[Array, “ S”]
Intermediate zoom factors during optimization.
- intermediate_aperture_diametersFloat[Array, “ S”]
Intermediate aperture diameters during optimization.
- intermediate_aperture_centersFloat[Array, “ 2 S”]
Intermediate aperture centers during optimization.
- intermediate_travel_distancesFloat[Array, “ S”]
Intermediate travel distances during optimization.
- Return type:
Tuple[Tuple[...],Tuple[...]]
janssen.lenses¶
Lens implementations and optical calculations.
Extended Summary¶
Comprehensive lens modeling and optical propagation algorithms for simulating light propagation through various optical elements. Includes implementations of common lens types and propagation methods based on wave optics.
Submodules¶
- lens_elements
Lens elements for optical simulations
- lens_prop
Lens propagation functions
Routine Listings¶
- create_lens_phasefunction
Create phase profile for a lens based on its parameters
- double_concave_lensfunction
Create parameters for a double concave lens
- double_convex_lensfunction
Create parameters for a double convex lens
- lens_focal_lengthfunction
Calculate focal length from lens parameters
- lens_thickness_profilefunction
Calculate thickness profile of a lens
- meniscus_lensfunction
Create parameters for a meniscus lens
- plano_concave_lensfunction
Create parameters for a plano-concave lens
- plano_convex_lensfunction
Create parameters for a plano-convex lens
- propagate_through_lensfunction
Propagate optical wavefront through a lens
- angular_spectrum_propfunction
Angular spectrum propagation method
- digital_zoomfunction
Digital zoom transformation for optical fields
- fraunhofer_propfunction
Fraunhofer (far-field) propagation
- fresnel_propfunction
Fresnel (near-field) propagation
- lens_propagationfunction
General lens-based propagation
- optical_zoomfunction
Optical zoom transformation
Notes
All propagation functions are JAX-compatible and support automatic differentiation. The lens functions can model both ideal and realistic optical elements with aberrations.
- janssen.lenses.create_lens_phase(xx: Float[Array, 'hh ww'], yy: Float[Array, 'hh ww'], params: LensParams, wavelength: float | Float[Array, '']) tuple[Float[Array, 'hh ww'], Float[Array, 'hh ww']][source]¶
Create the phase profile and transmission mask for a lens.
- Parameters:
xx (
Float[Array," hh ww"]) – X coordinates grid.yy (
Float[Array," hh ww"]) – Y coordinates grid.params (
LensParams) – Lens parameters.wavelength (
float) – Wavelength of light.
- Return type:
tuple[Float[Array, 'hh ww'],Float[Array, 'hh ww']]- Returns:
phase_profile (
Float[Array," hh ww"]) – Phase profile of the lens.transmission (
Float[Array," hh ww"]) – Transmission mask of the lens.
Notes
Calculate radial coordinates.
Calculate thickness profile.
Calculate phase profile.
Create transmission mask.
Return phase and transmission.
- janssen.lenses.double_concave_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], r_ratio: float | Float[Array, ''] | None = 1.0) LensParams[source]¶
Create parameters for a double concave lens.
- Parameters:
- Returns:
params – Lens parameters.
- Return type:
LensParams
Notes
Calculate R1 using lensmaker’s equation.
Calculate R2 using R_ratio.
Create and return LensParams.
- janssen.lenses.double_convex_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], r_ratio: float | Float[Array, ''] | None = 1.0) LensParams[source]¶
Create parameters for a double convex lens.
- Parameters:
- Returns:
params – Lens parameters.
- Return type:
LensParams
Notes
Calculate r1 using lensmaker’s equation.
Calculate r2 using R_ratio.
Create and return LensParams.
- janssen.lenses.lens_focal_length(n: float | Float[Array, ''], r1: int | float | complex | Num[Array, ''], r2: int | float | complex | Num[Array, '']) float | Float[Array, ''][source]¶
Calculate the focal length of a lens using the lensmaker’s equation.
- Parameters:
n (
float) – Refractive index of the lens material.r1 (
scalar_numeric) – Radius of curvature of the first surface (positive for convex).r2 (
scalar_numeric) – Radius of curvature of the second surface (positive for convex).
- Returns:
f – Focal length of the lens.
- Return type:
Notes
Apply the lensmaker’s equation.
Return the calculated focal length.
- janssen.lenses.lens_thickness_profile(r: Float[Array, 'H W'], r1: float | Float[Array, ''], r2: float | Float[Array, ''], center_thickness: float | Float[Array, ''], diameter: float | Float[Array, '']) Float[Array, 'H W'][source]¶
Calculate the thickness profile of a lens.
- Parameters:
- Returns:
thickness – Thickness profile of the lens.
- Return type:
Float[Array," H W"]
Notes
- Calculate surface sag for both surfaces
only where aperture mask & r is finite.
Combine sags with center thickness.
Return thickness profile.
- janssen.lenses.meniscus_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], r_ratio: float | Float[Array, ''], convex_first: bool | Bool[Array, ''] | None = True) LensParams[source]¶
Create parameters for a meniscus (concavo-convex) lens.
For a meniscus lens, one surface is convex (positive R) and one is concave (negative R).
- Parameters:
focal_length (
float) – Desired focal length in meters.diameter (
float) – Lens diameter in meters.n (
float) – Refractive index of lens material.center_thickness (
float) – Center thickness in meters.r_ratio (
float) – Absolute ratio of R2/R1.convex_first (
scalar_bool, optional) – If True, first surface is convex, by default True.
- Returns:
params – Lens parameters.
- Return type:
LensParams
Notes
Calculate magnitude of R1 using lensmaker’s equation.
Calculate R2 magnitude using R_ratio.
Assign correct signs based on convex_first.
Create and return LensParams.
- janssen.lenses.plano_concave_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], concave_first: bool | Bool[Array, ''] | None = True) LensParams[source]¶
Create parameters for a plano-concave lens.
- Parameters:
- Returns:
params – Lens parameters.
- Return type:
LensParams
Notes
Calculate R for curved surface.
Set other R to infinity (flat surface).
Create and return LensParams.
- janssen.lenses.plano_convex_lens(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], convex_first: bool | Bool[Array, ''] | None = True) LensParams[source]¶
Create parameters for a plano-convex lens.
- Parameters:
- Returns:
params – Lens parameters.
- Return type:
LensParams
Notes
Calculate R for curved surface.
Set other R to infinity (flat surface).
Create and return LensParams.
- janssen.lenses.propagate_through_lens(field: Complex[Array, 'hh ww'], phase_profile: Float[Array, 'hh ww'], transmission: Float[Array, 'hh ww']) Complex[Array, 'hh ww'][source]¶
Propagate a field through a lens.
- Parameters:
field (
Complex[Array," hh ww"]) – Input complex field.phase_profile (
Float[Array," hh ww"]) – Phase profile of the lens.transmission (
Float[Array," hh ww"]) – Transmission mask of the lens.
- Returns:
output_field – Field after passing through the lens.
- Return type:
Complex[Array," hh ww"]
Notes
Apply transmission mask.
Add phase profile.
Return modified field.
- janssen.lenses.angular_spectrum_prop(incoming: OpticalWavefront, z_move: int | float | complex | Num[Array, ''], refractive_index: int | float | complex | Num[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Propagate a complex field using the angular spectrum method.
- Parameters:
incoming (
OpticalWavefront) –PyTree with the following parameters:
- fieldComplex[Array, “ hh ww”]
Input complex field
- wavelengthFloat[Array, “ “]
Wavelength of light in meters
- dxFloat[Array, “ “]
Grid spacing in meters
- z_positionFloat[Array, “ “]
Wave front position in meters
z_move (
scalar_numeric) – Propagation distance in meters This is in free space.refractive_index (
Optional[scalar_numeric], optional) – Index of refraction of the medium. Default is 1.0 (vacuum).
- Returns:
Propagated wave front
- Return type:
OpticalWavefront
Notes
Algorithm:
Get the shape of the input field
Calculate the wavenumber
Compute the path length
Create spatial frequency coordinates
Compute the squared spatial frequencies
Angular spectrum transfer function
Ensure evanescent waves are properly handled
Fourier transform of the input field
Apply the transfer function in the Fourier domain
Inverse Fourier transform to get the propagated field
Return the propagated field
- janssen.lenses.digital_zoom(wavefront: OpticalWavefront, zoom_factor: int | float | complex | Num[Array, '']) OpticalWavefront[source]¶
Zoom an optical wavefront by a specified factor.
Key is this returns the same sized array as the original wavefront.
- Parameters:
wavefront (
OpticalWavefront) – Incoming optical wavefront.zoom_factor (
scalar_numeric) – Zoom factor (greater than 1 to zoom in, less than 1 to zoom out).
- Returns:
zoomed_wavefront – Zoomed optical wavefront of the same spatial dimensions.
- Return type:
OpticalWavefront
Notes
Algorithm:
For zoom in (zoom_factor >= 1.0): - Calculate the crop fraction (1 / zoom_factor) to determine the
central region to extract
- Create interpolation coordinates for the zoomed region centered
on the image
- Use scipy.ndimage.map_coordinates with bilinear interpolation
to sample the field
Return the zoomed field with adjusted pixel size (dx / zoom_factor)
For zoom out (zoom_factor < 1.0): - Calculate the shrink fraction (zoom_factor) to determine the
final image size
Create a coordinate mapping from the full image to the shrunken region
Use scipy.ndimage.map_coordinates to interpolate the original field
- Apply a mask to zero out regions outside the shrunken
area (padding effect)
Return the zoomed field with adjusted pixel size (dx / zoom_factor)
- janssen.lenses.fraunhofer_prop(incoming: OpticalWavefront, z_move: float | Float[Array, ''], refractive_index: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Propagate a complex field using the Fraunhofer approximation.
- Parameters:
incoming (
OpticalWavefront) –PyTree with the following parameters:
- fieldComplex[Array, “ hh ww”]
Input complex field
- wavelengthFloat[Array, “ “]
Wavelength of light in meters
- dxFloat[Array, “ “]
Grid spacing in meters
- z_positionFloat[Array, “ “]
Wave front position in meters
z_move (
float) – Propagation distance in meters. This is in free space.refractive_index (
float, optional) – Index of refraction of the medium. Default is 1.0 (vacuum).
- Returns:
Propagated wave front
- Return type:
OpticalWavefront
Notes
Algorithm:
Get the shape of the input field
Calculate the spatial frequency coordinates
Create the meshgrid of spatial frequencies
Compute the transfer function for Fraunhofer propagation
Compute the Fourier transform of the input field
Apply the transfer function in the Fourier domain
Inverse Fourier transform to get the propagated field
Return the propagated field
- janssen.lenses.fresnel_prop(incoming: OpticalWavefront, z_move: int | float | complex | Num[Array, ''], refractive_index: int | float | complex | Num[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Propagate a complex field using the Fresnel approximation.
- Parameters:
incoming (
OpticalWavefront) –PyTree with the following parameters:
- fieldComplex[Array, “ hh ww”]
Input complex field
- wavelengthFloat[Array, “ “]
Wavelength of light in meters
- dxFloat[Array, “ “]
Grid spacing in meters
- z_positionFloat[Array, “ “]
Wave front position in meters
z_move (
scalar_numeric) – Propagation distance in meters This is in free space.refractive_index (
Optional[scalar_numeric], optional) – Index of refraction of the medium. Default is 1.0 (vacuum).
- Returns:
Propagated wave front
- Return type:
OpticalWavefront
Notes
Algorithm:
Calculate the wavenumber
Create spatial coordinates
- Quadratic phase factor for Fresnel approximation
(pre-free-space propagation)
Apply quadratic phase to the input field
Compute Fourier transform of the input field
Compute spatial frequency coordinates
Transfer function for Fresnel propagation
Apply the transfer function in the Fourier domain
Inverse Fourier transform to get the propagated field
Final quadratic phase factor (post-free-space propagation)
Apply final quadratic phase factor
Return the propagated field
- janssen.lenses.lens_propagation(incoming: OpticalWavefront, lens: LensParams) OpticalWavefront[source]¶
Propagate an optical wavefront through a lens.
The lens is modeled as a thin lens with a given focal length and diameter.
- Parameters:
incoming (
OpticalWavefront) – The incoming optical wavefrontlens (
LensParams) – The lens parameters including focal length and diameter
- Returns:
The propagated optical wavefront after passing through the lens
- Return type:
OpticalWavefront
Notes
Algorithm:
- Create a meshgrid of coordinates based on the incoming wavefront’s
shape and pixel size.
Calculate the phase profile and transmission function of the lens.
Apply the phase screen to the incoming wavefront’s field.
- Return the new optical wavefront with the updated field, wavelength,
and pixel size.
- janssen.lenses.optical_zoom(wavefront: OpticalWavefront, zoom_factor: int | float | complex | Num[Array, '']) OpticalWavefront[source]¶
Modify the calibration of an optical wavefront without changing field.
- Parameters:
wavefront (
OpticalWavefront) – Incoming optical wavefront.zoom_factor (
scalar_numeric) – Zoom factor (greater than 1 to zoom in, less than 1 to zoom out).
- Returns:
Zoomed optical wavefront of the same spatial dimensions.
- Return type:
OpticalWavefront
janssen.simul¶
Differentiable optical simulation toolkit.
Extended Summary¶
Comprehensive optical simulation framework for modeling light propagation through various optical elements. All components are differentiable and optimized for JAX transformations, enabling gradient-based optimization of optical systems.
Submodules¶
- apertures
Aperture functions for optical microscopy
- elements
Optical element transformations
- microscope
Microscopy simulation pipelines
- helper
Helper functions for optical propagation
Routine Listings¶
- annular_aperturefunction
Create an annular (ring-shaped) aperture
- circular_aperturefunction
Create a circular aperture
- gaussian_apodizerfunction
Apply Gaussian apodization to a field
- gaussian_apodizer_ellipticalfunction
Apply elliptical Gaussian apodization
- rectangular_aperturefunction
Create a rectangular aperture
- supergaussian_apodizerfunction
Apply super-Gaussian apodization
- supergaussian_apodizer_ellipticalfunction
Apply elliptical super-Gaussian apodization
- variable_transmission_aperturefunction
Create aperture with variable transmission
- amplitude_grating_binaryfunction
Create binary amplitude grating
- apply_phase_maskfunction
Apply a phase mask to a field
- apply_phase_mask_fnfunction
Apply a phase mask function
- beam_splitterfunction
Model beam splitter operation
- half_waveplatefunction
Half-wave plate transformation
- mirror_reflectionfunction
Model mirror reflection
- nd_filterfunction
Neutral density filter
- phase_grating_blazed_ellipticalfunction
Elliptical blazed phase grating
- phase_grating_sawtoothfunction
Sawtooth phase grating
- phase_grating_sinefunction
Sinusoidal phase grating
- polarizer_jonesfunction
Jones matrix for polarizer
- prism_phase_rampfunction
Phase ramp from prism
- quarter_waveplatefunction
Quarter-wave plate transformation
- waveplate_jonesfunction
General waveplate Jones matrix
- add_phase_screenfunction
Add phase screen to field
- create_spatial_gridfunction
Create computational spatial grid
- field_intensityfunction
Calculate field intensity
- normalize_fieldfunction
Normalize optical field
- scale_pixelfunction
Scale pixel size in field
- linear_interactionfunction
Linear light-matter interaction
- simple_diffractogramfunction
Generate diffraction pattern
- simple_microscopefunction
Simple microscope forward model
Notes
All simulation functions support automatic differentiation and can be composed to model complex optical systems. The toolkit is optimized for both forward simulation and inverse problems in optics.
- janssen.simul.annular_aperture(incoming: OpticalWavefront, inner_diameter: float | Float[Array, ''], outer_diameter: float | Float[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply an annular (ring) aperture with inner and outer diameters.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront PyTree.inner_diameter (
float) – Inner blocked diameter in meters.outer_diameter (
float) – Outer clear aperture diameter in meters.center (
Optional[Float[Array," 2"]], optional) – Ring center [x0, y0] in meters, by default [0, 0].transmittivity (
Optional[scalar_float], optional) – Uniform transmittivity in the ring (0..1), by default 1.0.
- Returns:
apertured – Wavefront after applying the annular aperture.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids in meters.
Compute radial distance from center.
Create mask for inner_radius < r <= outer_radius.
Multiply by transmittivity (clipped), apply, and return.
- janssen.simul.circular_aperture(incoming: OpticalWavefront, diameter: float | Float[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply a circular aperture to the incoming wavefront.
The aperture is defined by its physical diameter and (optional) center.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront PyTree.diameter (
float) – Aperture diameter in meters.center (
Optional[Float[Array," 2"]], optional) – Physical center [x0, y0] of the aperture in meters, by default [0, 0].transmittivity (
Optional[scalar_float], optional) – Uniform transmittivity inside the aperture (0..1), by default 1.0.
- Returns:
apertured – Wavefront after applying the circular aperture.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids in meters.
Compute radial distance from the specified center.
Create a binary mask for r <= diameter/2.
Multiply by transmittivity (clipped to [0, 1]).
Apply to the complex field and return.
- janssen.simul.gaussian_apodizer(incoming: OpticalWavefront, sigma: float | Float[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply a Gaussian apodizer (smooth transmission mask) to the wavefront.
- Parameters:
incoming (
OpticalWavefront) – Input optical wavefront.sigma (
float) – Gaussian width parameter in meters.center (
Optional[Float[Array," 2"]], optional) – Physical center [x0, y0] of the Gaussian in meters, by default [0, 0].peak_transmittivity (
Optional[scalar_float], optional) – Maximum transmission at the Gaussian center, by default 1.0.
- Returns:
apertured – Wavefront after applying Gaussian apodization.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids.
Compute squared radial distance from center.
Evaluate Gaussian exp(-r^2 / (2*sigma^2)).
Scale by peak transmittivity, clip to [0,1].
Multiply with incoming field and return.
- janssen.simul.gaussian_apodizer_elliptical(incoming: OpticalWavefront, sigma_x: float | Float[Array, ''], sigma_y: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0, center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply an elliptical Gaussian apodizer to the wavefront.
With optional rotation, through an angle theta.
- Parameters:
incoming (
OpticalWavefront) – Input optical wavefront.sigma_x (
float) – Gaussian width along the x’-axis (meters) after rotation by theta.sigma_y (
float) – Gaussian width along the y’-axis (meters) after rotation by theta.theta (
Optional[scalar_float], optional) – Rotation angle in radians (counter-clockwise), by default 0.0.center (
Optional[Float[Array," 2"]], optional) – Physical center [x0, y0] in meters, by default [0, 0].peak_transmittivity (
Optional[scalar_float], optional) – Maximum transmission at the center, by default 1.0.
- Returns:
apertured – Wavefront after applying elliptical Gaussian apodization.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids.
Translate by center, rotate by theta → (x’, y’).
Evaluate exp(-0.5 * ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )).
Scale by peak_transmittivity, clip to [0, 1].
Multiply with incoming field and return.
- janssen.simul.rectangular_aperture(incoming: OpticalWavefront, width: float | Float[Array, ''], height: float | Float[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply an axis-aligned rectangular aperture to the incoming wavefront.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront PyTree.width (
float) – Rectangle width along x in meters.height (
float) – Rectangle height along y in meters.center (
Optional[Float[Array," 2"]], optional) – Rectangle center [x0, y0] in meters, by default [0, 0].transmittivity (
Optional[scalar_float], optional) – Uniform transmittivity inside the rectangle (0..1), by default 1.0.
- Returns:
apertured – Wavefront after applying the rectangular aperture.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids in meters.
Compute half-width/half-height and an inside-rectangle mask.
Multiply by transmittivity (clipped).
Apply to the complex field and return.
- janssen.simul.supergaussian_apodizer(incoming: OpticalWavefront, sigma: float | Float[Array, ''], m: int | float | complex | Num[Array, ''], center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply a super-Gaussian apodizer to the wavefront.
Transmission profile: exp(- (r^2 / sigma^2)^m ).
- Parameters:
incoming (
OpticalWavefront) – Input optical wavefront.sigma (
float) – Width parameter in meters (sets the roll-off scale).m (
scalar_numeric) – Super-Gaussian order (m=1 → Gaussian, m>1 → flatter top).center (
Optional[Float[Array," 2"]], optional) – Physical center [x0, y0] of the profile, by default [0, 0].peak_transmittivity (
Optional[scalar_float], optional) – Maximum transmission at the center, by default 1.0.
- Returns:
apertured – Wavefront after applying super-Gaussian apodization.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids.
Compute squared radial distance from center.
Evaluate exp(- (r^2 / sigma^2)^m ).
Scale by peak transmittivity, clip to [0,1].
Multiply with incoming field and return.
- janssen.simul.supergaussian_apodizer_elliptical(incoming: OpticalWavefront, sigma_x: float | Float[Array, ''], sigma_y: float | Float[Array, ''], m: int | float | complex | Num[Array, ''], theta: float | Float[Array, ''] | None = 0.0, center: Float[Array, '2'] | None = Array([0., 0.], dtype=float64), peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply an elliptical super-Gaussian apodizer with optional rotation.
Transmission profile: exp( - ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )^m ).
- Parameters:
incoming (
OpticalWavefront) – Input optical wavefront.sigma_x (
float) – Width along x’ (meters) after rotation by theta.sigma_y (
float) – Width along y’ (meters) after rotation by theta.m (
scalar_numeric) – Super-Gaussian order (m=1 → Gaussian; m>1 → flatter top, sharper edges).theta (
Optional[scalar_float], optional) – Rotation angle in radians (counter-clockwise), by default 0.0.center (
Optional[Float[Array," 2"]], optional) – Physical center [x0, y0] in meters, by default [0, 0].peak_transmittivity (
Optional[scalar_float], optional) – Maximum transmission at the center, by default 1.0.
- Returns:
apertured – Wavefront after applying elliptical super-Gaussian apodization.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids.
Translate by center, rotate by theta → (x’, y’).
Evaluate exp( - ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )^m ).
Scale by peak_transmittivity, clip to [0, 1].
Multiply with incoming field and return.
- janssen.simul.variable_transmission_aperture(incoming: OpticalWavefront, transmission: float | Float[Array, ''] | Float[Array, '...']) OpticalWavefront[source]¶
Apply an arbitrary (spatially varying) transmission to the wavefront.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront PyTree.transmission (
Union[scalar_float,Float[Array," H W"]]) – Precomputed transmission map (0..1) with shape “H W”, or a scalar attenuation factor for uniform transmission.
- Returns:
apertured – Wavefront after applying the transmission.
- Return type:
OpticalWavefront
Examples
Uniform attenuation:
>>> wf2 = variable_transmission_aperture(wf, 0.5) # 50% transmission
Spatially varying transmission:
>>> tmap = create_transmission_map(...) # Shape (H, W) >>> wf2 = variable_transmission_aperture(wf, tmap)
Notes
For scalar transmission: applies uniform attenuation.
For array transmission: applies spatially varying transmission map.
Transmission values are clipped to [0, 1].
This function is fully JAX-compatible and uses jax.lax.cond.
- janssen.simul.amplitude_grating_binary(incoming: OpticalWavefront, period: float | Float[Array, ''], duty_cycle: float | Float[Array, ''] | None = 0.5, theta: float | Float[Array, ''] | None = 0.0, trans_high: float | Float[Array, ''] | None = 1.0, trans_low: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]¶
Binary amplitude grating with given duty cycle.
- Parameters:
incoming (
OpticalWavefront) – Input field.period (
float) – Period in meters.duty_cycle (
float, optional) – Fraction of period in ‘high’ state (0..1), by default 0.5.theta (
float, optional) – Orientation (radians), by default 0.0.trans_high (
float, optional) – Amplitude transmittance for ‘high’ bars, by default 1.0.trans_low (
float, optional) – Amplitude transmittance for ‘low’ bars, by default 0.0.
- Returns:
Field after amplitude modulation.
- Return type:
OpticalWavefront
Notes
Compute u along grating direction.
Map u modulo period → binary mask via duty cycle.
Apply amplitude levels to field.
- janssen.simul.apply_phase_mask(incoming: OpticalWavefront, phase_map: Float[Array, 'H W']) OpticalWavefront[source]¶
Apply an arbitrary phase mask (e.g., SLM, turbulence screen).
Field_out = field_in * exp(i * phase_map).
- Parameters:
incoming (
OpticalWavefront) – Input field.phase_map (
Float[Array," H W"]) – Phase in radians, same spatial shape as field.
- Returns:
Field with added phase.
- Return type:
OpticalWavefront
- janssen.simul.apply_phase_mask_fn(incoming: OpticalWavefront, phase_fn: Callable[[Float[Array, 'H W'], Float[Array, 'H W']], Float[Array, 'H W']]) OpticalWavefront[source]¶
Build and apply a phase mask from a callable phase_fn(xx, yy).
- Parameters:
incoming (
OpticalWavefront) – Input field.phase_fn (
callable) – Function producing a phase map (radians) given centered grids xx, yy (meters).
- Returns:
Field with added phase.
- Return type:
OpticalWavefront
- janssen.simul.beam_splitter(incoming: OpticalWavefront, t2: float | Float[Array, ''] | None = 0.5, r2: float | Float[Array, ''] | None = 0.5, normalize: bool | None = True) tuple[OpticalWavefront, OpticalWavefront][source]¶
Split an input field into transmitted and reflected components.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront (scalar field).t2 (
float, optional) – Complex transmission amplitude, by default jnp.sqrt(0.5).r2 (
float, optional) – Complex reflection amplitude. Default 1j * jnp.sqrt(0.5) for 50/50 convention.normalize (
bool, optional) – If True, scale (t, r) so that |t|^2 + |r|^2 = 1, by default True.
- Return type:
- Returns:
wf_T (
OpticalWavefront) – Transmitted arm (t * field).wf_R (
OpticalWavefront) – Reflected arm (r * field).
Notes
Optionally renormalize (t, r).
Multiply field by t and r.
Return two wavefronts sharing same metadata.
- janssen.simul.half_waveplate(incoming: OpticalWavefront, theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]¶
Apply a half-wave plate (δ = π) with fast-axis angle theta.
- Parameters:
incoming (
OpticalWavefront) – Vector field Complex[H, W, 2] (Jones: ex, ey).theta (
float, optional) – Fast-axis angle in radians (CCW from x), by default 0.0.
- Returns:
hw_wavefront – Retarded field after half-wave plate.
- Return type:
OpticalWavefront
Notes
Call waveplate_jones with delta = π.
- janssen.simul.mirror_reflection(incoming: OpticalWavefront, flip_x: bool | None = True, flip_y: bool | None = False, add_pi_phase: bool | None = True, conjugate: bool | None = True) OpticalWavefront[source]¶
Mirror reflection: coordinate flips with optional π-phase and conjugation.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront.flip_x (
bool, optional) – Flip along x-axis (columns), by default True.flip_y (
bool, optional) – Flip along y-axis (rows), by default False.add_pi_phase (
bool, optional) – Multiply by exp(i*pi) = -1 to simulate phase inversion on reflection. Default True.conjugate (
bool, optional) – Conjugate the complex field, useful when reversing propagation direction. Default is True.
- Returns:
Reflected wavefront.
- Return type:
OpticalWavefront
Notes
Flip axes as requested (jnp.flip).
Optional complex conjugation.
Optional -1 factor for π phase.
- janssen.simul.nd_filter(incoming: OpticalWavefront, optical_density: float | Float[Array, ''] | None = 0.0, transmittance: float | Float[Array, ''] | None = -1.0) OpticalWavefront[source]¶
Neutral density (ND) filter as a uniform amplitude attenuator.
- Parameters:
- Returns:
nd_wavefront – Attenuated wavefront.
- Return type:
OpticalWavefront
Notes
Determine intensity T from OD or provided T.
Amplitude factor a = sqrt(T).
Multiply field by a and return.
- janssen.simul.phase_grating_blazed_elliptical(incoming: OpticalWavefront, period_x: float | Float[Array, ''], period_y: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0, depth: float | Float[Array, ''] | None = 6.283185307179586, two_dim: bool | None = False) OpticalWavefront[source]¶
Orientation-aware elliptical blazed grating.
Supports anisotropic periods along rotated axes (x’, y’) and optional 2D blaze.
- Parameters:
incoming (
OpticalWavefront) – Input scalar wavefront.period_x (
float) – Blaze period along x’ in meters (after rotation by theta).period_y (
float) – Blaze period along y’ in meters (after rotation by theta).theta (
float, optional) – Grating orientation angle in radians (CCW from x), by default 0.0.depth (
float, optional) – Peak-to-peak phase depth in radians, by default 2π.two_dim (
bool, optional) – If False (default), apply a 1D blaze along x’ only. If True, create a 2D blazed lattice using both x’ and y’.
- Returns:
phase_grating_wavefront – Field after applying the elliptical blazed phase.
- Return type:
OpticalWavefront
Notes
Build centered grids xx, yy (meters) and rotate → (x’, y’).
- Compute fractional coordinates
..math:: fu = frac(x’/period_x) fv = frac(y’/period_y)
- if two_dim is True
..math:: phase = depth * frac(fu + fv)
- else,
..math:: phase = depth * fu
Multiply by exp(i * phase) and return.
- janssen.simul.phase_grating_sawtooth(incoming: OpticalWavefront, period: float | Float[Array, ''], depth: float | Float[Array, ''], theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]¶
Sawtooth phase grating with peak-to-peak depth (radians).
- Parameters:
- Returns:
Field after blazed phase modulation.
- Return type:
OpticalWavefront
Notes
Compute fractional coordinate within each period.
- Sawtooth phase in [0, depth) → shift to mean-zero if desired
(kept at [0, depth)).
Apply phase with exp(i*phase).
- janssen.simul.phase_grating_sine(incoming: OpticalWavefront, period: float | Float[Array, ''], depth: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]¶
Sinusoidal phase grating.
Phase = depth * sin(2π * u / period), where u is the coordinate along the grating direction.
- Parameters:
- Returns:
Field after phase modulation.
- Return type:
OpticalWavefront
- janssen.simul.polarizer_jones(incoming: OpticalWavefront, theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]¶
Linear polarizer at angle theta (radians, CCW from x-axis).
Applied to a 2-component Jones field (ex, ey) stored in the last dimension.
- Parameters:
incoming (
OpticalWavefront) – Field shape must be Complex[H, W, 2].theta (
float, optional) – Transmission axis angle (radians), by default 0.0.
- Returns:
Polarized field with same shape.
- Return type:
OpticalWavefront
Notes
Jones matrix: P = R(-θ) @ [[1, 0],[0, 0]] @ R(θ).
Apply P to [ex, ey] at each pixel.
- janssen.simul.prism_phase_ramp(incoming: OpticalWavefront, deflect_x: float | Float[Array, ''] | None = 0.0, deflect_y: float | Float[Array, ''] | None = 0.0, use_small_angle: bool | None = True) OpticalWavefront[source]¶
Apply a linear phase ramp to simulate a prism-induced beam deviation.
- Parameters:
incoming (
OpticalWavefront) – Input scalar wavefront.deflect_x (
float, optional) – Deflection along +x. If use_small_angle is True, interpreted as angle (rad). Otherwise interpreted as spatial frequency kx [rad/m], by default 0.0.deflect_y (
float, optional) – Deflection along +y (angle or ky), by default 0.0.use_small_angle (
bool, optional) – If True, convert small angles to kx, ky via k*sin(angle) ~ k*angle. Default True.
- Returns:
Wavefront with added linear phase.
- Return type:
OpticalWavefront
Notes
Build xx, yy grids (m).
Compute kx, ky from deflections.
Phase = kx*xx + ky*yy; multiply by exp(i*phase).
- janssen.simul.quarter_waveplate(incoming: OpticalWavefront, theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]¶
Apply a quarter-wave plate (δ = π/2) with fast-axis angle theta.
- Parameters:
incoming (
OpticalWavefront) – Vector field Complex[H, W, 2] (Jones: ex, ey).theta (
float, optional) – Fast-axis angle in radians (CCW from x), by default 0.0.
- Returns:
qw_wavefront – Retarded field after quarter-wave plate.
- Return type:
OpticalWavefront
Notes
Call waveplate_jones with delta = π/2.
- janssen.simul.waveplate_jones(incoming: OpticalWavefront, delta: float | Float[Array, ''], theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]¶
Waveplate/retarder with retardance delta and fast-axis angle theta.
Special cases: quarter-wave (delta=π/2), half-wave (delta=π).
- Parameters:
- Returns:
jones_wavefront – Retarded field with same shape.
- Return type:
OpticalWavefront
Notes
Jones matrix: J = R(-θ) @ diag(1, e^{iδ}) @ R(θ).
Apply J to [ex, ey] per pixel.
- janssen.simul.add_phase_screen(field: Num[Array, 'hh ww'], phase: Float[Array, 'hh ww']) Complex[Array, 'H W'][source]¶
Add a phase screen to a complex field.
- Parameters:
field (
Num[Array," hh ww"]) – Input complex field.phase (
Float[Array," hh ww"]) – Phase screen to add.
- Returns:
screened_field – Field with phase screen added.
- Return type:
Complex[Array," hh ww"]
Notes
Multiply the input field by the exponential of the phase screen.
Return the screened field.
- janssen.simul.create_spatial_grid(diameter: Num[Array, ''], num_points: Int[Array, '']) tuple[Float[Array, 'nn nn'], Float[Array, 'nn nn']][source]¶
Create a 2D spatial grid for optical propagation.
- Parameters:
diameter (
Num[Array," "]) – Physical size of the grid in meters.num_points (
Int[Array," "]) – Number of points in each dimension.
- Return type:
tuple[Float[Array, 'nn nn'],Float[Array, 'nn nn']]- Returns:
xx (
Float[Array," nn nn"]) – X coordinate grid in meters.yy (
Float[Array," nn nn"]) – Y coordinate grid in meters.
Notes
Create a linear space of points along the x-axis.
Create a linear space of points along the y-axis.
Create a meshgrid of spatial coordinates.
Return the meshgrid.
- janssen.simul.field_intensity(field: Complex[Array, 'hh ww']) Float[Array, 'hh ww'][source]¶
Calculate intensity from complex field.
- Parameters:
field (
Complex[Array," hh ww"]) – Input complex field.- Returns:
intensity – Intensity of the field.
- Return type:
Float[Array," hh ww"]
Notes
Calculate the intensity as the square of the absolute value of the field.
Return the intensity.
- janssen.simul.normalize_field(field: Complex[Array, 'hh ww']) Complex[Array, 'hh ww'][source]¶
Normalize complex field to unit power.
- Parameters:
field (
Complex[Array," hh ww"]) – Input complex field.- Returns:
normalized_field – Normalized complex field.
- Return type:
Complex[Array," hh ww"]
Notes
- Calculate the power of the field as the sum of the square of
the absolute value of the field.
Normalize the field by dividing by the square root of the power.
Return the normalized field.
- janssen.simul.scale_pixel(wavefront: OpticalWavefront, new_dx: float | Float[Array, '']) OpticalWavefront[source]¶
Rescale OpticalWavefront pixel size while keeping array shape fixed.
JAX-compatible (jit/vmap-safe). Crops or pads to preserve shape.
- Parameters:
wavefront (
OpticalWavefront) – OpticalWavefront to be resized.new_dx (
float) – New pixel size (meters).
- Returns:
scaled_wavefront – Resized OpticalWavefront with updated pixel size and resized field, which is of the same size as the original field.
- Return type:
OpticalWavefront
Notes
If the new pixel size is smaller than the old one, then the new FOV is smaller too at the same field size. So we will first find the new smaller FOV, and crop to that size with the current pixel size. Then we will resize to the new pizel size with the cropped FOV so that the size of the field remains the same. So here the order is crop, then resize.
If the new pixel size is larger than the old one, then the new FOV of the final field is larger too
Return the resized OpticalWavefront.
- janssen.simul.linear_interaction(sample: SampleFunction, light: OpticalWavefront) OpticalWavefront[source]¶
Propagate an optical wavefront through a sample.
The sample is modeled as a complex function that modifies the incoming wavefront. Using linear interaction.
- Parameters:
sample (
SampleFunction) – The sample function representing the optical properties of the samplelight (
OpticalWavefront) – The incoming optical wavefront
- Returns:
The propagated optical wavefront after passing through the sample
- Return type:
OpticalWavefront
- janssen.simul.simple_diffractogram(sample_cut: SampleFunction, lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) Diffractogram[source]¶
Calculate the diffractogram of a sample using a simple model.
The lightwave interacts with the sample linearly, and is then zoomed optically. Following this it interacts with a circular aperture before propagating to the camera plane. The camera image is then scaled to the pixel size of the camera. The diffractogram is created from the camera image.
- Parameters:
sample_cut (
SampleFunction) – The sample function representing the optical properties of the samplelightwave (
OpticalWavefront) – The incoming optical wavefrontzoom_factor (
float) – The zoom factor for the optical systemaperture_diameter (
float) – The diameter of the aperture in meterstravel_distance (
float) – The distance traveled by the light in meterscamera_pixel_size (
float) – The pixel size of the camera in metersaperture_center (
Optional[Float[Array," 2"]], optional) – The center of the aperture in pixels
- Returns:
The calculated diffractogram of the sample
- Return type:
Diffractogram
Notes
Algorithm:
Propagate the lightwave through the sample using linear interaction
Apply optical zoom to the wavefront
Apply a circular aperture to the zoomed wavefront
Propagate the wavefront to the camera plane using Fraunhofer propagation
Scale the pixel size of the camera image
Calculate the field intensity of the camera image
Create a diffractogram from the camera image
- janssen.simul.simple_microscope(sample: SampleFunction, positions: Num[Array, 'n 2'], lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) MicroscopeData[source]¶
Calculate the 3D diffractograms of the entire imaging.
This cuts the sample, and then generates a diffractogram with the desired camera pixel size - all done in parallel. Done at every pixel positions.
- Parameters:
sample (
SampleFunction) – The sample function representing the optical properties of the samplepositions (
Num[Array," n 2"]) – The positions in the sample plane where the diffractograms are calculated.lightwave (
OpticalWavefront) – The incoming optical wavefrontzoom_factor (
float) – The zoom factor for the optical systemaperture_diameter (
float) – The diameter of the aperture in meterstravel_distance (
float) – The distance traveled by the light in meterscamera_pixel_size (
float) – The pixel size of the camera in metersaperture_center (
Optional[Float[Array," 2"]], optional) – The center of the aperture in pixels
- Returns:
The calculated diffractograms of the sample at the specified positions
- Return type:
MicroscopeData
Notes
Algorithm:
Get the size of the lightwave field
Calculate the pixel positions in the sample plane
For each position, cut out the sample and calculate the diffractogram
Combine the diffractograms into a single MicroscopeData object
Return the MicroscopeData object
janssen.utils¶
Common utility functions used throughout the code.
Extended Summary¶
Core utilities for the janssen package including type definitions, factory functions, and decorators for type checking and validation. Provides the foundation for type-safe JAX programming with PyTrees.
Submodules¶
- factory
Factory functions for creating data structures
- types
Type definitions and PyTrees
Routine Listings¶
- DiffractogramPyTree
PyTree for storing diffraction patterns
- GridParamsPyTree
PyTree for computational grid parameters
- LensParamsPyTree
PyTree for lens optical parameters
- MicroscopeDataPyTree
PyTree for microscopy data
- OpticalWavefrontPyTree
PyTree for optical wavefront representation
- OptimizerStatePyTree
PyTree for optimizer state tracking
- PtychographyParamsPyTree
PyTree for ptychography reconstruction parameters
- SampleFunctionPyTree
PyTree for sample representation
- make_diffractogramfunction
Factory function for Diffractogram creation
- make_grid_paramsfunction
Factory function for GridParams creation
- make_lens_paramsfunction
Factory function for LensParams creation
- make_microscope_datafunction
Factory function for MicroscopeData creation
- make_optical_wavefrontfunction
Factory function for OpticalWavefront creation
- make_optimizer_statefunction
Factory function for OptimizerState creation
- make_ptychography_paramsfunction
Factory function for PtychographyParams creation
- make_sample_functionfunction
Factory function for SampleFunction creation
- non_jax_numberTypeAlias
Type alias for Python numeric types
- scalar_boolTypeAlias
Type alias for scalar boolean values
- scalar_complexTypeAlias
Type alias for scalar complex values
- scalar_floatTypeAlias
Type alias for scalar float values
- scalar_integerTypeAlias
Type alias for scalar integer values
- scalar_numericTypeAlias
Type alias for any scalar numeric value
Notes
Always use factory functions for creating PyTree instances to ensure proper type checking and validation. All PyTrees are registered with JAX and support automatic differentiation.
- janssen.utils.make_diffractogram(image: Float[Array, 'hh ww'], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) Diffractogram[source]¶
JAX-safe factory function for Diffractogram with data validation.
- Parameters:
- Returns:
validated_diffractogram – Validated diffractogram instance
- Return type:
- Raises:
ValueError – If data is invalid or parameters are out of valid ranges
Notes
Algorithm:
Convert inputs to JAX arrays
- Validate image array:
Check it’s 2D
Ensure all values are finite and non-negative
- Validate parameters:
Check wavelength is positive
Check dx is positive
Create and return Diffractogram instance
- janssen.utils.make_grid_params(xx: Float[Array, 'hh ww'], yy: Float[Array, 'hh ww'], phase_profile: Float[Array, 'hh ww'], transmission: Float[Array, 'hh ww']) GridParams[source]¶
JAX-safe factory function for GridParams with data validation.
- Parameters:
xx (
Float[Array," hh ww"]) – Spatial grid in the x-directionyy (
Float[Array," hh ww"]) – Spatial grid in the y-directionphase_profile (
Float[Array," hh ww"]) – Phase profile of the optical fieldtransmission (
Float[Array," hh ww"]) – Transmission profile of the optical field
- Returns:
validated_grid_params – Validated grid parameters instance
- Return type:
- Raises:
ValueError – If array shapes are inconsistent or data is invalid
Notes
Algorithm:
Convert inputs to JAX arrays
- Validate array shapes:
Check all arrays are 2D
Check all arrays have the same shape
- Validate data:
Ensure transmission values are between 0 and 1
Ensure phase values are finite
Ensure grid coordinates are finite
Create and return GridParams instance
- janssen.utils.make_lens_params(focal_length: float | Float[Array, ''], diameter: float | Float[Array, ''], n: float | Float[Array, ''], center_thickness: float | Float[Array, ''], r1: float | Float[Array, ''], r2: float | Float[Array, '']) LensParams[source]¶
JAX-safe factory function for LensParams with data validation.
- Parameters:
focal_length (
float) – Focal length of the lens in metersdiameter (
float) – Diameter of the lens in metersn (
float) – Refractive index of the lens materialcenter_thickness (
float) – Thickness at the center of the lens in metersr1 (
float) – Radius of curvature of the first surface in meters (positive for convex)r2 (
float) – Radius of curvature of the second surface in meters (positive for convex)
- Returns:
validated_lens_params – Validated lens parameters instance
- Return type:
- Raises:
ValueError – If parameters are invalid or out of valid ranges
Notes
Algorithm:
Convert inputs to JAX arrays
- Validate parameters:
Check focal_length is positive
Check diameter is positive
Check refractive index is positive
Check center_thickness is positive
Check radii are finite
Create and return LensParams instance
- janssen.utils.make_microscope_data(image_data: Float[Array, 'pp hh ww'] | Float[Array, 'xx yy hh ww'], positions: Num[Array, 'pp 2'], wavelength: float | Float[Array, ''], dx: float | Float[Array, '']) MicroscopeData[source]¶
JAX-safe factory function for MicroscopeData with data validation.
- Parameters:
image_data (
Union[Float[Array," pp hh ww"],Float[Array," xx yy hh ww"]]) – 3D or 4D image data representing the optical fieldpositions (
Num[Array," pp 2"]) – Positions of the images during collectionwavelength (
float) – Wavelength of the optical wavefront in metersdx (
float) – Spatial sampling interval (grid spacing) in meters
- Returns:
validated_microscope_data – Validated microscope data instance
- Return type:
- Raises:
ValueError – If data is invalid or parameters are out of valid ranges
Notes
Algorithm:
Convert inputs to JAX arrays
- Validate image_data:
Check it’s 3D or 4D
Ensure all values are finite and non-negative
- Validate positions:
Check it’s 2D with shape (pp, 2)
Ensure all values are finite
- Validate parameters:
Check wavelength is positive
Check dx is positive
- Validate consistency:
Check P matches between image_data and positions
Create and return MicroscopeData instance
- janssen.utils.make_optical_wavefront(field: Complex[Array, 'hh ww'] | Complex[Array, 'hh ww 2'], wavelength: float | Float[Array, ''], dx: float | Float[Array, ''], z_position: float | Float[Array, ''], polarization: bool | Bool[Array, ''] | None = False) OpticalWavefront[source]¶
JAX-safe factory function for OpticalWavefront with data validation.
- Parameters:
field (
Union[Complex[Array," hh ww"],Complex[Array," hh ww 2"]]) – Complex amplitude of the optical field. Should be 2D for scalar fields or 3D with last dimension 2 for polarized fields.wavelength (
float) – Wavelength of the optical wavefront in metersdx (
float) – Spatial sampling interval (grid spacing) in metersz_position (
float) – Axial position of the wavefront in the propagation direction in meters.polarization (
scalar_bool, optional) – Whether the field is polarized (True for 3D field, False for 2D field). Default is False.
- Returns:
validated_optical_wavefront – Validated optical wavefront instance
- Return type:
- Raises:
ValueError – If data is invalid or parameters are out of valid ranges
Notes
Algorithm:
Convert inputs to JAX arrays
- Validate field array:
Check it’s 2D
Ensure all values are finite
- Validate parameters:
Check wavelength is positive
Check dx is positive
Check z_position is finite
Create and return OpticalWavefront instance
- janssen.utils.make_optimizer_state(shape: tuple, m: Complex[Array, '...'] | complex | Complex[Array, ''] | None = 1j, v: Float[Array, '...'] | float | Float[Array, ''] | None = 0.0, step: int | Int[Array, ''] | None = 0) OptimizerState[source]¶
JAX-safe factory function for OptimizerState with data validation.
- Parameters:
shape (
Tuple) – Shape of the parameters to be optimizedm (
Optional[Complex[Array,"..."]], optional) – First moment estimate. If None, initialized to zeros with given shape. Default is 1j.v (
Optional[Float[Array,"..."]], optional) – Second moment estimate. If None, initialized to zeros with given shape. Default is 0.0.step (
Optional[scalar_integer], optional) – Step count. Default is 0.
- Returns:
validated_optimizer_state – Validated optimizer state instance
- Return type:
- Raises:
ValueError – If arrays have incompatible shapes with the given shape parameter
Notes
Algorithm:
If m is None, initialize with complex zeros of given shape
If v is None, initialize with real zeros of given shape
If step is None, initialize to 0
Convert all inputs to JAX arrays with appropriate dtypes
Validate arrays have compatible shapes
Create and return OptimizerState instance
- janssen.utils.make_ptychography_params(zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], aperture_center: Float[Array, '2'], camera_pixel_size: float | Float[Array, ''], learning_rate: float | Float[Array, ''], num_iterations: int | Int[Array, '']) PtychographyParams[source]¶
Create a PtychographyParams PyTree with validated parameters.
- Parameters:
zoom_factor (
float) – Optical zoom factor for magnification (must be positive)aperture_diameter (
float) – Diameter of the aperture in meters (must be positive)travel_distance (
float) – Light propagation distance in meters (must be positive)aperture_center (
Float[Array," 2"]) – Center position of the aperture (x, y) in meterscamera_pixel_size (
float) – Camera pixel size in meters (must be positive)learning_rate (
float) – Learning rate for optimization (must be positive)num_iterations (
scalar_integer) – Number of optimization iterations (must be positive)
- Returns:
Validated ptychography parameters as a PyTree
- Return type:
Notes
This function performs runtime validation to ensure all parameters are properly formatted and within valid ranges before creating the PtychographyParams PyTree.
- janssen.utils.make_sample_function(sample: Complex[Array, 'hh ww'], dx: float | Float[Array, '']) SampleFunction[source]¶
JAX-safe factory function for SampleFunction with data validation.
- Parameters:
sample (
Complex[Array," hh ww"]) – The sample functiondx (
float) – Spatial sampling interval (grid spacing) in meters
- Returns:
validated_sample_function – Validated sample function instance
- Return type:
- Raises:
ValueError – If data is invalid or parameters are out of valid ranges
Notes
Algorithm:
Convert inputs to JAX arrays
- Validate sample array:
Check it’s 2D
Ensure all values are finite
- Validate parameters:
Check dx is positive
Create and return SampleFunction instance
- class janssen.utils.Diffractogram(image: Float[Array, 'hh ww'], wavelength: Float[Array, ''], dx: Float[Array, ''])[source]¶
Bases:
NamedTuplePyTree structure for representing a single diffractogram.
- image¶
Image data.
- Type:
Float[Array," hh ww"]
- wavelength¶
Wavelength of the optical wavefront in meters.
- Type:
Float[Array," "]
- dx¶
Spatial sampling interval (grid spacing) in meters.
- Type:
Float[Array," "]
-
image:
Float[Array, 'hh ww']¶ Alias for field number 0
-
wavelength:
Float[Array, '']¶ Alias for field number 1
-
dx:
Float[Array, '']¶ Alias for field number 2
- class janssen.utils.GridParams(xx: Float[Array, 'hh ww'], yy: Float[Array, 'hh ww'], phase_profile: Float[Array, 'hh ww'], transmission: Float[Array, 'hh ww'])[source]¶
Bases:
NamedTuplePyTree structure for computational grid parameters.
- xx¶
Spatial grid in the x-direction
- Type:
Float[Array," hh ww"]
- yy¶
Spatial grid in the y-direction
- Type:
Float[Array," hh ww"]
- phase_profile¶
Phase profile of the optical field
- Type:
Float[Array," hh ww"]
- transmission¶
Transmission profile of the optical field
- Type:
Float[Array," hh ww"]
Notes
This class is registered as a PyTree node, making it compatible with JAX transformations like jit, grad, and vmap. The auxiliary data in tree_flatten is None as all relevant data is stored in JAX arrays.
-
xx:
Float[Array, 'hh ww']¶ Alias for field number 0
-
yy:
Float[Array, 'hh ww']¶ Alias for field number 1
-
phase_profile:
Float[Array, 'hh ww']¶ Alias for field number 2
-
transmission:
Float[Array, 'hh ww']¶ Alias for field number 3
- class janssen.utils.LensParams(focal_length: Float[Array, ''], diameter: Float[Array, ''], n: Float[Array, ''], center_thickness: Float[Array, ''], r1: Float[Array, ''], r2: Float[Array, ''])[source]¶
Bases:
NamedTuplePyTree structure for lens parameters.
- focal_length¶
Focal length of the lens in meters
- Type:
Float[Array," "]
- diameter¶
Diameter of the lens in meters
- Type:
Float[Array," "]
- n¶
Refractive index of the lens material
- Type:
Float[Array," "]
- center_thickness¶
Thickness at the center of the lens in meters
- Type:
Float[Array," "]
- r1¶
Radius of curvature of the first surface in meters (positive for convex)
- Type:
Float[Array," "]
- r2¶
Radius of curvature of the second surface in meters ( positive for convex)
- Type:
Float[Array," "]
-
focal_length:
Float[Array, '']¶ Alias for field number 0
-
diameter:
Float[Array, '']¶ Alias for field number 1
-
n:
Float[Array, '']¶ Alias for field number 2
-
center_thickness:
Float[Array, '']¶ Alias for field number 3
-
r1:
Float[Array, '']¶ Alias for field number 4
-
r2:
Float[Array, '']¶ Alias for field number 5
- class janssen.utils.MicroscopeData(image_data: Float[Array, 'pp hh ww'] | Float[Array, 'xx yy hh ww'], positions: Num[Array, 'pp 2'], wavelength: Float[Array, ''], dx: Float[Array, ''])[source]¶
Bases:
NamedTuplePyTree structure for representing an 3D or 4D microscope image.
- image_data¶
3D or 4D image data representing the optical field.
- Type:
Float[Array," pp hh ww"] | Float[Array," xx yy hh ww"]
- positions¶
Positions of the images during collection.
- Type:
Num[Array," pp 2"]
- wavelength¶
Wavelength of the optical wavefront in meters.
- Type:
Float[Array," "]
- dx¶
Spatial sampling interval (grid spacing) in meters.
- Type:
Float[Array," "]
-
positions:
Num[Array, 'pp 2']¶ Alias for field number 1
-
wavelength:
Float[Array, '']¶ Alias for field number 2
-
dx:
Float[Array, '']¶ Alias for field number 3
- class janssen.utils.OpticalWavefront(field: Complex[Array, 'hh ww'] | Complex[Array, 'hh ww 2'], wavelength: Float[Array, ''], dx: Float[Array, ''], z_position: Float[Array, ''], polarization: Bool[Array, ''])[source]¶
Bases:
NamedTuplePyTree structure for representing an optical wavefront.
- field¶
Complex amplitude of the optical field. Can be scalar (H, W) or polarized with two components (H, W, 2).
- Type:
Union[Complex[Array," hh ww"],Complex[Array," hh ww 2"]]
- wavelength¶
Wavelength of the optical wavefront in meters.
- Type:
Float[Array," "]
- dx¶
Spatial sampling interval (grid spacing) in meters.
- Type:
Float[Array," "]
- z_position¶
Axial position of the wavefront along the propagation direction. In meters.
- Type:
Float[Array," "]
- polarization¶
Whether the field is polarized (True for 3D field, False for 2D field).
- Type:
Bool[Array," "]
-
wavelength:
Float[Array, '']¶ Alias for field number 1
-
dx:
Float[Array, '']¶ Alias for field number 2
-
z_position:
Float[Array, '']¶ Alias for field number 3
-
polarization:
Bool[Array, '']¶ Alias for field number 4
- class janssen.utils.OptimizerState(m: Complex[Array, '...'], v: Float[Array, '...'], step: Int[Array, ''])[source]¶
Bases:
NamedTuplePyTree structure for maintaining optimizer state.
- m¶
First moment estimate (for Adam-like optimizers)
- Type:
Complex[Array,"..."]
- v¶
Second moment estimate (for Adam-like optimizers)
- Type:
Float[Array,"..."]
- step¶
Step count
- Type:
Int[Array," "]
-
m:
']¶ Alias for field number 0
-
v:
']¶ Alias for field number 1
-
step:
Int[Array, '']¶ Alias for field number 2
- class janssen.utils.PtychographyParams(zoom_factor: Float[Array, ''], aperture_diameter: Float[Array, ''], travel_distance: Float[Array, ''], aperture_center: Float[Array, '2'], camera_pixel_size: Float[Array, ''], learning_rate: Float[Array, ''], num_iterations: Int[Array, ''])[source]¶
Bases:
NamedTuplePyTree structure for ptychography reconstruction parameters.
- zoom_factor¶
Optical zoom factor for magnification
- Type:
Float[Array," "]
- aperture_diameter¶
Diameter of the aperture in meters
- Type:
Float[Array," "]
- travel_distance¶
Light propagation distance in meters
- Type:
Float[Array," "]
- aperture_center¶
Center position of the aperture (x, y) in meters
- Type:
Float[Array," 2"]
- camera_pixel_size¶
Camera pixel size in meters (typically fixed)
- Type:
Float[Array," "]
- learning_rate¶
Learning rate for optimization
- Type:
Float[Array," "]
- num_iterations¶
Number of optimization iterations
- Type:
Int[Array," "]
Notes
This class encapsulates all the optical and optimization parameters used in ptychographic reconstruction. It is registered as a PyTree node to enable JAX transformations and gradient-based optimization of these parameters.
-
zoom_factor:
Float[Array, '']¶ Alias for field number 0
-
aperture_diameter:
Float[Array, '']¶ Alias for field number 1
-
travel_distance:
Float[Array, '']¶ Alias for field number 2
-
aperture_center:
Float[Array, '2']¶ Alias for field number 3
-
camera_pixel_size:
Float[Array, '']¶ Alias for field number 4
-
learning_rate:
Float[Array, '']¶ Alias for field number 5
-
num_iterations:
Int[Array, '']¶ Alias for field number 6
- class janssen.utils.SampleFunction(sample: Complex[Array, 'hh ww'], dx: Float[Array, ''])[source]¶
Bases:
NamedTuplePyTree structure for representing a sample function.
- sample¶
The sample function.
- Type:
Complex[Array," hh ww"]
- dx¶
Spatial sampling interval (grid spacing) in meters.
- Type:
Float[Array," "]
-
sample:
Complex[Array, 'hh ww']¶ Alias for field number 0
-
dx:
Float[Array, '']¶ Alias for field number 1