janssen.optics¶
Differentiable optical simulation toolkit.
Extended Summary¶
Comprehensive optical simulation framework for modeling light propagation through various optical elements. All components are differentiable and optimized for JAX transformations, enabling gradient-based optimization of optical systems.
Routine Listings¶
add_phase_screen()Add phase screen to field.
amplitude_grating_binary()Create binary amplitude grating.
annular_aperture()Create an annular (ring-shaped) aperture.
apply_aberration()Apply aberration to optical wavefront.
apply_phase_mask()Apply a phase mask to a field.
apply_phase_mask_fn()Apply a phase mask function.
astigmatism()Generate astigmatism aberration (Z5, Z6).
beam_splitter()Model beam splitter operation.
circular_aperture()Create a circular aperture.
coma()Generate coma aberration (Z7, Z8).
compute_phase_from_coeffs()Compute phase map from Zernike coefficients.
create_spatial_grid()Create computational spatial grid.
defocus()Generate defocus aberration (Z4).
factorial()JAX-compatible factorial computation.
field_intensity()Calculate field intensity.
gaussian_apodizer()Apply Gaussian apodization to a field.
gaussian_apodizer_elliptical()Apply elliptical Gaussian apodization.
generate_aberration_nm()Generate aberration phase map from (n, m) coefficients.
generate_aberration_noll()Generate aberration phase map from Noll coefficients.
half_waveplate()Half-wave plate transformation.
mirror_reflection()Model mirror reflection.
nd_filter()Neutral density filter.
nm_to_noll()Convert (n, m) indices to Noll index.
noll_to_nm()Convert Noll index to (n, m) indices.
normalize_field()Normalize optical field.
phase_grating_blazed_elliptical()Elliptical blazed phase grating.
phase_grating_sawtooth()Sawtooth phase grating.
phase_grating_sine()Sinusoidal phase grating.
phase_rms()Compute RMS of phase within the unit pupil.
polarizer_jones()Jones matrix for polarizer.
prism_phase_ramp()Phase ramp from prism.
quarter_waveplate()Quarter-wave plate transformation.
rectangular_aperture()Create a rectangular aperture.
scale_pixel()Scale pixel size in field.
sellmeier()Sellmeier equation for refractive index.
spherical_aberration()Generate spherical aberration (Z11).
supergaussian_apodizer()Apply super-Gaussian apodization.
supergaussian_apodizer_elliptical()Apply elliptical super-Gaussian apodization.
trefoil()Generate trefoil aberration (Z9, Z10).
variable_transmission_aperture()Create aperture with variable transmission.
waveplate_jones()General waveplate Jones matrix.
zernike_polynomial()Generate a single Zernike polynomial.
zernike_radial()Radial component of Zernike polynomial.
Notes
All simulation functions support automatic differentiation and can be composed to model complex optical systems. The toolkit is optimized for both forward simulation and inverse problems in optics.
- janssen.optics.annular_aperture(incoming: OpticalWavefront, inner_diameter: float | Float[Array, ''], outer_diameter: float | Float[Array, ''], center: float | Float[Array, ''] | Float[Array, '2'] = 0.0, transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply an annular (ring) aperture with inner and outer diameters.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront PyTree.inner_diameter (
float) – Inner blocked diameter in meters.outer_diameter (
float) – Outer clear aperture diameter in meters.center (
Float[Array," 2"], optional) – Ring center [x0, y0] in meters, by default [0, 0].transmittivity (
Optional[ScalarFloat], optional) – Uniform transmittivity in the ring (0..1), by default 1.0.
- Returns:
apertured – Wavefront after applying the annular aperture.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids in meters.
Compute radial distance from center.
Create mask for inner_radius < r <= outer_radius.
Multiply by transmittivity (clipped), apply, and return.
- janssen.optics.circular_aperture(incoming: OpticalWavefront, diameter: float | Float[Array, ''], center: float | Float[Array, ''] | Float[Array, '2'] = 0.0, transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply a circular aperture to the incoming wavefront.
The aperture is defined by its physical diameter and (optional) center.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront PyTree.diameter (
float) – Aperture diameter in meters.center (
Float[Array," 2"], optional) – Physical center [x0, y0] of the aperture in meters, by default [0, 0].transmittivity (
Optional[ScalarFloat], optional) – Uniform transmittivity inside the aperture (0..1), by default 1.0.
- Returns:
apertured – Wavefront after applying the circular aperture.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids in meters.
Compute radial distance from the specified center.
Create a binary mask for r <= diameter/2.
Multiply by transmittivity (clipped to [0, 1]).
Apply to the complex field and return.
- janssen.optics.gaussian_apodizer(incoming: OpticalWavefront, sigma: float | Float[Array, ''], center: float | Float[Array, ''] | Float[Array, '2'] = 0.0, peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply a Gaussian apodizer (smooth transmission mask) to the wavefront.
- Parameters:
incoming (
OpticalWavefront) – Input optical wavefront.sigma (
float) – Gaussian width parameter in meters.center (
Float[Array," 2"], optional) – Physical center [x0, y0] of the Gaussian in meters, by default [0, 0].peak_transmittivity (
Optional[ScalarFloat], optional) – Maximum transmission at the Gaussian center, by default 1.0.
- Returns:
apertured – Wavefront after applying Gaussian apodization.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids.
Compute squared radial distance from center.
Evaluate Gaussian exp(-r^2 / (2*sigma^2)).
Scale by peak transmittivity, clip to [0,1].
Multiply with incoming field and return.
- janssen.optics.gaussian_apodizer_elliptical(incoming: OpticalWavefront, sigma_x: float | Float[Array, ''], sigma_y: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0, center: float | Float[Array, ''] | Float[Array, '2'] = 0.0, peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply an elliptical Gaussian apodizer to the wavefront.
With optional rotation, through an angle theta.
- Parameters:
incoming (
OpticalWavefront) – Input optical wavefront.sigma_x (
float) – Gaussian width along the x’-axis (meters) after rotation by theta.sigma_y (
float) – Gaussian width along the y’-axis (meters) after rotation by theta.theta (
Optional[ScalarFloat], optional) – Rotation angle in radians (counter-clockwise), by default 0.0.center (
Float[Array," 2"], optional) – Physical center [x0, y0] in meters, by default [0, 0].peak_transmittivity (
Optional[ScalarFloat], optional) – Maximum transmission at the center, by default 1.0.
- Returns:
apertured – Wavefront after applying elliptical Gaussian apodization.
- Return type:
OpticalWavefront
See also
gaussian_apodizerApply a Gaussian apodizer (smooth transmission mask) to the wavefront.
supergaussian_apodizerApply a super-Gaussian apodizer (smooth transmission mask) to the wavefront.
Notes
Build centered (x, y) grids.
Translate by center, rotate by theta → (x’, y’).
Evaluate exp(-0.5 * ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )).
Scale by peak_transmittivity, clip to [0, 1].
Multiply with incoming field and return.
- janssen.optics.rectangular_aperture(incoming: OpticalWavefront, width: float | Float[Array, ''], height: float | Float[Array, ''], center: float | Float[Array, ''] | Float[Array, '2'] = 0.0, transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply an axis-aligned rectangular aperture to the incoming wavefront.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront PyTree.width (
float) – Rectangle width along x in meters.height (
float) – Rectangle height along y in meters.center (
Float[Array," 2"], optional) – Rectangle center [x0, y0] in meters, by default [0, 0].transmittivity (
Optional[ScalarFloat], optional) – Uniform transmittivity inside the rectangle (0..1), by default 1.0.
- Returns:
apertured – Wavefront after applying the rectangular aperture.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids in meters.
Compute half-width/half-height and an inside-rectangle mask.
Multiply by transmittivity (clipped).
Apply to the complex field and return.
- janssen.optics.supergaussian_apodizer(incoming: OpticalWavefront, sigma: float | Float[Array, ''], m: int | float | complex | Num[Array, ''], center: float | Float[Array, ''] | Float[Array, '2'] = 0.0, peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply a super-Gaussian apodizer to the wavefront.
Transmission profile: exp(- (r^2 / sigma^2)^m ).
- Parameters:
incoming (
OpticalWavefront) – Input optical wavefront.sigma (
float) – Width parameter in meters (sets the roll-off scale).m (
numeric) – Super-Gaussian order (m=1 → Gaussian, m>1 → flatter top).center (
Float[Array," 2"], optional) – Physical center [x0, y0] of the profile, by default [0, 0].peak_transmittivity (
Optional[ScalarFloat], optional) – Maximum transmission at the center, by default 1.0.
- Returns:
apertured – Wavefront after applying super-Gaussian apodization.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids.
Compute squared radial distance from center.
Evaluate exp(- (r^2 / sigma^2)^m ).
Scale by peak transmittivity, clip to [0,1].
Multiply with incoming field and return.
- janssen.optics.supergaussian_apodizer_elliptical(incoming: OpticalWavefront, sigma_x: float | Float[Array, ''], sigma_y: float | Float[Array, ''], m: int | float | complex | Num[Array, ''], theta: float | Float[Array, ''] | None = 0.0, center: float | Float[Array, ''] | Float[Array, '2'] = 0.0, peak_transmittivity: float | Float[Array, ''] | None = 1.0) OpticalWavefront[source]¶
Apply an elliptical super-Gaussian apodizer with optional rotation.
Transmission profile: exp( - ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )^m ).
- Parameters:
incoming (
OpticalWavefront) – Input optical wavefront.sigma_x (
float) – Width along x’ (meters) after rotation by theta.sigma_y (
float) – Width along y’ (meters) after rotation by theta.m (
numeric) – Super-Gaussian order (m=1 → Gaussian; m>1 → flatter top, sharper edges).theta (
Optional[ScalarFloat], optional) – Rotation angle in radians (counter-clockwise), by default 0.0.center (
Float[Array," 2"], optional) – Physical center [x0, y0] in meters, by default [0, 0].peak_transmittivity (
Optional[ScalarFloat], optional) – Maximum transmission at the center, by default 1.0.
- Returns:
apertured – Wavefront after applying elliptical super-Gaussian apodization.
- Return type:
OpticalWavefront
Notes
Build centered (x, y) grids.
Translate by center, rotate by theta → (x’, y’).
Evaluate exp( - ( (x’/sigma_x)^2 + (y’/sigma_y)^2 )^m ).
Scale by peak_transmittivity, clip to [0, 1].
Multiply with incoming field and return.
- janssen.optics.variable_transmission_aperture(incoming: OpticalWavefront, transmission: float | Float[Array, ''] | Float[Array, '...']) OpticalWavefront[source]¶
Apply an arbitrary (spatially varying) transmission to the wavefront.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront PyTree.transmission (
Union[ScalarFloat,Float[Array," H W"]]) – Precomputed transmission map (0..1) with shape “H W”, or a scalar attenuation factor for uniform transmission.
- Returns:
transmitted – Wavefront after applying the transmission.
- Return type:
OpticalWavefront
Examples
Uniform attenuation:
>>> wf2 = variable_transmission_aperture(wf, 0.5) # 50% trans
Spatially varying transmission:
>>> tmap = create_transmission_map(...) # Shape (H, W) >>> wf2 = variable_transmission_aperture(wf, tmap)
Notes
For scalar transmission: applies uniform attenuation.
For array transmission: applies spatially varying transmission map.
Transmission values are clipped to [0, 1].
This function is fully JAX-compatible and uses jax.lax.cond.
- janssen.optics.amplitude_grating_binary(incoming: OpticalWavefront, period: float | Float[Array, ''], duty_cycle: float | Float[Array, ''] | None = 0.5, theta: float | Float[Array, ''] | None = 0.0, trans_high: float | Float[Array, ''] | None = 1.0, trans_low: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]¶
Binary amplitude grating with given duty cycle.
- Parameters:
incoming (
OpticalWavefront) – Input field.period (
float) – Period in meters.duty_cycle (
float, optional) – Fraction of period in ‘high’ state (0..1), by default 0.5.theta (
float, optional) – Orientation (radians), by default 0.0.trans_high (
float, optional) – Amplitude transmittance for ‘high’ bars, by default 1.0.trans_low (
float, optional) – Amplitude transmittance for ‘low’ bars, by default 0.0.
- Returns:
Field after amplitude modulation.
- Return type:
OpticalWavefront
Notes
Compute u along grating direction.
Map u modulo period → binary mask via duty cycle.
Apply amplitude levels to field.
- janssen.optics.apply_phase_mask(incoming: OpticalWavefront, phase_map: Float[Array, 'hh ww']) OpticalWavefront[source]¶
Apply an arbitrary phase mask (e.g., SLM, turbulence screen).
Field_out = field_in * exp(i * phase_map).
- Parameters:
incoming (
OpticalWavefront) – Input field.phase_map (
Float[Array," hh ww"]) – Phase in radians, same spatial shape as field.
- Returns:
masked_wavefront – Field with added phase.
- Return type:
OpticalWavefront
- janssen.optics.apply_phase_mask_fn(incoming: OpticalWavefront, phase_fn: Callable[[Float[Array, 'hh ww'], Float[Array, 'hh ww']], Float[Array, 'hh ww']]) OpticalWavefront[source]¶
Build and apply a phase mask from a callable phase_fn(xx, yy).
- Parameters:
incoming (
OpticalWavefront) – Input field.phase_fn (
callable) – Function producing a phase map (radians) given centered grids xx, yy (meters).
- Returns:
masked_wavefront – Field with added phase.
- Return type:
OpticalWavefront
- janssen.optics.beam_splitter(incoming: OpticalWavefront, t2: float | Float[Array, ''] | None = 0.5, r2: float | Float[Array, ''] | None = 0.5, normalize: bool | Bool[Array, ''] | None = True) tuple[OpticalWavefront, OpticalWavefront][source]¶
Split an input field into transmitted and reflected components.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront (scalar field).t2 (
float, optional) – Complex transmission amplitude, by default jnp.sqrt(0.5).r2 (
float, optional) – Complex reflection amplitude. Default 1j * jnp.sqrt(0.5) for 50/50 convention.normalize (
bool, optional) – If True, scale (t, r) so that |t|^2 + |r|^2 = 1, by default True.
- Return type:
tuple[OpticalWavefront,OpticalWavefront]- Returns:
wf_T (
OpticalWavefront) – Transmitted arm (t * field).wf_R (
OpticalWavefront) – Reflected arm (r * field).
Notes
Optionally renormalize (t, r).
Multiply field by t and r.
Return two wavefronts sharing same metadata.
- janssen.optics.half_waveplate(incoming: OpticalWavefront, theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]¶
Apply a half-wave plate (δ = π) with fast-axis angle theta.
- Parameters:
incoming (
OpticalWavefront) – Vector field Complex[H, W, 2] (Jones: ex, ey).theta (
float, optional) – Fast-axis angle in radians (CCW from x), by default 0.0.
- Returns:
hw_wavefront – Retarded field after half-wave plate.
- Return type:
OpticalWavefront
Notes
Call waveplate_jones with delta = π.
- janssen.optics.mirror_reflection(incoming: OpticalWavefront, flip_x: bool | Bool[Array, ''] | None = True, flip_y: bool | Bool[Array, ''] | None = False, add_pi_phase: bool | Bool[Array, ''] | None = True, conjugate: bool | Bool[Array, ''] | None = True) OpticalWavefront[source]¶
- Mirror reflection:
coordinate flips with optional π-phase and conjugation.
- Parameters:
incoming (
OpticalWavefront) – Input wavefront.flip_x (
bool, optional) – Flip along x-axis (columns), by default True.flip_y (
bool, optional) – Flip along y-axis (rows), by default False.add_pi_phase (
bool, optional) – Multiply by exp(i*pi) = -1 to simulate phase inversion on reflection. Default True.conjugate (
bool, optional) – Conjugate the complex field, useful when reversing propagation direction. Default is True.
- Returns:
Reflected wavefront.
- Return type:
OpticalWavefront
Notes
Flip axes as requested (jnp.flip).
Optional complex conjugation.
Optional -1 factor for π phase.
- janssen.optics.nd_filter(incoming: OpticalWavefront, optical_density: float | Float[Array, ''] | None = 0.0, transmittance: float | Float[Array, ''] | None = -1.0) OpticalWavefront[source]¶
Neutral density (ND) filter as a uniform amplitude attenuator.
- Parameters:
- Returns:
nd_wavefront – Attenuated wavefront.
- Return type:
OpticalWavefront
Notes
Determine intensity T from OD or provided T.
Amplitude factor a = sqrt(T).
Multiply field by a and return.
- janssen.optics.phase_grating_blazed_elliptical(incoming: OpticalWavefront, period_x: float | Float[Array, ''], period_y: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0, depth: float | Float[Array, ''] | None = 6.283185307179586, two_dim: bool | Bool[Array, ''] | None = False) OpticalWavefront[source]¶
Orientation-aware elliptical blazed grating.
Supports anisotropic periods along rotated axes (x’, y’) and optional 2D blaze.
- Parameters:
incoming (
OpticalWavefront) – Input scalar wavefront.period_x (
float) – Blaze period along x’ in meters (after rotation by theta).period_y (
float) – Blaze period along y’ in meters (after rotation by theta).theta (
float, optional) – Grating orientation angle in radians (CCW from x), by default 0.0.depth (
float, optional) – Peak-to-peak phase depth in radians, by default 2π.two_dim (
bool, optional) – If False (default), apply a 1D blaze along x’ only. If True, create a 2D blazed lattice using both x’ and y’.
- Returns:
phase_grating_wavefront – Field after applying the elliptical blazed phase.
- Return type:
OpticalWavefront
Notes
Build centered grids xx, yy (meters) and rotate → (x’, y’).
- Compute fractional coordinates
..math:: fu = frac(x’/period_x) fv = frac(y’/period_y)
- if two_dim is True
..math:: phase = depth * frac(fu + fv)
- else,
..math:: phase = depth * fu
Multiply by exp(i * phase) and return.
- janssen.optics.phase_grating_sawtooth(incoming: OpticalWavefront, period: float | Float[Array, ''], depth: float | Float[Array, ''], theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]¶
Sawtooth phase grating with peak-to-peak depth (radians).
- Parameters:
- Returns:
grating – Field after blazed phase modulation.
- Return type:
OpticalWavefront
Notes
Compute fractional coordinate within each period.
- Sawtooth phase in [0, depth) → shift to mean-zero if desired
(kept at [0, depth)).
Apply phase with exp(i*phase).
- janssen.optics.phase_grating_sine(incoming: OpticalWavefront, period: float | Float[Array, ''], depth: float | Float[Array, ''], theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]¶
Sinusoidal phase grating.
Phase = depth * sin(2π * u / period), where u is the coordinate along the grating direction.
- Parameters:
- Returns:
Field after phase modulation.
- Return type:
OpticalWavefront
- janssen.optics.polarizer_jones(incoming: OpticalWavefront, theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]¶
Linear polarizer at angle theta (radians, CCW from x-axis).
Applied to a 2-component Jones field (ex, ey) stored in the last dimension.
- Parameters:
incoming (
OpticalWavefront) – Field shape must be Complex[H, W, 2].theta (
float, optional) – Transmission axis angle (radians), by default 0.0.
- Returns:
Polarized field with same shape.
- Return type:
OpticalWavefront
Notes
Jones matrix: P = R(-θ) @ [[1, 0],[0, 0]] @ R(θ).
Apply P to [ex, ey] at each pixel.
- janssen.optics.prism_phase_ramp(incoming: OpticalWavefront, deflect_x: float | Float[Array, ''] | None = 0.0, deflect_y: float | Float[Array, ''] | None = 0.0, use_small_angle: bool | Bool[Array, ''] | None = True) OpticalWavefront[source]¶
Apply a linear phase ramp to simulate a prism-induced beam deviation.
- Parameters:
incoming (
OpticalWavefront) – Input scalar wavefront.deflect_x (
float, optional) – Deflection along +x. If use_small_angle is True, interpreted as angle (rad). Otherwise interpreted as spatial frequency kx [rad/m], by default 0.0.deflect_y (
float, optional) – Deflection along +y (angle or ky), by default 0.0.use_small_angle (
bool, optional) – If True, convert small angles to kx, ky via k*sin(angle) ~ k*angle. Default True.
- Returns:
output_phasefront – Wavefront with added linear phase.
- Return type:
OpticalWavefront
Notes
Build xx, yy grids (m).
Compute kx, ky from deflections.
Phase = kx*xx + ky*yy; multiply by exp(i*phase).
- janssen.optics.quarter_waveplate(incoming: OpticalWavefront, theta: float | Float[Array, ''] | None = 0.0) OpticalWavefront[source]¶
Apply a quarter-wave plate (δ = π/2) with fast-axis angle theta.
- Parameters:
incoming (
OpticalWavefront) – Vector field Complex[H, W, 2] (Jones: ex, ey).theta (
float, optional) – Fast-axis angle in radians (CCW from x), by default 0.0.
- Returns:
qw_wavefront – Retarded field after quarter-wave plate.
- Return type:
OpticalWavefront
Notes
Call waveplate_jones with delta = π/2.
- janssen.optics.waveplate_jones(incoming: OpticalWavefront, delta: float | Float[Array, ''], theta: float | Float[Array, ''] = 0.0) OpticalWavefront[source]¶
Waveplate/retarder with retardance delta and fast-axis angle theta.
Special cases: quarter-wave (delta=π/2), half-wave (delta=π).
- Parameters:
- Returns:
jones_wavefront – Retarded field with same shape.
- Return type:
OpticalWavefront
Notes
Jones matrix: J = R(-θ) @ diag(1, e^{iδ}) @ R(θ).
Apply J to [ex, ey] per pixel.
- janssen.optics.add_phase_screen(field: Num[Array, 'hh ww'], phase: Float[Array, 'hh ww']) Complex[Array, 'H W'][source]¶
Add a phase screen to a complex field.
- Parameters:
field (
Num[Array," hh ww"]) – Input complex field.phase (
Float[Array," hh ww"]) – Phase screen to add.
- Returns:
screened_field – Field with phase screen added.
- Return type:
Complex[Array," hh ww"]
Notes
Multiply the input field by the exponential of the phase screen.
Return the screened field.
- janssen.optics.create_spatial_grid(diameter: int | float | complex | Num[Array, ''] | Num[Array, '2'], num_points: int | Int[Array, ''] | Int[Array, '2']) tuple[Float[Array, 'hh ww'], Float[Array, 'hh ww']][source]¶
Create a 2D spatial grid for optical propagation.
- Parameters:
diameter (
ScalarNumeric | Num[Array," 2"]) – Physical size of the grid in meters. Can be a scalar (square grid) or array of shape (2,) with [diameter_x, diameter_y] for rectangular grids.num_points (
ScalarInteger | Int[Array," 2"]) – Number of points in each dimension. Can be a scalar (square grid) or array of shape (2,) with [num_points_x, num_points_y] for rectangular grids.
- Return type:
tuple[Float[Array, 'hh ww'],Float[Array, 'hh ww']]- Returns:
xx (
Float[Array," hh ww"]) – X coordinate grid in meters.yy (
Float[Array," hh ww"]) – Y coordinate grid in meters.
Notes
Create a linear space of points along the x-axis.
Create a linear space of points along the y-axis.
Create a meshgrid of spatial coordinates.
Return the meshgrid.
Supports both square and non-square grids without if-else statements.
Examples
Square grid:
>>> xx, yy = create_spatial_grid(1e-3, 256)
Rectangular grid:
>>> grid_size = jnp.asarray((256, 512), dtype=jnp.int32) >>> xx, yy = create_spatial_grid(jnp.array([1e-3, 2e-3]), grid_size)
- janssen.optics.field_intensity(field: Complex[Array, 'hh ww']) Float[Array, 'hh ww'][source]¶
Calculate intensity from complex field.
- Parameters:
field (
Complex[Array," hh ww"]) – Input complex field.- Returns:
intensity – Intensity of the field.
- Return type:
Float[Array," hh ww"]
Notes
Calculate the intensity as the square of the absolute value of the
field. - Return the intensity.
- janssen.optics.normalize_field(field: Complex[Array, 'hh ww']) Complex[Array, 'hh ww'][source]¶
Normalize complex field to unit power.
- Parameters:
field (
Complex[Array," hh ww"]) – Input complex field.- Returns:
normalized_field – Normalized complex field.
- Return type:
Complex[Array," hh ww"]
Notes
- Calculate the power of the field as the sum of the square of
the absolute value of the field.
Normalize the field by dividing by the square root of the power.
Return the normalized field.
- janssen.optics.scale_pixel(wavefront: OpticalWavefront, new_dx: float | Float[Array, '']) OpticalWavefront[source]¶
Rescale OpticalWavefront pixel size while keeping array shape fixed.
JAX-compatible (jit/vmap-safe). Crops or pads to preserve shape.
- Parameters:
wavefront (
OpticalWavefront) – OpticalWavefront to be resized.new_dx (
float) – New pixel size (meters).
- Returns:
scaled_wavefront – Resized OpticalWavefront with updated pixel size and resized field, which is of the same size as the original field.
- Return type:
OpticalWavefront
Notes
If the new pixel size is smaller than the old one, then the new FOV is smaller too at the same field size. So we will first find the new smaller FOV, and crop to that size with the current pixel size. Then we will resize to the new pizel size with the cropped FOV so that the size of the field remains the same. So here the order is crop, then resize.
If the new pixel size is larger than the old one, then the new FOV of the final field is larger too
Return the resized OpticalWavefront.
- janssen.optics.sellmeier(wavelength_nm: int | float | complex | Num[Array, ''] | Num[Array, 'nn'], sellmeier_b: Num[Array, '3'], sellmeier_c: Num[Array, '3']) Float[Array, ''] | Float[Array, 'nn'][source]¶
Calculate refractive index using the Sellmeier equation.
- Parameters:
wavelength_nm (
ScalarNumeric | Num[Array," nn"]) – Wavelength in nanometers. Can be scalar or array.sellmeier_b (
Num[Array," 3"]) – Sellmeier B coefficients [B1, B2, B3].sellmeier_c (
Num[Array," 3"]) – Sellmeier C coefficients [C1, C2, C3] in micrometers squared.
- Returns:
n – Refractive index. Shape matches input wavelength.
- Return type:
Union[Float[Array," "],Float[Array," nn"]]
Notes
The Sellmeier equation relates refractive index to wavelength:
\[n^2(\lambda) = 1 + \sum_{i=1}^{3} \frac{B_i\lambda^2}{\lambda^2 - C_i}\]where \(\lambda\) is in micrometers and \(C_i\) are in micrometers squared.
- janssen.optics.apply_aberration(incoming: OpticalWavefront, coefficients: Float[Array, 'N'], pupil_radius: float | Float[Array, '']) OpticalWavefront[source]¶
Apply Zernike aberrations to an optical wavefront.
- Parameters:
incoming (
OpticalWavefront) – Input wavefrontcoefficients (
Float[Array," N"]) – Noll-indexed Zernike coefficients in waves (index i = Noll index i+1)pupil_radius (
float) – Pupil radius in meters
- Returns:
wavefront_out – Aberrated wavefront
- Return type:
OpticalWavefront
- janssen.optics.astigmatism(xx: Float[Array, 'H W'], yy: Float[Array, 'H W'], amplitude_0: float | Float[Array, ''], amplitude_45: float | Float[Array, ''], pupil_radius: float | Float[Array, '']) Float[Array, 'H W'][source]¶
Generate astigmatism aberration (Z5 and Z6 in Noll notation).
- Parameters:
xx (
Float[Array," H W"]) – X coordinate grid in metersyy (
Float[Array," H W"]) – Y coordinate grid in metersamplitude_0 (
float) – Vertical/horizontal astigmatism amplitude in waves (Z6)amplitude_45 (
float) – Oblique astigmatism amplitude in waves (Z5)pupil_radius (
float) – Pupil radius in meters
- Returns:
phase – Astigmatism phase map in radians
- Return type:
Float[Array," H W"]
- janssen.optics.coma(xx: Float[Array, 'H W'], yy: Float[Array, 'H W'], amplitude_x: float | Float[Array, ''], amplitude_y: float | Float[Array, ''], pupil_radius: float | Float[Array, '']) Float[Array, 'H W'][source]¶
Generate coma aberration (Z7 and Z8 in Noll notation).
- Parameters:
- Returns:
phase – Coma phase map in radians
- Return type:
Float[Array," H W"]
- janssen.optics.compute_phase_from_coeffs(rho: Float[Array, '*batch'], theta: Float[Array, '*batch'], coefficients: Float[Array, 'N'], start_noll: int = 4) Float[Array, '*batch'][source]¶
Compute phase map from Zernike coefficients.
Generates a phase aberration map by summing normalized Zernike polynomials weighted by the provided coefficients. The coefficients are mapped to consecutive Noll indices starting from start_noll.
- Parameters:
rho (
Float[Array," *batch"]) – Normalized radial coordinate (0 to 1)theta (
Float[Array," *batch"]) – Azimuthal angle in radianscoefficients (
Float[Array," N"]) – Zernike coefficients in waves. Element i corresponds to Noll index (start_noll + i).start_noll (
int, optional) – Starting Noll index for the coefficients, by default 4 (defocus). Common choices: 1 (piston), 4 (defocus, skipping tip/tilt).
- Returns:
Phase map in radians
- Return type:
Float[Array," *batch"]
Notes
- The phase is computed as:
phase = 2 * pi * sum_i(coefficients[i] * Z_{start_noll + i})
where Z_j is the normalized Zernike polynomial for Noll index j. The output is in radians, with coefficients interpreted as waves.
Examples
>>> # Compute phase for defocus through spherical aberration (j=4 to j=11) >>> coeffs = jnp.array([0.5, 0.1, -0.2, 0.0, 0.0, 0.0, 0.0, 0.3]) >>> phase = compute_phase_from_coeffs(rho, theta, coeffs, start_noll=4)
- janssen.optics.defocus(xx: Float[Array, 'hh ww'], yy: Float[Array, 'hh ww'], amplitude: float | Float[Array, ''], pupil_radius: float | Float[Array, '']) Float[Array, 'hh ww'][source]¶
Generate defocus aberration (Z4 in Noll notation).
- Parameters:
- Returns:
phase – Defocus phase map in radians
- Return type:
Float[Array," hh ww"]
- janssen.optics.factorial(n: Int[Array, '']) Int[Array, ''][source]¶
JAX-compatible factorial computation.
- Parameters:
n (
Int[Array," "]) – Non-negative integer- Returns:
n! (n factorial)
- Return type:
Int[Array," "]
- janssen.optics.generate_aberration_nm(xx: Float[Array, 'H W'], yy: Float[Array, 'H W'], n_indices: Int[Array, 'N'], m_indices: Int[Array, 'N'], coefficients: Float[Array, 'N'], pupil_radius: float | Float[Array, '']) Float[Array, 'H W'][source]¶
Generate aberration from (n,m) indices and coefficients.
- Parameters:
xx (
Float[Array," H W"]) – X coordinate grid in metersyy (
Float[Array," H W"]) – Y coordinate grid in metersn_indices (
Int[Array," N"]) – Array of radial ordersm_indices (
Int[Array," N"]) – Array of azimuthal frequenciescoefficients (
Float[Array," N"]) – Zernike coefficients in wavespupil_radius (
float) – Pupil radius in meters
- Returns:
phase_radians – Phase aberration map in radians
- Return type:
Float[Array," H W"]
Notes
This version is fully JAX-compatible and can be JIT-compiled. Uses jax.lax.scan for efficient accumulation with traced-compatible Zernike polynomial computation.
- janssen.optics.generate_aberration_noll(xx: Float[Array, 'hh ww'], yy: Float[Array, 'hh ww'], coefficients: Float[Array, 'nn'], pupil_radius: float | Float[Array, '']) Float[Array, 'hh ww'][source]¶
Generate aberration from Noll-indexed coefficients.
- Parameters:
xx (
Float[Array," hh ww"]) – X coordinate grid in metersyy (
Float[Array," hh ww"]) – Y coordinate grid in meterscoefficients (
Float[Array," nn"]) – Zernike coefficients in waves, indexed by Noll index. Element 0 corresponds to j=1 (piston), element 1 to j=2, etc.pupil_radius (
float) – Pupil radius in meters
- Returns:
phase – Phase aberration map in radians
- Return type:
Float[Array," hh ww"]
Notes
Converts Noll indices to (n,m) pairs and calls generate_aberration_nm. Uses vectorized JAX operations for the Noll-to-nm conversion. Sign convention: j even -> m >= 0 (cosine), j odd -> m <= 0 (sine).
The radial order n is computed from n(n+1)/2 < j <= (n+1)(n+2)/2. The position k within row n determines |m|, which follows the pattern: 0,2,2,4,4,… for n even and 1,1,3,3,5,5,… for n odd.
- janssen.optics.nm_to_noll(n: int, m: int) int[source]¶
Convert (n, m) indices to Noll index.
- Parameters:
- Returns:
Noll index (starting from 1)
- Return type:
Notes
Sign convention: j even -> m >= 0 (cosine), j odd -> m <= 0 (sine).
The first Noll index for row n is j_base = n(n+1)/2 + 1.
For m=0, the position k within the row is 0. For m!=0, find the pair of k values for the given |m|, then select based on the sign of m and the parity requirement.
For n even: |m| values are 0,2,4,…; group index g = |m|/2; k_first = 2g-1 for g>0, else 0. For n odd: |m| values are 1,3,5,…; group index g = (|m|-1)/2; k_first = 2g.
The final k is chosen such that m > 0 yields an even j, and m < 0 yields an odd j.
- janssen.optics.noll_to_nm(j: int | Int[Array, '']) tuple[int, int][source]¶
Convert Noll index to (n, m) indices.
- Parameters:
j (
int) – Noll index (starting from 1)- Return type:
- Returns:
Notes
Uses the standard Noll ordering where j=1 corresponds to piston (n=0, m=0). Sign convention: j even -> m >= 0 (cosine), j odd -> m <= 0 (sine).
The radial order n is found from the cumulative count relation: n(n+1)/2 < j <= (n+1)(n+2)/2.
Within each row n, the position k (0-indexed) determines |m|. For n even: |m| follows pattern 0,2,2,4,4,… For n odd: |m| follows pattern 1,1,3,3,5,5,…
- janssen.optics.phase_rms(rho: Float[Array, '*batch'], theta: Float[Array, '*batch'], coefficients: Float[Array, 'N'], start_noll: int = 4) Float[Array, ''][source]¶
Compute RMS of phase within the unit pupil.
Calculates the root-mean-square of the phase aberration within the region where rho <= 1.0 (the unit pupil).
- Parameters:
rho (
Float[Array," *batch"]) – Normalized radial coordinate (0 to 1)theta (
Float[Array," *batch"]) – Azimuthal angle in radianscoefficients (
Float[Array," N"]) – Zernike coefficients in waves. Element i corresponds to Noll index (start_noll + i).start_noll (
int, optional) – Starting Noll index for the coefficients, by default 4 (defocus).
- Returns:
RMS phase value in radians
- Return type:
Float[Array," "]
Notes
- The RMS is computed as:
RMS = sqrt(mean((phase - mean(phase))^2))
where the mean is taken only over pixels within the unit pupil (rho <= 1). The piston (mean phase) is subtracted before computing RMS.
Examples
>>> # Compute RMS for a set of aberration coefficients >>> coeffs = jnp.array([0.5, 0.1, -0.2, 0.0, 0.0, 0.0, 0.0, 0.3]) >>> rms = phase_rms(rho, theta, coeffs, start_noll=4)
- janssen.optics.spherical_aberration(xx: Float[Array, 'H W'], yy: Float[Array, 'H W'], amplitude: float | Float[Array, ''], pupil_radius: float | Float[Array, '']) Float[Array, 'H W'][source]¶
Generate primary spherical aberration (Z11 in Noll notation).
- Parameters:
- Returns:
phase – Spherical aberration phase map in radians
- Return type:
Float[Array," H W"]
- janssen.optics.trefoil(xx: Float[Array, 'H W'], yy: Float[Array, 'H W'], amplitude_0: float | Float[Array, ''], amplitude_30: float | Float[Array, ''], pupil_radius: float | Float[Array, '']) Float[Array, 'H W'][source]¶
Generate trefoil aberration (Z9 and Z10 in Noll notation).
- Parameters:
- Returns:
trefoil_wavefront – Trefoil phase map in radians
- Return type:
Float[Array," H W"]
Notes
This function generates a trefoil aberration phase map in radians. The trefoil aberration is a combination of two Zernike polynomials: Z9 and Z10. The Z9 polynomial is the vertical trefoil aberration and the Z10 polynomial is the oblique trefoil aberration.
- janssen.optics.zernike_polynomial(rho: Float[Array, '*batch'], theta: Float[Array, '*batch'], n: int, m: int, normalize: bool = True) Float[Array, '*batch'][source]¶
Generate a single Zernike polynomial.
- Parameters:
rho (
Float[Array," *batch"]) – Normalized radial coordinate (0 to 1)theta (
Float[Array," *batch"]) – Azimuthal angle in radiansn (
int) – Radial order (n >= 0)m (
int) – Azimuthal frequency (|m| <= n, n-|m| must be even)normalize (
bool, optional) – Whether to normalize for unit RMS over unit circle, by default True
- Returns:
Zernike polynomial Z_n^m(rho, theta)
- Return type:
Float[Array," *batch"]
Notes
The polynomial is zero outside the unit circle (rho > 1). Normalization follows the convention where RMS over unit circle = 1. Angular part uses cosine for m>0, sine for m<0, and 1 for m=0. Normalization factor is sqrt(n+1) for m=0 and sqrt(2*(n+1)) for m≠0.
- janssen.optics.zernike_radial(rho: Float[Array, '*batch'], n: int, m: int) Float[Array, '*batch'][source]¶
Compute the radial component of Zernike polynomial.
- Parameters:
- Returns:
Radial polynomial R_n^|m|(rho)
- Return type:
Float[Array," *batch"]
Notes
Uses JAX-compatible validation that returns zeros for invalid (n,m) combinations where n-|m| is odd. Computes the radial polynomial using the standard formula with factorials for valid combinations. Uses jax.lax.scan for efficient accumulation of terms.