Source code for janssen.optics.apertures

"""Aperture functions for optical simulations.

Extended Summary
----------------
Optical aperture and apodization functions for controlling the amplitude
and phase of optical wavefronts. Includes both hard apertures and smooth
apodization functions commonly used in optical systems.

Routine Listings
----------------
circular_aperture : function
    Applies a circular aperture (optionally offset) with uniform
    transmittivity.
rectangular_aperture : function
    Applies an axis-aligned rectangular aperture with uniform
    transmittivity.
annular_aperture : function
    Applies a concentric ring (donut) aperture between inner/outer
    diameters.
variable_transmission_aperture : function
    Applies an arbitrary transmission mask (array or callable),
    including common apodizers such as Gaussian or super-Gaussian
gaussian_apodizer : function
    Applies a Gaussian apodizer (smooth transmission mask) to the
    wavefront.
supergaussian_apodizer : function
    Applies a super-Gaussian apodizer (smooth transmission mask) to
    wavefront.
gaussian_apodizer_elliptical : function
    Applies an elliptical Gaussian apodizer to the wavefront
supergaussian_apodizer_elliptical : function
    Applies an elliptical super-Gaussian apodizer to the wavefront
_arrayed_grids : function, internal
    Creates coordinate grids without array creation.

Notes
-----
All aperture functions are compatible with JAX transformations and
support automatic differentiation. The apertures can be combined to
create complex pupil functions for optical systems.
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Optional, Tuple, Union
from jaxtyping import Array, Bool, Float, Num, jaxtyped

from janssen.types import (
    OpticalWavefront,
    ScalarFloat,
    ScalarNumeric,
    make_optical_wavefront,
)


def _arrayed_grids(
    x0: Num[Array, " hh ww"],
    y0: Num[Array, " hh ww"],
    dx: Union[ScalarFloat, Num[Array, " 2"]],
) -> Tuple[Float[Array, " hh ww"], Float[Array, " hh ww"]]:
    """Create coordinate grids without array creation.

    Parameters
    ----------
    x0 : Num[Array, " hh ww"]
        Zero-valued input array for x coordinates.
    y0 : Num[Array, " hh ww"]
        Zero-valued input array for y coordinates.
    dx : Union[ScalarFloat, Num[Array, " 2"]]
        Grid spacing in meters. Can be scalar or 2-element array [dx, dy].

    Returns
    -------
    xx : Float[Array, " hh ww"]
        X coordinate grid in meters.
    yy : Float[Array, " hh ww"]
        Y coordinate grid in meters.
    """
    hh: int
    ww: int
    hh, ww = x0.shape
    dx_arr: Union[Num[Array, " "], Num[Array, " 2"]] = jnp.asarray(dx)
    dx_arr = jnp.atleast_1d(dx_arr)
    expected_ndim: int = 2
    dx_2elem = jnp.where(
        dx_arr.size >= expected_ndim,
        dx_arr[:2],
        jnp.array([dx_arr[0], dx_arr[0]]),
    )
    dx_val: Num[Array, " "] = dx_2elem[0]
    dy_val: Num[Array, " "] = dx_2elem[1]

    def x_line(
        arr: Num[Array, " hh ww"], spacing: Num[Array, " "]
    ) -> Num[Array, " hh ww"]:
        arr_x: Num[Array, " ww"] = jnp.arange(-ww // 2, ww // 2) * spacing
        arr_full: Num[Array, " hh ww"] = arr + jnp.repeat(arr_x, hh).reshape(
            hh, ww, order="F"
        )
        return arr_full

    def y_line(
        arr: Num[Array, " hh ww"], spacing: Num[Array, " "]
    ) -> Num[Array, " hh ww"]:
        arr_y: Num[Array, " hh"] = jnp.arange(-hh // 2, hh // 2) * spacing
        arr_full: Num[Array, " hh ww"] = arr + jnp.repeat(arr_y, ww).reshape(
            hh, ww, order="C"
        )
        return arr_full

    xx: Num[Array, " hh ww"] = x_line(x0, dx_val)
    yy: Num[Array, " hh ww"] = y_line(y0, dy_val)
    return (xx, yy)


[docs] @jaxtyped(typechecker=beartype) def circular_aperture( incoming: OpticalWavefront, diameter: ScalarFloat, center: Union[ScalarFloat, Float[Array, " 2"]] = 0.0, transmittivity: Optional[ScalarFloat] = 1.0, ) -> OpticalWavefront: """ Apply a circular aperture to the incoming wavefront. The aperture is defined by its physical diameter and (optional) center. Parameters ---------- incoming : OpticalWavefront Input wavefront PyTree. diameter : ScalarFloat Aperture diameter in meters. center : Float[Array, " 2"], optional Physical center [x0, y0] of the aperture in meters, by default [0, 0]. transmittivity : Optional[ScalarFloat], optional Uniform transmittivity inside the aperture (0..1), by default 1.0. Returns ------- apertured : OpticalWavefront Wavefront after applying the circular aperture. Notes ----- - Build centered (x, y) grids in meters. - Compute radial distance from the specified center. - Create a binary mask for r <= diameter/2. - Multiply by transmittivity (clipped to [0, 1]). - Apply to the complex field and return. """ center_array: Float[Array, " 2"] = jnp.atleast_2d( jnp.asarray(center, dtype=jnp.float64) ).ravel()[:2] arr_zeros: Float[Array, " hh ww"] = jnp.zeros_like( incoming.field, dtype=float ) xx: Float[Array, " hh ww"] yy: Float[Array, " hh ww"] xx, yy = _arrayed_grids(arr_zeros, arr_zeros, incoming.dx) x0: Float[Array, " "] y0: Float[Array, " "] x0, y0 = center_array[0], center_array[1] r: Float[Array, " hh ww"] = jnp.sqrt(((xx - x0) ** 2) + ((yy - y0) ** 2)) inside: Bool[Array, " hh ww"] = r <= (diameter / 2.0) t: Float[Array, " "] = jnp.clip( jnp.asarray(transmittivity, dtype=float), 0.0, 1.0 ) transmission: Float[Array, " hh ww"] = inside.astype(float) * t apertured: OpticalWavefront = make_optical_wavefront( field=incoming.field * transmission, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return apertured
[docs] @jaxtyped(typechecker=beartype) def rectangular_aperture( incoming: OpticalWavefront, width: ScalarFloat, height: ScalarFloat, center: Union[ScalarFloat, Float[Array, " 2"]] = 0.0, transmittivity: Optional[ScalarFloat] = 1.0, ) -> OpticalWavefront: """ Apply an axis-aligned rectangular aperture to the incoming wavefront. Parameters ---------- incoming : OpticalWavefront Input wavefront PyTree. width : ScalarFloat Rectangle width along x in meters. height : ScalarFloat Rectangle height along y in meters. center : Float[Array, " 2"], optional Rectangle center [x0, y0] in meters, by default [0, 0]. transmittivity : Optional[ScalarFloat], optional Uniform transmittivity inside the rectangle (0..1), by default 1.0. Returns ------- apertured : OpticalWavefront Wavefront after applying the rectangular aperture. Notes ----- - Build centered (x, y) grids in meters. - Compute half-width/half-height and an inside-rectangle mask. - Multiply by transmittivity (clipped). - Apply to the complex field and return. """ center_array: Float[Array, " 2"] = jnp.atleast_2d( jnp.asarray(center, dtype=jnp.float64) ).ravel()[:2] arr_zeros: Float[Array, " hh ww"] = jnp.zeros_like( incoming.field, dtype=float ) xx: Float[Array, " hh ww"] yy: Float[Array, " hh ww"] xx, yy = _arrayed_grids(arr_zeros, arr_zeros, incoming.dx) x0: Float[Array, " "] y0: Float[Array, " "] x0, y0 = center_array[0], center_array[1] hx: Float[Array, " "] = width / 2.0 hy: Float[Array, " "] = height / 2.0 inside_x: Bool[Array, " hh ww"] = ((x0 - hx) <= xx) & ((x0 + hx) >= xx) inside_y: Bool[Array, " hh ww"] = ((y0 - hy) <= yy) & ((y0 + hy) >= yy) inside: Bool[Array, " hh ww"] = inside_x & inside_y t: Float[Array, " "] = jnp.clip( jnp.asarray(transmittivity, dtype=float), 0.0, 1.0 ) transmission: Float[Array, " hh ww"] = inside.astype(float) * t apertured: OpticalWavefront = make_optical_wavefront( field=incoming.field * transmission, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return apertured
[docs] @jaxtyped(typechecker=beartype) def annular_aperture( incoming: OpticalWavefront, inner_diameter: ScalarFloat, outer_diameter: ScalarFloat, center: Union[ScalarFloat, Float[Array, " 2"]] = 0.0, transmittivity: Optional[ScalarFloat] = 1.0, ) -> OpticalWavefront: """ Apply an annular (ring) aperture with inner and outer diameters. Parameters ---------- incoming : OpticalWavefront Input wavefront PyTree. inner_diameter : ScalarFloat Inner blocked diameter in meters. outer_diameter : ScalarFloat Outer clear aperture diameter in meters. center : Float[Array, " 2"], optional Ring center [x0, y0] in meters, by default [0, 0]. transmittivity : Optional[ScalarFloat], optional Uniform transmittivity in the ring (0..1), by default 1.0. Returns ------- apertured : OpticalWavefront Wavefront after applying the annular aperture. Notes ----- - Build centered (x, y) grids in meters. - Compute radial distance from center. - Create mask for inner_radius < r <= outer_radius. - Multiply by transmittivity (clipped), apply, and return. """ center_array: Float[Array, " 2"] = jnp.atleast_2d( jnp.asarray(center, dtype=jnp.float64) ).ravel()[:2] arr_zeros: Float[Array, " hh ww"] = jnp.zeros_like( incoming.field, dtype=float ) xx: Float[Array, " hh ww"] yy: Float[Array, " hh ww"] xx, yy = _arrayed_grids(arr_zeros, arr_zeros, incoming.dx) x0: Float[Array, " "] y0: Float[Array, " "] x0, y0 = center_array[0], center_array[1] r: Float[Array, " hh ww"] = jnp.sqrt((xx - x0) ** 2 + (yy - y0) ** 2) r_in: Float[Array, " "] = inner_diameter / 2.0 r_out: Float[Array, " "] = outer_diameter / 2.0 ring: Bool[Array, " hh ww"] = (r > r_in) & (r <= r_out) t: Float[Array, " "] = jnp.clip( jnp.asarray(transmittivity, dtype=float), 0.0, 1.0 ) transmission: Float[Array, " hh ww"] = ring.astype(float) * t apertured: OpticalWavefront = make_optical_wavefront( field=incoming.field * transmission, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return apertured
[docs] @jaxtyped(typechecker=beartype) def variable_transmission_aperture( incoming: OpticalWavefront, transmission: Union[ScalarFloat, Float[Array, " ..."]], ) -> OpticalWavefront: """ Apply an arbitrary (spatially varying) transmission to the wavefront. Parameters ---------- incoming : OpticalWavefront Input wavefront PyTree. transmission : Union[ScalarFloat, Float[Array, " H W"]] Precomputed transmission map (0..1) with shape "H W", or a scalar attenuation factor for uniform transmission. Returns ------- transmitted : OpticalWavefront Wavefront after applying the transmission. Examples -------- Uniform attenuation:: >>> wf2 = variable_transmission_aperture(wf, 0.5) # 50% trans Spatially varying transmission:: >>> tmap = create_transmission_map(...) # Shape (H, W) >>> wf2 = variable_transmission_aperture(wf, tmap) Notes ----- - For scalar transmission: applies uniform attenuation. - For array transmission: applies spatially varying transmission map. - Transmission values are clipped to [0, 1]. - This function is fully JAX-compatible and uses jax.lax.cond. """ trans: Float[Array, " ..."] = jnp.asarray(transmission, dtype=float) def apply_scalar_transmission() -> OpticalWavefront: t: Float[Array, " hh ww"] = jnp.clip(trans, 0.0, 1.0) return make_optical_wavefront( field=incoming.field * t, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) def apply_array_transmission() -> OpticalWavefront: tmap: Float[Array, " hh ww"] = jnp.clip(trans, 0.0, 1.0) return make_optical_wavefront( field=incoming.field * tmap, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) transmitted: OpticalWavefront = jax.lax.cond( trans.ndim == 0, apply_scalar_transmission, apply_array_transmission ) return transmitted
[docs] @jaxtyped(typechecker=beartype) def gaussian_apodizer( incoming: OpticalWavefront, sigma: ScalarFloat, center: Union[ScalarFloat, Float[Array, " 2"]] = 0.0, peak_transmittivity: Optional[ScalarFloat] = 1.0, ) -> OpticalWavefront: """ Apply a Gaussian apodizer (smooth transmission mask) to the wavefront. Parameters ---------- incoming : OpticalWavefront Input optical wavefront. sigma : ScalarFloat Gaussian width parameter in meters. center : Float[Array, " 2"], optional Physical center [x0, y0] of the Gaussian in meters, by default [0, 0]. peak_transmittivity : Optional[ScalarFloat], optional Maximum transmission at the Gaussian center, by default 1.0. Returns ------- apertured : OpticalWavefront Wavefront after applying Gaussian apodization. Notes ----- - Build centered (x, y) grids. - Compute squared radial distance from center. - Evaluate Gaussian exp(-r^2 / (2*sigma^2)). - Scale by peak transmittivity, clip to [0,1]. - Multiply with incoming field and return. """ center_array: Float[Array, " 2"] = jnp.atleast_2d( jnp.asarray(center, dtype=jnp.float64) ).ravel()[:2] arr_zeros: Float[Array, " hh ww"] = jnp.zeros_like( incoming.field, dtype=float ) xx: Float[Array, " hh ww"] yy: Float[Array, " hh ww"] xx, yy = _arrayed_grids(arr_zeros, arr_zeros, incoming.dx) x0: Float[Array, " "] y0: Float[Array, " "] x0, y0 = center_array[0], center_array[1] r2: Float[Array, " hh ww"] = ((xx - x0) ** 2) + ((yy - y0) ** 2) gauss: Float[Array, " hh ww"] = jnp.exp(-r2 / (2.0 * sigma**2)) tmap: Float[Array, " hh ww"] = jnp.clip( gauss * peak_transmittivity, 0.0, 1.0 ) apertured: OpticalWavefront = make_optical_wavefront( field=incoming.field * tmap, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return apertured
[docs] @jaxtyped(typechecker=beartype) def supergaussian_apodizer( incoming: OpticalWavefront, sigma: ScalarFloat, m: ScalarNumeric, center: Union[ScalarFloat, Float[Array, " 2"]] = 0.0, peak_transmittivity: Optional[ScalarFloat] = 1.0, ) -> OpticalWavefront: """ Apply a super-Gaussian apodizer to the wavefront. Transmission profile: exp(- (r^2 / sigma^2)^m ). Parameters ---------- incoming : OpticalWavefront Input optical wavefront. sigma : ScalarFloat Width parameter in meters (sets the roll-off scale). m : ScalarNumeric Super-Gaussian order (m=1 → Gaussian, m>1 → flatter top). center : Float[Array, " 2"], optional Physical center [x0, y0] of the profile, by default [0, 0]. peak_transmittivity : Optional[ScalarFloat], optional Maximum transmission at the center, by default 1.0. Returns ------- apertured : OpticalWavefront Wavefront after applying super-Gaussian apodization. Notes ----- - Build centered (x, y) grids. - Compute squared radial distance from center. - Evaluate exp(- (r^2 / sigma^2)^m ). - Scale by peak transmittivity, clip to [0,1]. - Multiply with incoming field and return. """ arr_zeros: Float[Array, " hh ww"] = jnp.zeros_like( incoming.field, dtype=float ) xx: Float[Array, " hh ww"] yy: Float[Array, " hh ww"] xx, yy = _arrayed_grids(arr_zeros, arr_zeros, incoming.dx) center_arr = jnp.broadcast_to(jnp.asarray(center), (2,)) x0, y0 = center_arr[0], center_arr[1] r2: Float[Array, " hh ww"] = (xx - x0) ** 2 + (yy - y0) ** 2 super_gauss: Float[Array, " hh ww"] = jnp.exp(-((r2 / (sigma**2)) ** m)) tmap: Float[Array, " hh ww"] = jnp.clip( super_gauss * peak_transmittivity, 0.0, 1.0 ) apertured: OpticalWavefront = make_optical_wavefront( field=incoming.field * tmap, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return apertured
[docs] @jaxtyped(typechecker=beartype) def gaussian_apodizer_elliptical( incoming: OpticalWavefront, sigma_x: ScalarFloat, sigma_y: ScalarFloat, theta: Optional[ScalarFloat] = 0.0, center: Union[ScalarFloat, Float[Array, " 2"]] = 0.0, peak_transmittivity: Optional[ScalarFloat] = 1.0, ) -> OpticalWavefront: """ Apply an elliptical Gaussian apodizer to the wavefront. With optional rotation, through an angle `theta`. Parameters ---------- incoming : OpticalWavefront Input optical wavefront. sigma_x : ScalarFloat Gaussian width along the x'-axis (meters) after rotation by `theta`. sigma_y : ScalarFloat Gaussian width along the y'-axis (meters) after rotation by `theta`. theta : Optional[ScalarFloat], optional Rotation angle in radians (counter-clockwise), by default 0.0. center : Float[Array, " 2"], optional Physical center [x0, y0] in meters, by default [0, 0]. peak_transmittivity : Optional[ScalarFloat], optional Maximum transmission at the center, by default 1.0. Returns ------- apertured : OpticalWavefront Wavefront after applying elliptical Gaussian apodization. See Also -------- gaussian_apodizer : Apply a Gaussian apodizer (smooth transmission mask) to the wavefront. supergaussian_apodizer : Apply a super-Gaussian apodizer (smooth transmission mask) to the wavefront. Notes ----- - Build centered (x, y) grids. - Translate by `center`, rotate by `theta` → (x', y'). - Evaluate exp(-0.5 * ( (x'/sigma_x)^2 + (y'/sigma_y)^2 )). - Scale by `peak_transmittivity`, clip to [0, 1]. - Multiply with incoming field and return. """ center_array: Float[Array, " 2"] = jnp.atleast_2d( jnp.asarray(center, dtype=jnp.float64) ).ravel()[:2] arr_zeros: Float[Array, " hh ww"] = jnp.zeros_like( incoming.field, dtype=float ) xx: Float[Array, " hh ww"] yy: Float[Array, " hh ww"] xx, yy = _arrayed_grids(arr_zeros, arr_zeros, incoming.dx) x0: Float[Array, " "] y0: Float[Array, " "] x0, y0 = center_array[0], center_array[1] xc: Float[Array, " hh ww"] = xx - x0 yc: Float[Array, " hh ww"] = yy - y0 ct: Float[Array, " "] = jnp.cos(theta) st: Float[Array, " "] = jnp.sin(theta) xp: Float[Array, " hh ww"] = (ct * xc) + (st * yc) yp: Float[Array, " hh ww"] = (ct * yc) - (st * xc) arg: Float[Array, " hh ww"] = ((xp / sigma_x) ** 2) + ((yp / sigma_y) ** 2) gauss: Float[Array, " hh ww"] = jnp.exp(-0.5 * arg) tmap: Float[Array, " hh ww"] = jnp.clip( gauss * peak_transmittivity, 0.0, 1.0 ) apertured: OpticalWavefront = make_optical_wavefront( field=incoming.field * tmap, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return apertured
[docs] @jaxtyped(typechecker=beartype) def supergaussian_apodizer_elliptical( incoming: OpticalWavefront, sigma_x: ScalarFloat, sigma_y: ScalarFloat, m: ScalarNumeric, theta: Optional[ScalarFloat] = 0.0, center: Union[ScalarFloat, Float[Array, " 2"]] = 0.0, peak_transmittivity: Optional[ScalarFloat] = 1.0, ) -> OpticalWavefront: """ Apply an elliptical super-Gaussian apodizer with optional rotation. Transmission profile: exp( - ( (x'/sigma_x)^2 + (y'/sigma_y)^2 )^m ). Parameters ---------- incoming : OpticalWavefront Input optical wavefront. sigma_x : ScalarFloat Width along x' (meters) after rotation by `theta`. sigma_y : ScalarFloat Width along y' (meters) after rotation by `theta`. m : ScalarNumeric Super-Gaussian order (m=1 → Gaussian; m>1 → flatter top, sharper edges). theta : Optional[ScalarFloat], optional Rotation angle in radians (counter-clockwise), by default 0.0. center : Float[Array, " 2"], optional Physical center [x0, y0] in meters, by default [0, 0]. peak_transmittivity : Optional[ScalarFloat], optional Maximum transmission at the center, by default 1.0. Returns ------- apertured : OpticalWavefront Wavefront after applying elliptical super-Gaussian apodization. Notes ----- - Build centered (x, y) grids. - Translate by `center`, rotate by `theta` → (x', y'). - Evaluate exp( - ( (x'/sigma_x)^2 + (y'/sigma_y)^2 )^m ). - Scale by `peak_transmittivity`, clip to [0, 1]. - Multiply with incoming field and return. """ center_array: Float[Array, " 2"] = jnp.atleast_2d( jnp.asarray(center, dtype=jnp.float64) ).ravel()[:2] arr_zeros: Float[Array, " hh ww"] = jnp.zeros_like( incoming.field, dtype=float ) xx: Float[Array, " hh ww"] yy: Float[Array, " hh ww"] xx, yy = _arrayed_grids(arr_zeros, arr_zeros, incoming.dx) x0: Float[Array, " "] y0: Float[Array, " "] x0, y0 = center_array[0], center_array[1] xc: Float[Array, " hh ww"] = xx - x0 yc: Float[Array, " hh ww"] = yy - y0 ct: Float[Array, " "] = jnp.cos(theta) st: Float[Array, " "] = jnp.sin(theta) xp: Float[Array, " hh ww"] = (ct * xc) + (st * yc) yp: Float[Array, " hh ww"] = (ct * yc) - (st * xc) base: Float[Array, " hh ww"] = ((xp / sigma_x) ** 2) + ( (yp / sigma_y) ** 2 ) super_gauss: Float[Array, " hh ww"] = jnp.exp(-(base**m)) tmap: Float[Array, " hh ww"] = jnp.clip( super_gauss * peak_transmittivity, 0.0, 1.0 ) apertured: OpticalWavefront = make_optical_wavefront( field=incoming.field * tmap, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return apertured