Source code for janssen.optics.zernike

"""Zernike polynomial functions for optical aberration modeling.

Extended Summary
----------------
This module provides functions for generating Zernike polynomials and
creating optical aberrations based on them. Zernike polynomials form a
complete orthogonal basis over the unit circle and are widely used in
optics to describe wavefront aberrations.

The module supports:
- Individual Zernike polynomial generation (Noll and OSA/ANSI indexing)
- Common aberration types (defocus, astigmatism, coma, spherical, etc.)
- Wavefront aberration synthesis from Zernike coefficients
- Conversion between different indexing conventions

Routine Listings
----------------
zernike_polynomial : function
    Generate a single Zernike polynomial
zernike_radial : function
    Radial component of Zernike polynomial
zernike_even : function
    Generate even (cosine) Zernike polynomial
zernike_odd : function
    Generate odd (sine) Zernike polynomial
zernike_nm : function
    Generate Zernike polynomial from (n,m) indices
zernike_noll : function
    Generate Zernike polynomial from Noll index
factorial : function
    JAX-compatible factorial computation
noll_to_nm : function
    Convert Noll index to (n, m) indices
nm_to_noll : function
    Convert (n, m) indices to Noll index
generate_aberration_nm : function
    Generate aberration phase map from (n,m) indices and coefficients
generate_aberration_noll : function
    Generate aberration phase map from Noll-indexed coefficients
compute_phase_from_coeffs : function
    Compute phase map from Zernike coefficients with start index
phase_rms : function
    Compute RMS of phase within the unit pupil
defocus : function
    Generate defocus aberration (Z4)
astigmatism : function
    Generate astigmatism aberration (Z5, Z6)
coma : function
    Generate coma aberration (Z7, Z8)
spherical_aberration : function
    Generate spherical aberration (Z11)
trefoil : function
    Generate trefoil aberration (Z9, Z10)
apply_aberration : function
    Apply aberration to optical wavefront

Notes
-----
Zernike polynomials are defined on the unit circle with normalization
such that the RMS value over the unit circle equals 1. The polynomials
use the Noll indexing convention by default, which starts at j=1 for
piston. OSA/ANSI indexing is also supported.

References
----------
.. [1] Noll, R. J. (1976). "Zernike polynomials and atmospheric turbulence".
       JOSA, 66(3), 207-211.
.. [2] Born, M., & Wolf, E. (1999). Principles of optics (7th ed.).
       Cambridge University Press.
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Tuple
from jaxtyping import Array, Float, Int, jaxtyped

from janssen.types import (
    OpticalWavefront,
    ScalarFloat,
    ScalarInteger,
    make_optical_wavefront,
)

from .helper import add_phase_screen


[docs] @jaxtyped(typechecker=beartype) def factorial(n: Int[Array, " "]) -> Int[Array, " "]: """JAX-compatible factorial computation. Parameters ---------- n : Int[Array, " "] Non-negative integer Returns ------- Int[Array, " "] n! (n factorial) """ gammaln_result: Float[Array, " "] = jax.scipy.special.gammaln(n + 1) exp_result: Float[Array, " "] = jnp.exp(gammaln_result) rounded: Float[Array, " "] = jnp.round(exp_result) result: Int[Array, " "] = rounded.astype(jnp.int64) return result
[docs] @jaxtyped(typechecker=beartype) def noll_to_nm(j: ScalarInteger) -> Tuple[int, int]: """Convert Noll index to (n, m) indices. Parameters ---------- j : int Noll index (starting from 1) Returns ------- n : int Radial order m : int Azimuthal frequency (signed) Notes ----- Uses the standard Noll ordering where j=1 corresponds to piston (n=0, m=0). Sign convention: j even -> m >= 0 (cosine), j odd -> m <= 0 (sine). The radial order n is found from the cumulative count relation: n(n+1)/2 < j <= (n+1)(n+2)/2. Within each row n, the position k (0-indexed) determines |m|. For n even: |m| follows pattern 0,2,2,4,4,... For n odd: |m| follows pattern 1,1,3,3,5,5,... """ n_float: Float[Array, " "] = (-1 + jnp.sqrt(1 + 8 * j)) / 2 n: int = int(jnp.ceil(n_float)) - 1 j_start: int = n * (n + 1) // 2 + 1 k: int = j - j_start m_abs_even_n: int = 2 * ((k + 1) // 2) m_abs_odd_n: int = 2 * (k // 2) + 1 m_abs: int = int(jnp.where(n % 2 == 0, m_abs_even_n, m_abs_odd_n)) m_positive: int = m_abs m_negative: int = -m_abs m_with_sign: int = int(jnp.where(j % 2 == 0, m_positive, m_negative)) m: int = int(jnp.where(m_abs == 0, 0, m_with_sign)) return n, m
[docs] @jaxtyped(typechecker=beartype) def nm_to_noll(n: int, m: int) -> int: """Convert (n, m) indices to Noll index. Parameters ---------- n : int Radial order (n >= 0) m : int Azimuthal frequency (|m| <= n, n-|m| must be even) Returns ------- int Noll index (starting from 1) Notes ----- Sign convention: j even -> m >= 0 (cosine), j odd -> m <= 0 (sine). The first Noll index for row n is j_base = n(n+1)/2 + 1. For m=0, the position k within the row is 0. For m!=0, find the pair of k values for the given |m|, then select based on the sign of m and the parity requirement. For n even: |m| values are 0,2,4,...; group index g = |m|/2; k_first = 2g-1 for g>0, else 0. For n odd: |m| values are 1,3,5,...; group index g = (|m|-1)/2; k_first = 2g. The final k is chosen such that m > 0 yields an even j, and m < 0 yields an odd j. """ j_base: int = n * (n + 1) // 2 + 1 m_abs: int = abs(m) g_even_n: int = m_abs // 2 k_first_even_n: int = int(jnp.where(g_even_n > 0, 2 * g_even_n - 1, 0)) g_odd_n: int = (m_abs - 1) // 2 k_first_odd_n: int = 2 * g_odd_n k_first: int = int(jnp.where(n % 2 == 0, k_first_even_n, k_first_odd_n)) j_first: int = j_base + k_first j_first_is_even: int = 1 - (j_first % 2) k_for_pos: int = int(jnp.where(j_first_is_even, k_first, k_first + 1)) k_for_neg: int = int(jnp.where(j_first_is_even, k_first + 1, k_first)) k: int = int( jnp.where(m_abs == 0, 0, jnp.where(m > 0, k_for_pos, k_for_neg)) ) j: int = j_base + k return j
[docs] @jaxtyped(typechecker=beartype) def zernike_radial( rho: Float[Array, " *batch"], n: int, m: int, ) -> Float[Array, " *batch"]: """Compute the radial component of Zernike polynomial. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) n : int Radial order m : int Azimuthal frequency (absolute value used) Returns ------- Float[Array, " *batch"] Radial polynomial R_n^|m|(rho) Notes ----- Uses JAX-compatible validation that returns zeros for invalid (n,m) combinations where n-|m| is odd. Computes the radial polynomial using the standard formula with factorials for valid combinations. Uses jax.lax.scan for efficient accumulation of terms. """ m_abs: int = abs(m) valid: bool = (n - m_abs) % 2 == 0 def scan_fn( carry: Float[Array, " *batch"], s: Int[Array, " "] ) -> Tuple[Float[Array, " *batch"], None]: sign: Float[Array, " "] = (-1.0) ** s num: Int[Array, " "] = factorial(jnp.array(n - s)) denom_s: Int[Array, " "] = factorial(s) denom_n_plus: Int[Array, " "] = factorial( jnp.array((n + m_abs) // 2 - s) ) denom_n_minus: Int[Array, " "] = factorial( jnp.array((n - m_abs) // 2 - s) ) denom: Int[Array, " "] = denom_s * denom_n_plus * denom_n_minus coeff: Float[Array, " "] = sign * num / denom power_term: Float[Array, " *batch"] = rho ** (n - 2 * s) updated_result: Float[Array, " *batch"] = carry + coeff * power_term return updated_result, None initial_result: Float[Array, " *batch"] = jnp.zeros_like(rho) s_values: Int[Array, " S"] = jnp.arange((n - m_abs) // 2 + 1) result: Float[Array, " *batch"] result, _ = jax.lax.scan(scan_fn, initial_result, s_values) return jnp.where(valid, result, jnp.zeros_like(rho))
[docs] @jaxtyped(typechecker=beartype) def zernike_polynomial( rho: Float[Array, " *batch"], theta: Float[Array, " *batch"], n: int, m: int, normalize: bool = True, ) -> Float[Array, " *batch"]: """Generate a single Zernike polynomial. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) theta : Float[Array, " *batch"] Azimuthal angle in radians n : int Radial order (n >= 0) m : int Azimuthal frequency (|m| <= n, n-|m| must be even) normalize : bool, optional Whether to normalize for unit RMS over unit circle, by default True Returns ------- Float[Array, " *batch"] Zernike polynomial Z_n^m(rho, theta) Notes ----- The polynomial is zero outside the unit circle (rho > 1). Normalization follows the convention where RMS over unit circle = 1. Angular part uses cosine for m>0, sine for m<0, and 1 for m=0. Normalization factor is sqrt(n+1) for m=0 and sqrt(2*(n+1)) for m≠0. """ r: Float[Array, " *batch"] = zernike_radial(rho, n, abs(m)) m_abs: int = abs(m) angular_cos: Float[Array, " *batch"] = jnp.cos(m_abs * theta) angular_sin: Float[Array, " *batch"] = jnp.sin(m_abs * theta) angular_ones: Float[Array, " *batch"] = jnp.ones_like(theta) angular: Float[Array, " *batch"] = jnp.where( m > 0, angular_cos, jnp.where(m < 0, angular_sin, angular_ones) ) norm_m0: Float[Array, " "] = jnp.sqrt(n + 1) norm_m_nonzero: Float[Array, " "] = jnp.sqrt(2 * (n + 1)) norm: Float[Array, " "] = jnp.where( normalize, jnp.where(m == 0, norm_m0, norm_m_nonzero), 1.0 ) mask: Float[Array, " *batch"] = rho <= 1.0 return norm * r * angular * mask
@jaxtyped(typechecker=beartype) def zernike_even( rho: Float[Array, " *batch"], theta: Float[Array, " *batch"], n: int, m: int, normalize: bool = True, ) -> Float[Array, " *batch"]: """Generate even (cosine) Zernike polynomial. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) theta : Float[Array, " *batch"] Azimuthal angle in radians n : int Radial order (n >= 0) m : int Azimuthal frequency (|m| <= n, n-|m| must be even) normalize : bool, optional Whether to normalize for unit RMS over unit circle, by default True Returns ------- even_polynomial : Float[Array, " *batch"] Even Zernike polynomial using cosine for angular part Notes ----- This function always uses cosine for the angular component, suitable for symmetric aberrations. Angular part uses cos(|m|*theta) for m≠0, and 1 for m=0. Normalization factor is sqrt(n+1) for m=0 and sqrt(2*(n+1)) for m≠0. Returns zero outside the unit circle (rho > 1). """ r: Float[Array, " *batch"] = zernike_radial(rho, n, abs(m)) m_abs: int = abs(m) cos_term: Float[Array, " *batch"] = jnp.cos(m_abs * theta) ones_term: Float[Array, " *batch"] = jnp.ones_like(theta) angular: Float[Array, " *batch"] = jnp.where(m != 0, cos_term, ones_term) norm_m0: Float[Array, " "] = jnp.sqrt(n + 1) norm_m_nonzero: Float[Array, " "] = jnp.sqrt(2 * (n + 1)) norm: Float[Array, " "] = jnp.where( normalize, jnp.where(m == 0, norm_m0, norm_m_nonzero), 1.0 ) mask: Float[Array, " *batch"] = rho <= 1.0 even_polynomial: Float[Array, " *batch"] = norm * r * angular * mask return even_polynomial @jaxtyped(typechecker=beartype) def zernike_odd( rho: Float[Array, " *batch"], theta: Float[Array, " *batch"], n: int, m: int, normalize: bool = True, ) -> Float[Array, " *batch"]: """Generate odd (sine) Zernike polynomial. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) theta : Float[Array, " *batch"] Azimuthal angle in radians n : int Radial order (n >= 0) m : int Azimuthal frequency (|m| <= n, n-|m| must be even, m != 0) normalize : bool, optional Whether to normalize for unit RMS over unit circle, by default True Returns ------- odd_polynomial : Float[Array, " *batch"] Odd Zernike polynomial using sine for angular part Notes ----- This function always uses sine for the angular component, suitable for antisymmetric aberrations. Returns zero if m=0. Angular part uses sin(|m|*theta) for all m values. Normalization factor is sqrt(2*(n+1)) when normalize=True. Returns zero outside the unit circle (rho > 1) and for m=0. """ is_m_zero: bool = m == 0 r: Float[Array, " *batch"] = zernike_radial(rho, n, abs(m)) m_abs: int = abs(m) angular: Float[Array, " *batch"] = jnp.sin(m_abs * theta) norm_value: Float[Array, " "] = jnp.sqrt(2 * (n + 1)) norm: Float[Array, " "] = jnp.where(normalize, norm_value, 1.0) mask: Float[Array, " *batch"] = rho <= 1.0 polynomial: Float[Array, " *batch"] = norm * r * angular * mask zeros: Float[Array, " *batch"] = jnp.zeros_like(rho) odd_polynomial: Float[Array, " *batch"] = jnp.where( is_m_zero, zeros, polynomial ) return odd_polynomial @jaxtyped(typechecker=beartype) def zernike_nm( rho: Float[Array, " *batch"], theta: Float[Array, " *batch"], n: int, m: int, normalize: bool = True, ) -> Float[Array, " *batch"]: """Generate Zernike polynomial based on (n,m) indices. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) theta : Float[Array, " *batch"] Azimuthal angle in radians n : int Radial order (n >= 0) m : int Azimuthal frequency (|m| <= n, n-|m| must be even) normalize : bool, optional Whether to normalize for unit RMS over unit circle, by default True Returns ------- Float[Array, " *batch"] Zernike polynomial Z_n^m(rho, theta) Notes ----- Determines whether to use even (cosine) or odd (sine) Zernike polynomial based on the sign of m. For m>=0, uses even (cosine) form. For m<0, uses odd (sine) form. """ is_even: bool = m >= 0 even_result: Float[Array, " *batch"] = zernike_even( rho, theta, n, abs(m), normalize ) odd_result: Float[Array, " *batch"] = zernike_odd( rho, theta, n, abs(m), normalize ) result: Float[Array, " *batch"] = jnp.where( is_even, even_result, odd_result ) return result @jaxtyped(typechecker=beartype) def zernike_noll( rho: Float[Array, " *batch"], theta: Float[Array, " *batch"], j: int, normalize: bool = True, ) -> Float[Array, " *batch"]: """Generate Zernike polynomial based on Noll index. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) theta : Float[Array, " *batch"] Azimuthal angle in radians j : int Noll index (starting from 1) normalize : bool, optional Whether to normalize for unit RMS over unit circle, by default True Returns ------- Float[Array, " *batch"] Zernike polynomial for Noll index j Notes ----- Converts Noll index to (n,m) pair and calls zernike_nm. The Noll indexing convention assigns j=1 to piston (n=0, m=0). """ n, m = noll_to_nm(j) result: Float[Array, " *batch"] = zernike_nm(rho, theta, n, m, normalize) return result def _zernike_radial_traced( rho: Float[Array, " *batch"], n: Int[Array, " "], m_abs: Int[Array, " "], max_n: int = 20, ) -> Float[Array, " *batch"]: """Traced-compatible radial Zernike polynomial. Supports traced n and m values by using fixed maximum loop bounds. This is necessary for use inside jax.lax.scan where n and m are traced. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) n : Int[Array, " "] Radial order (traced JAX array) m_abs : Int[Array, " "] Absolute value of azimuthal frequency (traced JAX array) max_n : int, optional Maximum radial order to support, by default 20. The loop iterates max_n // 2 + 1 times regardless of actual n value. Returns ------- Float[Array, " *batch"] Radial polynomial R_n^|m|(rho) Notes ----- Validity check: (n - m_abs) must be even for valid Zernike polynomials. Invalid combinations return zeros. The number of terms in the sum is (n - m_abs) // 2 + 1. Terms where s >= num_terms are masked out during accumulation. Uses jax.scipy.special.gammaln for stable factorial computation with traced values. """ valid = ((n - m_abs) % 2) == 0 num_terms = (n - m_abs) // 2 + 1 def body_fn( s: Int[Array, " "], carry: Float[Array, " *batch"] ) -> Float[Array, " *batch"]: sign = (-1.0) ** s num = jnp.exp(jax.scipy.special.gammaln(n - s + 1)) denom_s = jnp.exp(jax.scipy.special.gammaln(s + 1)) denom_n_plus = jnp.exp( jax.scipy.special.gammaln((n + m_abs) // 2 - s + 1) ) denom_n_minus = jnp.exp( jax.scipy.special.gammaln((n - m_abs) // 2 - s + 1) ) denom = denom_s * denom_n_plus * denom_n_minus coeff = sign * num / denom power_term = rho ** (n - 2 * s) term = coeff * power_term term = jnp.where(s < num_terms, term, 0.0) return carry + term result = jax.lax.fori_loop(0, max_n // 2 + 1, body_fn, jnp.zeros_like(rho)) return jnp.where(valid, result, jnp.zeros_like(rho)) def _zernike_polynomial_traced( rho: Float[Array, " *batch"], theta: Float[Array, " *batch"], n: Int[Array, " "], m: Int[Array, " "], normalize: bool = True, max_n: int = 20, ) -> Float[Array, " *batch"]: """Traced-compatible Zernike polynomial. Supports traced n and m values for use inside jax.lax.scan. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) theta : Float[Array, " *batch"] Azimuthal angle in radians n : Int[Array, " "] Radial order (traced JAX array) m : Int[Array, " "] Azimuthal frequency (traced JAX array, signed) normalize : bool, optional Whether to normalize for unit RMS over unit circle, by default True max_n : int, optional Maximum radial order to support, by default 20 Returns ------- Float[Array, " *batch"] Zernike polynomial Z_n^m(rho, theta) Notes ----- Angular part uses cosine for m > 0, sine for m < 0, and 1 for m = 0. Normalization factor is sqrt(n+1) for m=0 and sqrt(2*(n+1)) for m≠0. Returns zero outside the unit circle (rho > 1). """ m_abs = jnp.abs(m) r = _zernike_radial_traced(rho, n, m_abs, max_n) angular_cos = jnp.cos(m_abs * theta) angular_sin = jnp.sin(m_abs * theta) angular_ones = jnp.ones_like(theta) angular = jnp.where( m > 0, angular_cos, jnp.where(m < 0, angular_sin, angular_ones) ) norm_m0 = jnp.sqrt(n + 1.0) norm_m_nonzero = jnp.sqrt(2.0 * (n + 1.0)) norm = jnp.where( normalize, jnp.where(m == 0, norm_m0, norm_m_nonzero), 1.0 ) mask = rho <= 1.0 return norm * r * angular * mask
[docs] @jaxtyped(typechecker=beartype) def generate_aberration_nm( xx: Float[Array, " H W"], yy: Float[Array, " H W"], n_indices: Int[Array, " N"], m_indices: Int[Array, " N"], coefficients: Float[Array, " N"], pupil_radius: ScalarFloat, ) -> Float[Array, " H W"]: """Generate aberration from (n,m) indices and coefficients. Parameters ---------- xx : Float[Array, " H W"] X coordinate grid in meters yy : Float[Array, " H W"] Y coordinate grid in meters n_indices : Int[Array, " N"] Array of radial orders m_indices : Int[Array, " N"] Array of azimuthal frequencies coefficients : Float[Array, " N"] Zernike coefficients in waves pupil_radius : ScalarFloat Pupil radius in meters Returns ------- phase_radians : Float[Array, " H W"] Phase aberration map in radians Notes ----- This version is fully JAX-compatible and can be JIT-compiled. Uses jax.lax.scan for efficient accumulation with traced-compatible Zernike polynomial computation. """ rho: Float[Array, " H W"] = jnp.sqrt(xx**2 + yy**2) / pupil_radius theta: Float[Array, " H W"] = jnp.arctan2(yy, xx) def scan_fn( phase_acc: Float[Array, " H W"], inputs: Tuple[Int[Array, " "], Int[Array, " "], Float[Array, " "]], ) -> Tuple[Float[Array, " H W"], None]: n, m, coeff = inputs z: Float[Array, " H W"] = _zernike_polynomial_traced( rho, theta, n, m, normalize=True ) updated_phase: Float[Array, " H W"] = phase_acc + coeff * z return updated_phase, None initial_phase: Float[Array, " H W"] = jnp.zeros_like(xx) inputs: Tuple[Int[Array, " N"], Int[Array, " N"], Float[Array, " N"]] = ( n_indices, m_indices, coefficients, ) phase: Float[Array, " H W"] phase, _ = jax.lax.scan(scan_fn, initial_phase, inputs) phase_radians: Float[Array, " H W"] = 2 * jnp.pi * phase return phase_radians
[docs] @jaxtyped(typechecker=beartype) def generate_aberration_noll( xx: Float[Array, " hh ww"], yy: Float[Array, " hh ww"], coefficients: Float[Array, " nn"], pupil_radius: ScalarFloat, ) -> Float[Array, " hh ww"]: """Generate aberration from Noll-indexed coefficients. Parameters ---------- xx : Float[Array, " hh ww"] X coordinate grid in meters yy : Float[Array, " hh ww"] Y coordinate grid in meters coefficients : Float[Array, " nn"] Zernike coefficients in waves, indexed by Noll index. Element 0 corresponds to j=1 (piston), element 1 to j=2, etc. pupil_radius : ScalarFloat Pupil radius in meters Returns ------- phase : Float[Array, " hh ww"] Phase aberration map in radians Notes ----- Converts Noll indices to (n,m) pairs and calls generate_aberration_nm. Uses vectorized JAX operations for the Noll-to-nm conversion. Sign convention: j even -> m >= 0 (cosine), j odd -> m <= 0 (sine). The radial order n is computed from n(n+1)/2 < j <= (n+1)(n+2)/2. The position k within row n determines |m|, which follows the pattern: 0,2,2,4,4,... for n even and 1,1,3,3,5,5,... for n odd. """ num_coeffs: int = coefficients.shape[0] j_indices: Int[Array, " nn"] = jnp.arange( 1, num_coeffs + 1, dtype=jnp.int32 ) n_float: Float[Array, " nn"] = (-1 + jnp.sqrt(1 + 8 * j_indices)) / 2 n_indices: Int[Array, " nn"] = (jnp.ceil(n_float) - 1).astype(jnp.int32) j_start: Int[Array, " nn"] = n_indices * (n_indices + 1) // 2 + 1 k: Int[Array, " nn"] = j_indices - j_start m_abs_even_n: Int[Array, " nn"] = 2 * ((k + 1) // 2) m_abs_odd_n: Int[Array, " nn"] = 2 * (k // 2) + 1 m_abs: Int[Array, " nn"] = jnp.where( n_indices % 2 == 0, m_abs_even_n, m_abs_odd_n ) m_positive: Int[Array, " nn"] = m_abs m_negative: Int[Array, " nn"] = -m_abs m_with_sign: Int[Array, " nn"] = jnp.where( j_indices % 2 == 0, m_positive, m_negative ) m_indices: Int[Array, " nn"] = jnp.where(m_abs == 0, 0, m_with_sign) phase: Float[Array, " hh ww"] = generate_aberration_nm( xx, yy, n_indices, m_indices, coefficients, pupil_radius ) return phase
[docs] @jaxtyped(typechecker=beartype) def compute_phase_from_coeffs( rho: Float[Array, " *batch"], theta: Float[Array, " *batch"], coefficients: Float[Array, " N"], start_noll: int = 4, ) -> Float[Array, " *batch"]: """Compute phase map from Zernike coefficients. Generates a phase aberration map by summing normalized Zernike polynomials weighted by the provided coefficients. The coefficients are mapped to consecutive Noll indices starting from `start_noll`. Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) theta : Float[Array, " *batch"] Azimuthal angle in radians coefficients : Float[Array, " N"] Zernike coefficients in waves. Element i corresponds to Noll index (start_noll + i). start_noll : int, optional Starting Noll index for the coefficients, by default 4 (defocus). Common choices: 1 (piston), 4 (defocus, skipping tip/tilt). Returns ------- Float[Array, " *batch"] Phase map in radians Notes ----- The phase is computed as: phase = 2 * pi * sum_i(coefficients[i] * Z_{start_noll + i}) where Z_j is the normalized Zernike polynomial for Noll index j. The output is in radians, with coefficients interpreted as waves. Examples -------- >>> # Compute phase for defocus through spherical aberration (j=4 to j=11) >>> coeffs = jnp.array([0.5, 0.1, -0.2, 0.0, 0.0, 0.0, 0.0, 0.3]) >>> phase = compute_phase_from_coeffs(rho, theta, coeffs, start_noll=4) """ num_coeffs: int = coefficients.shape[0] noll_indices: Int[Array, " N"] = jnp.arange( start_noll, start_noll + num_coeffs, dtype=jnp.int32 ) n_float: Float[Array, " N"] = (-1 + jnp.sqrt(1 + 8 * noll_indices)) / 2 n_indices: Int[Array, " N"] = (jnp.ceil(n_float) - 1).astype(jnp.int32) j_start: Int[Array, " N"] = n_indices * (n_indices + 1) // 2 + 1 k: Int[Array, " N"] = noll_indices - j_start m_abs_even_n: Int[Array, " N"] = 2 * ((k + 1) // 2) m_abs_odd_n: Int[Array, " N"] = 2 * (k // 2) + 1 m_abs: Int[Array, " N"] = jnp.where( n_indices % 2 == 0, m_abs_even_n, m_abs_odd_n ) m_positive: Int[Array, " N"] = m_abs m_negative: Int[Array, " N"] = -m_abs m_with_sign: Int[Array, " N"] = jnp.where( noll_indices % 2 == 0, m_positive, m_negative ) m_indices: Int[Array, " N"] = jnp.where(m_abs == 0, 0, m_with_sign) def scan_fn( phase_acc: Float[Array, " *batch"], inputs: Tuple[Int[Array, " "], Int[Array, " "], Float[Array, " "]], ) -> Tuple[Float[Array, " *batch"], None]: n, m, coeff = inputs z: Float[Array, " *batch"] = _zernike_polynomial_traced( rho, theta, n, m, normalize=True ) updated_phase: Float[Array, " *batch"] = phase_acc + coeff * z return updated_phase, None initial_phase: Float[Array, " *batch"] = jnp.zeros_like(rho) inputs: Tuple[Int[Array, " N"], Int[Array, " N"], Float[Array, " N"]] = ( n_indices, m_indices, coefficients, ) phase: Float[Array, " *batch"] phase, _ = jax.lax.scan(scan_fn, initial_phase, inputs) phase_radians: Float[Array, " *batch"] = 2 * jnp.pi * phase return phase_radians
[docs] @jaxtyped(typechecker=beartype) def phase_rms( rho: Float[Array, " *batch"], theta: Float[Array, " *batch"], coefficients: Float[Array, " N"], start_noll: int = 4, ) -> Float[Array, " "]: """Compute RMS of phase within the unit pupil. Calculates the root-mean-square of the phase aberration within the region where rho <= 1.0 (the unit pupil). Parameters ---------- rho : Float[Array, " *batch"] Normalized radial coordinate (0 to 1) theta : Float[Array, " *batch"] Azimuthal angle in radians coefficients : Float[Array, " N"] Zernike coefficients in waves. Element i corresponds to Noll index (start_noll + i). start_noll : int, optional Starting Noll index for the coefficients, by default 4 (defocus). Returns ------- Float[Array, " "] RMS phase value in radians Notes ----- The RMS is computed as: RMS = sqrt(mean((phase - mean(phase))^2)) where the mean is taken only over pixels within the unit pupil (rho <= 1). The piston (mean phase) is subtracted before computing RMS. Examples -------- >>> # Compute RMS for a set of aberration coefficients >>> coeffs = jnp.array([0.5, 0.1, -0.2, 0.0, 0.0, 0.0, 0.0, 0.3]) >>> rms = phase_rms(rho, theta, coeffs, start_noll=4) """ phase: Float[Array, " *batch"] = compute_phase_from_coeffs( rho, theta, coefficients, start_noll ) mask: Float[Array, " *batch"] = rho <= 1.0 phase_in_pupil: Float[Array, " *batch"] = jnp.where(mask, phase, 0.0) n_pixels: Float[Array, " "] = jnp.sum(mask) mean_phase: Float[Array, " "] = jnp.sum(phase_in_pupil) / n_pixels variance: Float[Array, " "] = ( jnp.sum(jnp.where(mask, (phase - mean_phase) ** 2, 0.0)) / n_pixels ) rms: Float[Array, " "] = jnp.sqrt(variance) return rms
[docs] @jaxtyped(typechecker=beartype) def defocus( xx: Float[Array, " hh ww"], yy: Float[Array, " hh ww"], amplitude: ScalarFloat, pupil_radius: ScalarFloat, ) -> Float[Array, " hh ww"]: """Generate defocus aberration (Z4 in Noll notation). Parameters ---------- xx : Float[Array, " hh ww"] X coordinate grid in meters yy : Float[Array, " hh ww"] Y coordinate grid in meters amplitude : ScalarFloat Defocus amplitude in waves pupil_radius : ScalarFloat Pupil radius in meters Returns ------- phase : Float[Array, " hh ww"] Defocus phase map in radians """ coefficients: Float[Array, " 4"] = jnp.zeros(4) coefficients = coefficients.at[3].set(amplitude) phase: Float[Array, " hh ww"] = generate_aberration_noll( xx, yy, coefficients, pupil_radius ) return phase
[docs] @jaxtyped(typechecker=beartype) def astigmatism( xx: Float[Array, " H W"], yy: Float[Array, " H W"], amplitude_0: ScalarFloat, amplitude_45: ScalarFloat, pupil_radius: ScalarFloat, ) -> Float[Array, " H W"]: """Generate astigmatism aberration (Z5 and Z6 in Noll notation). Parameters ---------- xx : Float[Array, " H W"] X coordinate grid in meters yy : Float[Array, " H W"] Y coordinate grid in meters amplitude_0 : ScalarFloat Vertical/horizontal astigmatism amplitude in waves (Z6) amplitude_45 : ScalarFloat Oblique astigmatism amplitude in waves (Z5) pupil_radius : ScalarFloat Pupil radius in meters Returns ------- phase : Float[Array, " H W"] Astigmatism phase map in radians """ coefficients: Float[Array, " 6"] = jnp.zeros(6) coefficients = coefficients.at[4].set(amplitude_45) coefficients = coefficients.at[5].set(amplitude_0) phase: Float[Array, " H W"] = generate_aberration_noll( xx, yy, coefficients, pupil_radius ) return phase
[docs] @jaxtyped(typechecker=beartype) def coma( xx: Float[Array, " H W"], yy: Float[Array, " H W"], amplitude_x: ScalarFloat, amplitude_y: ScalarFloat, pupil_radius: ScalarFloat, ) -> Float[Array, " H W"]: """Generate coma aberration (Z7 and Z8 in Noll notation). Parameters ---------- xx : Float[Array, " H W"] X coordinate grid in meters yy : Float[Array, " H W"] Y coordinate grid in meters amplitude_x : ScalarFloat Horizontal coma amplitude in waves (Z8) amplitude_y : ScalarFloat Vertical coma amplitude in waves (Z7) pupil_radius : ScalarFloat Pupil radius in meters Returns ------- phase : Float[Array, " H W"] Coma phase map in radians """ coefficients: Float[Array, " 8"] = jnp.zeros(8) coefficients = coefficients.at[6].set(amplitude_y) coefficients = coefficients.at[7].set(amplitude_x) phase: Float[Array, " H W"] = generate_aberration_noll( xx, yy, coefficients, pupil_radius ) return phase
[docs] @jaxtyped(typechecker=beartype) def spherical_aberration( xx: Float[Array, " H W"], yy: Float[Array, " H W"], amplitude: ScalarFloat, pupil_radius: ScalarFloat, ) -> Float[Array, " H W"]: """Generate primary spherical aberration (Z11 in Noll notation). Parameters ---------- xx : Float[Array, " H W"] X coordinate grid in meters yy : Float[Array, " H W"] Y coordinate grid in meters amplitude : ScalarFloat Spherical aberration amplitude in waves pupil_radius : ScalarFloat Pupil radius in meters Returns ------- phase : Float[Array, " H W"] Spherical aberration phase map in radians """ coefficients: Float[Array, " 11"] = jnp.zeros(11) coefficients = coefficients.at[10].set(amplitude) phase: Float[Array, " H W"] = generate_aberration_noll( xx, yy, coefficients, pupil_radius ) return phase
[docs] @jaxtyped(typechecker=beartype) def trefoil( xx: Float[Array, " H W"], yy: Float[Array, " H W"], amplitude_0: ScalarFloat, amplitude_30: ScalarFloat, pupil_radius: ScalarFloat, ) -> Float[Array, " H W"]: """Generate trefoil aberration (Z9 and Z10 in Noll notation). Parameters ---------- xx : Float[Array, " H W"] X coordinate grid in meters yy : Float[Array, " H W"] Y coordinate grid in meters amplitude_0 : ScalarFloat Vertical trefoil amplitude in waves (Z10) amplitude_30 : ScalarFloat Oblique trefoil amplitude in waves (Z9) pupil_radius : ScalarFloat Pupil radius in meters Returns ------- trefoil_wavefront : Float[Array, " H W"] Trefoil phase map in radians Notes ----- This function generates a trefoil aberration phase map in radians. The trefoil aberration is a combination of two Zernike polynomials: Z9 and Z10. The Z9 polynomial is the vertical trefoil aberration and the Z10 polynomial is the oblique trefoil aberration. """ coefficients: Float[Array, " 10"] = jnp.zeros(10) coefficients = coefficients.at[8].set(amplitude_30) coefficients = coefficients.at[9].set(amplitude_0) trefoil_wavefront: Float[Array, " H W"] = generate_aberration_noll( xx, yy, coefficients, pupil_radius ) return trefoil_wavefront
[docs] @jaxtyped(typechecker=beartype) def apply_aberration( incoming: OpticalWavefront, coefficients: Float[Array, " N"], pupil_radius: ScalarFloat, ) -> OpticalWavefront: """Apply Zernike aberrations to an optical wavefront. Parameters ---------- incoming : OpticalWavefront Input wavefront coefficients : Float[Array, " N"] Noll-indexed Zernike coefficients in waves (index i = Noll index i+1) pupil_radius : ScalarFloat Pupil radius in meters Returns ------- wavefront_out : OpticalWavefront Aberrated wavefront """ h: int w: int h, w = incoming.field.shape[:2] x: Float[Array, " W"] = jnp.arange(-w // 2, w // 2) * incoming.dx y: Float[Array, " H"] = jnp.arange(-h // 2, h // 2) * incoming.dx xx: Float[Array, " H W"] yy: Float[Array, " H W"] xx, yy = jnp.meshgrid(x, y) phase: Float[Array, " H W"] = generate_aberration_noll( xx, yy, coefficients, pupil_radius ) field_out: Float[Array, " H W"] = add_phase_screen(incoming.field, phase) wavefront_out: OpticalWavefront = make_optical_wavefront( field=field_out, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return wavefront_out