Source code for janssen.simul.elements

"""Optical element implementations.

Extended Summary
----------------
Optical elements and components for building complex optical systems.
Includes gratings, waveplates, polarizers, beam splitters, and other
common optical components used in microscopy and optical systems.

Routine Listings
----------------
prism_phase_ramp : function
    Applies a linear phase ramp to simulate beam deviation/dispersion
beam_splitter : function
    Splits a field into transmitted and reflected arms with given (t, r)
mirror_reflection : function
    Applies mirror reflection(s): coordinate flip(s), optional conjugation,
    π phase
phase_grating_sine : function
    Sinusoidal phase grating
amplitude_grating_binary : function
    Binary amplitude grating with duty cycle
phase_grating_sawtooth : function
    Blazed (sawtooth) phase grating.
apply_phase_mask : function
    Applies an arbitrary phase mask (SLM / phase screen).
apply_phase_mask_fn : function
    Builds a phase mask from a callable f(xx, yy) and applies it.
polarizer_jones
    Linear polarizer at angle theta (Jones matrix) for 2-component fields.
waveplate_jones : function
    Waveplate (retarder) with retardance delta and fast axis angle theta.
nd_filter : function
    Neutral density filter with optical density (OD) or direct transmittance.
quarter_waveplate : function
    Quarter-waveplate with fast axis angle theta.
half_waveplate : function
    Half-waveplate with fast axis angle theta.
phase_grating_blazed_elliptical : function
    Elliptical blazed phase grating with period_x, period_y, theta,
    depth, and two_dim
_xy_grids : function, internal
    Build centered (x, y) grids.
_rotate_coords : function, internal
    Rotate coordinates by an angle theta.

Notes
-----
All optical elements are implemented as pure JAX functions and support
automatic differentiation. Elements can be composed to create complex
optical systems. Polarization-sensitive elements use Jones calculus for
vectorial field calculations.
    Rotates coordinates by an angle theta.
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Callable, Optional, Tuple
from jaxtyping import Array, Bool, Complex, Float, Num, jaxtyped

from janssen.utils import (
    OpticalWavefront,
    make_optical_wavefront,
    scalar_float,
)

from .helper import add_phase_screen

jax.config.update("jax_enable_x64", True)


def _xy_grids(
    nx: int, ny: int, dx: float
) -> Tuple[Float[Array, " ny nx"], Float[Array, " ny nx"]]:
    """
    Build centered (x, y) grids.

    Parameters
    ----------
    nx : int
        Number of pixels along x.
    ny : int
        Number of pixels along y.
    dx : float
        Pixel size in meters.

    Returns
    -------
    xx : Float[Array, " ny nx"]
        Grid of x coordinates.
    yy : Float[Array, " ny nx"]
        Grid of y coordinates.
    """
    x: Float[Array, " nx"] = jnp.arange(-nx // 2, nx // 2) * dx
    y: Float[Array, " ny"] = jnp.arange(-ny // 2, ny // 2) * dx
    xx: Float[Array, " ny nx"]
    yy: Float[Array, " ny nx"]
    xx, yy = jnp.meshgrid(x, y)
    return (xx, yy)


def _rotate_coords(
    xx: Num[Array, " ..."], yy: Num[Array, " ..."], theta: scalar_float
) -> Tuple[Num[Array, " ..."], Num[Array, " ..."]]:
    """
    Rotate coordinates by an angle theta.

    Parameters
    ----------
    xx : Num[Array, " ..."]
        Grid of x coordinates.
    yy : Num[Array, " ..."]
        Grid of y coordinates.
    theta : scalar_float
        Angle of rotation in radians.

    Returns
    -------
    uu : Num[Array, " ..."]
        Rotated x coordinates.
    vv : Num[Array, " ..."]
        Rotated y coordinates.

    Notes
    -----
    - Rotates coordinates by an angle theta.
    - Uses cosine and sine to compute the rotation matrix.
    - Returns the rotated coordinates.
    """
    ct: Float[Array, " "] = jnp.cos(theta)
    st: Float[Array, " "] = jnp.sin(theta)
    uu: Num[Array, " ..."] = (ct * xx) + (st * yy)
    vv: Num[Array, " ..."] = (ct * yy) - (st * xx)
    return (uu, vv)


[docs] @jaxtyped(typechecker=beartype) def prism_phase_ramp( incoming: OpticalWavefront, deflect_x: Optional[scalar_float] = 0.0, deflect_y: Optional[scalar_float] = 0.0, use_small_angle: Optional[bool] = True, ) -> OpticalWavefront: """ Apply a linear phase ramp to simulate a prism-induced beam deviation. Parameters ---------- incoming : OpticalWavefront Input scalar wavefront. deflect_x : scalar_float, optional Deflection along +x. If `use_small_angle` is True, interpreted as angle (rad). Otherwise interpreted as spatial frequency kx [rad/m], by default 0.0. deflect_y : scalar_float, optional Deflection along +y (angle or ky), by default 0.0. use_small_angle : bool, optional If True, convert small angles to kx, ky via k*sin(angle) ~ k*angle. Default True. Returns ------- OpticalWavefront Wavefront with added linear phase. Notes ----- - Build xx, yy grids (m). - Compute kx, ky from deflections. - Phase = kx*xx + ky*yy; multiply by exp(i*phase). """ ny: int nx: int ny, nx = incoming.field.shape[:2] xx: Num[Array, " ny nx"] yy: Num[Array, " ny nx"] xx, yy = _xy_grids(nx, ny, float(incoming.dx)) k: scalar_float = (2.0 * jnp.pi) / incoming.wavelength kx: scalar_float ky: scalar_float kx, ky = jax.lax.cond( use_small_angle, lambda: (k * deflect_x, k * deflect_y), lambda: (deflect_x, deflect_y), ) phase = (kx * xx) + (ky * yy) field_out = add_phase_screen(incoming.field, phase) return make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, )
[docs] @jaxtyped(typechecker=beartype) def beam_splitter( incoming: OpticalWavefront, t2: Optional[scalar_float] = 0.5, r2: Optional[scalar_float] = 0.5, normalize: Optional[bool] = True, ) -> Tuple[OpticalWavefront, OpticalWavefront]: """ Split an input field into transmitted and reflected components. Parameters ---------- incoming : OpticalWavefront Input wavefront (scalar field). t2 : scalar_float, optional Complex transmission amplitude, by default jnp.sqrt(0.5). r2 : scalar_float, optional Complex reflection amplitude. Default 1j * jnp.sqrt(0.5) for 50/50 convention. normalize : bool, optional If True, scale (t, r) so that |t|^2 + |r|^2 = 1, by default True. Returns ------- wf_T : OpticalWavefront Transmitted arm (t * field). wf_R : OpticalWavefront Reflected arm (r * field). Notes ----- - Optionally renormalize (t, r). - Multiply field by t and r. - Return two wavefronts sharing same metadata. """ t_val: Complex[Array, " "] = jnp.sqrt( jnp.asarray(t2, dtype=jnp.complex128) ) r_val: Complex[Array, " "] = 1j * jnp.sqrt( jnp.asarray(r2, dtype=jnp.complex128) ) def normalize_values() -> Tuple[Complex[Array, " "], Complex[Array, " "]]: power: Float[Array, " "] = (jnp.abs(t_val) ** 2) + ( jnp.abs(r_val) ** 2 ) sqrt_power: Float[Array, " "] = jnp.sqrt(jnp.maximum(power, 1e-20)) t_norm: Complex[Array, " "] = t_val / sqrt_power r_norm: Complex[Array, " "] = r_val / sqrt_power return (t_norm, r_norm) t_val, r_val = jax.lax.cond( normalize, normalize_values, lambda: (t_val, r_val) ) wf_t = make_optical_wavefront( field=incoming.field * t_val, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) wf_r = make_optical_wavefront( field=incoming.field * r_val, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return (wf_t, wf_r)
[docs] @jaxtyped(typechecker=beartype) def mirror_reflection( incoming: OpticalWavefront, flip_x: Optional[bool] = True, flip_y: Optional[bool] = False, add_pi_phase: Optional[bool] = True, conjugate: Optional[bool] = True, ) -> OpticalWavefront: """ Mirror reflection: coordinate flips with optional π-phase and conjugation. Parameters ---------- incoming : OpticalWavefront Input wavefront. flip_x : bool, optional Flip along x-axis (columns), by default True. flip_y : bool, optional Flip along y-axis (rows), by default False. add_pi_phase : bool, optional Multiply by exp(i*pi) = -1 to simulate phase inversion on reflection. Default True. conjugate : bool, optional Conjugate the complex field, useful when reversing propagation direction. Default is True. Returns ------- OpticalWavefront Reflected wavefront. Notes ----- - Flip axes as requested (jnp.flip). - Optional complex conjugation. - Optional -1 factor for π phase. """ field = incoming.field field = jax.lax.cond( flip_x, lambda f: jnp.flip(f, axis=-1), lambda f: f, field ) field = jax.lax.cond( flip_y, lambda f: jnp.flip(f, axis=-2), lambda f: f, field ) field = jax.lax.cond( conjugate, lambda f: jnp.conjugate(f), lambda f: f, field ) field = jax.lax.cond(add_pi_phase, lambda f: -f, lambda f: f, field) return make_optical_wavefront( field=field, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, )
[docs] @jaxtyped(typechecker=beartype) def phase_grating_sine( incoming: OpticalWavefront, period: scalar_float, depth: scalar_float, theta: Optional[scalar_float] = 0.0, ) -> OpticalWavefront: """ Sinusoidal phase grating. Phase = depth * sin(2π * u / period), where u is the coordinate along the grating direction. Parameters ---------- incoming : OpticalWavefront Input field. period : scalar_float Grating period in meters. depth : scalar_float Phase modulation depth in radians. theta : scalar_float, optional Grating orientation (radians, CCW from x), by default 0.0. Returns ------- OpticalWavefront Field after phase modulation. """ ny: int nx: int ny, nx = incoming.field.shape[:2] xx: Float[Array, " ny nx"] yy: Float[Array, " ny nx"] xx, yy = _xy_grids(nx, ny, float(incoming.dx)) uu: Num[Array, " ny nx"] uu, _ = _rotate_coords(xx, yy, theta) phase: Float[Array, " ny nx"] phase = depth * jnp.sin(2.0 * jnp.pi * uu / period) field_out: Complex[Array, " ny nx"] = add_phase_screen( incoming.field, phase ) return make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, )
[docs] @jaxtyped(typechecker=beartype) def amplitude_grating_binary( incoming: OpticalWavefront, period: scalar_float, duty_cycle: Optional[scalar_float] = 0.5, theta: Optional[scalar_float] = 0.0, trans_high: Optional[scalar_float] = 1.0, trans_low: Optional[scalar_float] = 0.0, ) -> OpticalWavefront: """ Binary amplitude grating with given duty cycle. Parameters ---------- incoming : OpticalWavefront Input field. period : scalar_float Period in meters. duty_cycle : scalar_float, optional Fraction of period in 'high' state (0..1), by default 0.5. theta : scalar_float, optional Orientation (radians), by default 0.0. trans_high : scalar_float, optional Amplitude transmittance for 'high' bars, by default 1.0. trans_low : scalar_float, optional Amplitude transmittance for 'low' bars, by default 0.0. Returns ------- OpticalWavefront Field after amplitude modulation. Notes ----- - Compute u along grating direction. - Map u modulo period → binary mask via duty cycle. - Apply amplitude levels to field. """ ny: int nx: int ny, nx = incoming.field.shape[:2] xx: Float[Array, " ny nx"] yy: Float[Array, " ny nx"] xx, yy = _xy_grids(nx, ny, float(incoming.dx)) uu: Num[Array, " ny nx"] uu, _ = _rotate_coords(xx, yy, theta) duty: Float[Array, " "] = jnp.clip(duty_cycle, 0.0, 1.0) frac: Num[Array, " ny nx"] = (uu / period) - jnp.floor(uu / period) mask_high: Bool[Array, " ny nx"] = frac < duty tmap: Float[Array, " ny nx"] = jnp.where( mask_high, trans_high, trans_low ).astype(jnp.float64) field_out: Complex[Array, " ny nx"] = incoming.field * tmap return make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, )
[docs] @jaxtyped(typechecker=beartype) def phase_grating_sawtooth( incoming: OpticalWavefront, period: scalar_float, depth: scalar_float, theta: scalar_float = 0.0, ) -> OpticalWavefront: """ Sawtooth phase grating with peak-to-peak depth (radians). Parameters ---------- incoming : OpticalWavefront Input field. period : scalar_float Grating period in meters. depth : scalar_float Phase depth over one period in radians. theta : scalar_float, optional Orientation (radians), by default 0.0. Returns ------- OpticalWavefront Field after blazed phase modulation. Notes ----- - Compute fractional coordinate within each period. - Sawtooth phase in [0, depth) → shift to mean-zero if desired (kept at [0, depth)). - Apply phase with exp(i*phase). """ ny, nx = incoming.field.shape[:2] xx, yy = _xy_grids(nx, ny, float(incoming.dx)) uu, _ = _rotate_coords(xx, yy, theta) frac = (uu / period) - jnp.floor(uu / period) # [0,1) phase = depth * frac field_out = add_phase_screen(incoming.field, phase) return make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, )
[docs] @jaxtyped(typechecker=beartype) def apply_phase_mask( incoming: OpticalWavefront, phase_map: Float[Array, " H W"], ) -> OpticalWavefront: """ Apply an arbitrary phase mask (e.g., SLM, turbulence screen). Field_out = field_in * exp(i * phase_map). Parameters ---------- incoming : OpticalWavefront Input field. phase_map : Float[Array, " H W"] Phase in radians, same spatial shape as field. Returns ------- OpticalWavefront Field with added phase. """ field_out = add_phase_screen(incoming.field, phase_map) return make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, )
[docs] @jaxtyped(typechecker=beartype) def apply_phase_mask_fn( incoming: OpticalWavefront, phase_fn: Callable[ [Float[Array, " H W"], Float[Array, " H W"]], Float[Array, " H W"] ], ) -> OpticalWavefront: """ Build and apply a phase mask from a callable `phase_fn(xx, yy)`. Parameters ---------- incoming : OpticalWavefront Input field. phase_fn : callable Function producing a phase map (radians) given centered grids xx, yy (meters). Returns ------- OpticalWavefront Field with added phase. """ ny, nx = incoming.field.shape[:2] xx, yy = _xy_grids(nx, ny, float(incoming.dx)) phase_map = phase_fn(xx, yy) field_out = add_phase_screen(incoming.field, phase_map) return make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, )
[docs] @jaxtyped(typechecker=beartype) def polarizer_jones( incoming: OpticalWavefront, theta: scalar_float = 0.0, ) -> OpticalWavefront: """ Linear polarizer at angle `theta` (radians, CCW from x-axis). Applied to a 2-component Jones field (ex, ey) stored in the last dimension. Parameters ---------- incoming : OpticalWavefront Field shape must be Complex[H, W, 2]. theta : scalar_float, optional Transmission axis angle (radians), by default 0.0. Returns ------- OpticalWavefront Polarized field with same shape. Notes ----- - Jones matrix: P = R(-θ) @ [[1, 0],[0, 0]] @ R(θ). - Apply P to [ex, ey] at each pixel. """ field = incoming.field ct = jnp.cos(theta) st = jnp.sin(theta) ex, ey = field[..., 0], field[..., 1] e_par = ex * ct + ey * st ex_out = e_par * ct ey_out = e_par * st field_out = jnp.stack([ex_out, ey_out], axis=-1) return make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, polarization=True, )
[docs] @jaxtyped(typechecker=beartype) def waveplate_jones( incoming: OpticalWavefront, delta: scalar_float, theta: scalar_float = 0.0, ) -> OpticalWavefront: """ Waveplate/retarder with retardance `delta` and fast-axis angle `theta`. Special cases: quarter-wave (delta=π/2), half-wave (delta=π). Parameters ---------- incoming : OpticalWavefront Field shape must be Complex[H, W, 2]. delta : scalar_float Phase delay between fast and slow axes in radians. theta : scalar_float, optional Fast-axis angle (radians, CCW from x), by default 0.0. Returns ------- jones_wavefront : OpticalWavefront Retarded field with same shape. Notes ----- - Jones matrix: J = R(-θ) @ diag(1, e^{iδ}) @ R(θ). - Apply J to [ex, ey] per pixel. """ field: Complex[Array, " hh ww 2"] = incoming.field ct: Float[Array, " "] = jnp.cos(theta) st: Float[Array, " "] = jnp.sin(theta) e: Complex[Array, " "] = jnp.exp(1j * delta) ex: Complex[Array, " hh ww"] ey: Complex[Array, " hh ww"] ex, ey = field[..., 0], field[..., 1] a: Complex[Array, " hh ww"] = (ct * ct) + (e * st * st) b: Complex[Array, " hh ww"] = (1.0 - e) * ct * st c: Complex[Array, " hh ww"] = b d: Complex[Array, " hh ww"] = (st * st) + (e * ct * ct) ex_out: Complex[Array, " hh ww"] = (a * ex) + (b * ey) ey_out: Complex[Array, " hh ww"] = (c * ex) + (d * ey) field_out: Complex[Array, " hh ww 2"] = jnp.stack( [ex_out, ey_out], axis=-1 ) jones_wavefront: OpticalWavefront = make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, polarization=True, ) return jones_wavefront
[docs] @jaxtyped(typechecker=beartype) def nd_filter( incoming: OpticalWavefront, optical_density: Optional[scalar_float] = 0.0, transmittance: Optional[scalar_float] = -1.0, ) -> OpticalWavefront: """ Neutral density (ND) filter as a uniform amplitude attenuator. Parameters ---------- incoming : OpticalWavefront Input field. optical_density : scalar_float, optional OD; intensity transmittance T = 10^(-OD). If given, overrides `transmittance`. Default is 0.0. transmittance : scalar_float, optional Intensity transmittance T in [0, 1]. Used if `optical_density` is 0. Returns ------- nd_wavefront : OpticalWavefront Attenuated wavefront. Notes ----- - Determine intensity T from OD or provided T. - Amplitude factor a = sqrt(T). - Multiply field by a and return. """ tt = jax.lax.cond( optical_density != 0, lambda: jnp.power(10.0, -jnp.asarray(optical_density)), lambda: jnp.clip(jnp.asarray(transmittance), 0.0, 1.0), ) a = jnp.sqrt(tt).astype(incoming.field.real.dtype) field_out = incoming.field * a nd_wavefront: OpticalWavefront = make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return nd_wavefront
[docs] @jaxtyped(typechecker=beartype) def quarter_waveplate( incoming: OpticalWavefront, theta: Optional[scalar_float] = 0.0, ) -> OpticalWavefront: """ Apply a quarter-wave plate (δ = π/2) with fast-axis angle `theta`. Parameters ---------- incoming : OpticalWavefront Vector field Complex[H, W, 2] (Jones: ex, ey). theta : scalar_float, optional Fast-axis angle in radians (CCW from x), by default 0.0. Returns ------- qw_wavefront : OpticalWavefront Retarded field after quarter-wave plate. Notes ----- Call `waveplate_jones` with delta = π/2. """ qw_wavefront: OpticalWavefront = waveplate_jones( incoming, delta=jnp.pi / 2.0, theta=theta ) return qw_wavefront
[docs] @jaxtyped(typechecker=beartype) def half_waveplate( incoming: OpticalWavefront, theta: Optional[scalar_float] = 0.0, ) -> OpticalWavefront: """ Apply a half-wave plate (δ = π) with fast-axis angle `theta`. Parameters ---------- incoming : OpticalWavefront Vector field Complex[H, W, 2] (Jones: ex, ey). theta : scalar_float, optional Fast-axis angle in radians (CCW from x), by default 0.0. Returns ------- hw_wavefront : OpticalWavefront Retarded field after half-wave plate. Notes ----- Call `waveplate_jones` with delta = π. """ hw_wavefront: OpticalWavefront = waveplate_jones( incoming, delta=jnp.pi, theta=theta ) return hw_wavefront
[docs] @jaxtyped(typechecker=beartype) def phase_grating_blazed_elliptical( incoming: OpticalWavefront, period_x: scalar_float, period_y: scalar_float, theta: Optional[scalar_float] = 0.0, depth: Optional[scalar_float] = 2.0 * jnp.pi, two_dim: Optional[bool] = False, ) -> OpticalWavefront: r""" Orientation-aware elliptical blazed grating. Supports anisotropic periods along rotated axes (x', y') and optional 2D blaze. Parameters ---------- incoming : OpticalWavefront Input scalar wavefront. period_x : scalar_float Blaze period along x' in meters (after rotation by `theta`). period_y : scalar_float Blaze period along y' in meters (after rotation by `theta`). theta : scalar_float, optional Grating orientation angle in radians (CCW from x), by default 0.0. depth : scalar_float, optional Peak-to-peak phase depth in radians, by default 2π. two_dim : bool, optional If False (default), apply a 1D blaze along x' only. If True, create a 2D blazed lattice using both x' and y'. Returns ------- phase_grating_wavefront : OpticalWavefront Field after applying the elliptical blazed phase. Notes ----- - Build centered grids xx, yy (meters) and rotate → (x', y'). - Compute fractional coordinates ..math:: fu = frac(x'/period_x) fv = frac(y'/period_y) - if `two_dim` is True ..math:: phase = depth * frac(fu + fv) else, ..math:: phase = depth * fu - Multiply by exp(i * phase) and return. """ ny, nx = incoming.field.shape[:2] xx, yy = _xy_grids(nx, ny, float(incoming.dx)) uu, vv = _rotate_coords(xx, yy, theta) eps = 1e-30 px = jnp.where(jnp.abs(period_x) < eps, eps, period_x) py = jnp.where(jnp.abs(period_y) < eps, eps, period_y) fu = (uu / px) - jnp.floor(uu / px) # in [0,1) fv = (vv / py) - jnp.floor(vv / py) # in [0,1) phase = depth * ((fu + fv) - jnp.floor(fu + fv)) if two_dim else depth * fu field_out = add_phase_screen(incoming.field, phase) phase_grating_wavefront: OpticalWavefront = make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return phase_grating_wavefront