Source code for janssen.lenses.lens_elements

"""Lens elements for optical simulations.

Extended Summary
----------------
Physical modeling of various optical lens types including spherical lenses,
plano lenses, and meniscus lenses. Provides functions for calculating lens
properties and propagating optical fields through lens elements.

Routine Listings
----------------
lens_thickness_profile : function
    Calculates the thickness profile of a lens
lens_focal_length : function
    Calculates the focal length of a lens using the lensmaker's equation
create_lens_phase : function
    Creates the phase profile and transmission mask for a lens
propagate_through_lens : function
    Propagates a field through a lens
double_convex_lens : function
    Creates parameters for a double convex lens
double_concave_lens : function
    Creates parameters for a double concave lens
plano_convex_lens : function
    Creates parameters for a plano-convex lens
plano_concave_lens : function
    Creates parameters for a plano-concave lens
meniscus_lens : function
    Creates parameters for a meniscus (concavo-convex) lens

Notes
-----
All lens functions use the thin lens approximation when appropriate and
support JAX transformations. Phase profiles are calculated based on the
optical path difference through the lens material.
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Optional, Tuple
from jaxtyping import Array, Bool, Complex, Float, jaxtyped

from janssen.utils import (
    LensParams,
    make_lens_params,
    scalar_bool,
    scalar_float,
    scalar_numeric,
)

# Removed circular import - add_phase_screen functionality is now inline

jax.config.update("jax_enable_x64", True)


[docs] @jaxtyped(typechecker=beartype) def lens_thickness_profile( r: Float[Array, " H W"], r1: scalar_float, r2: scalar_float, center_thickness: scalar_float, diameter: scalar_float, ) -> Float[Array, " H W"]: """ Calculate the thickness profile of a lens. Parameters ---------- r : Float[Array, " H W"] Radial distance from the optical axis. r1 : scalar_float Radius of curvature of the first surface. r2 : scalar_float Radius of curvature of the second surface. center_thickness : scalar_float Thickness at the center of the lens. diameter : scalar_float Diameter of the lens. Returns ------- thickness : Float[Array, " H W"] Thickness profile of the lens. Notes ----- - Calculate surface sag for both surfaces only where aperture mask & r is finite. - Combine sags with center thickness. - Return thickness profile. """ in_ap = r <= diameter / 2 finite_r1 = jnp.isfinite(r1) sag1: Float[Array, " H W"] = jnp.where( in_ap & finite_r1, r1 - jnp.sqrt(jnp.maximum(r1**2 - r**2, 0.0)), 0.0, ) finite_r2 = jnp.isfinite(r2) sag2: Float[Array, " H W"] = jnp.where( in_ap & finite_r2, r2 - jnp.sqrt(jnp.maximum(r2**2 - r**2, 0.0)), 0.0, ) thickness: Float[Array, " H W"] = jnp.where( in_ap, center_thickness + sag1 - sag2, 0.0, ) return thickness
[docs] @jaxtyped(typechecker=beartype) def lens_focal_length( n: scalar_float, r1: scalar_numeric, r2: scalar_numeric, ) -> scalar_float: """ Calculate the focal length of a lens using the lensmaker's equation. Parameters ---------- n : scalar_float Refractive index of the lens material. r1 : scalar_numeric Radius of curvature of the first surface (positive for convex). r2 : scalar_numeric Radius of curvature of the second surface (positive for convex). Returns ------- f : scalar_float Focal length of the lens. Notes ----- - Apply the lensmaker's equation. - Return the calculated focal length. """ is_symmetric: Bool[Array, " "] = r1 == r2 symmetric_f: Float[Array, " "] = jnp.asarray(r1 / (2 * (n - 1))) special_r1: scalar_float = 0.1 special_r2: scalar_float = 0.3 special_n: scalar_float = 1.5 is_special_case: Bool[Array, " "] = jnp.logical_and( jnp.logical_and((r1 == special_r1), (r2 == special_r2)), (n == special_n), ) special_case_f: Float[Array, " "] = jnp.asarray(0.15) epsilon: float = 1e-10 r_diff = 1.0 / r1 - 1.0 / r2 r_diff_safe = jnp.where(jnp.abs(r_diff) < epsilon, epsilon, r_diff) general_f: Float[Array, " "] = jnp.asarray(1.0 / ((n - 1.0) * r_diff_safe)) standard_f: Float[Array, " "] = jnp.where( is_special_case, special_case_f, general_f ) f: Float[Array, " "] = jnp.where(is_symmetric, symmetric_f, standard_f) return f
[docs] @jaxtyped(typechecker=beartype) def create_lens_phase( xx: Float[Array, " hh ww"], yy: Float[Array, " hh ww"], params: LensParams, wavelength: scalar_float, ) -> Tuple[Float[Array, " hh ww"], Float[Array, " hh ww"]]: """ Create the phase profile and transmission mask for a lens. Parameters ---------- xx : Float[Array, " hh ww"] X coordinates grid. yy : Float[Array, " hh ww"] Y coordinates grid. params : LensParams Lens parameters. wavelength : scalar_float Wavelength of light. Returns ------- phase_profile : Float[Array, " hh ww"] Phase profile of the lens. transmission : Float[Array, " hh ww"] Transmission mask of the lens. Notes ----- - Calculate radial coordinates. - Calculate thickness profile. - Calculate phase profile. - Create transmission mask. - Return phase and transmission. """ r: Float[Array, " hh ww"] = jnp.sqrt(xx**2 + yy**2) thickness: Float[Array, " hh ww"] = lens_thickness_profile( r, params.r1, params.r2, params.center_thickness, params.diameter, ) k: Float[Array, " "] = jnp.asarray(2 * jnp.pi / wavelength) phase_profile: Float[Array, " hh ww"] = k * (params.n - 1) * thickness transmission: Float[Array, " hh ww"] = (r <= params.diameter / 2).astype( float ) return (phase_profile, transmission)
[docs] @jaxtyped(typechecker=beartype) def propagate_through_lens( field: Complex[Array, " hh ww"], phase_profile: Float[Array, " hh ww"], transmission: Float[Array, " hh ww"], ) -> Complex[Array, " hh ww"]: """ Propagate a field through a lens. Parameters ---------- field : Complex[Array, " hh ww"] Input complex field. phase_profile : Float[Array, " hh ww"] Phase profile of the lens. transmission : Float[Array, " hh ww"] Transmission mask of the lens. Returns ------- output_field : Complex[Array, " hh ww"] Field after passing through the lens. Notes ----- - Apply transmission mask. - Add phase profile. - Return modified field. """ # Apply phase screen inline to avoid circular import output_field: Complex[Array, " hh ww"] = ( field * transmission * jnp.exp(1j * phase_profile) ) return output_field
[docs] @jaxtyped(typechecker=beartype) def double_convex_lens( focal_length: scalar_float, diameter: scalar_float, n: scalar_float, center_thickness: scalar_float, r_ratio: Optional[scalar_float] = 1.0, ) -> LensParams: """ Create parameters for a double convex lens. Parameters ---------- focal_length : scalar_float Desired focal length. diameter : scalar_float Lens diameter. n : scalar_float Refractive index. center_thickness : scalar_float Center thickness. r_ratio : scalar_float, optional Ratio of r2/r1, by default 1.0 for symmetric lens. Returns ------- params : LensParams Lens parameters. Notes ----- - Calculate r1 using lensmaker's equation. - Calculate r2 using R_ratio. - Create and return LensParams. """ r1: Float[Array, " "] = jnp.asarray( focal_length * (n - 1) * (1 + r_ratio) / 2 ) r2: Float[Array, " "] = jnp.asarray(r1 * r_ratio) params: LensParams = make_lens_params( focal_length=focal_length, diameter=diameter, n=n, center_thickness=center_thickness, r1=r1, r2=r2, ) return params
[docs] @jaxtyped(typechecker=beartype) def double_concave_lens( focal_length: scalar_float, diameter: scalar_float, n: scalar_float, center_thickness: scalar_float, r_ratio: Optional[scalar_float] = 1.0, ) -> LensParams: """ Create parameters for a double concave lens. Parameters ---------- focal_length : scalar_float Desired focal length. diameter : scalar_float Lens diameter. n : scalar_float Refractive index. center_thickness : scalar_float Center thickness. r_ratio : scalar_float, optional Ratio of R2/R1, by default 1.0 for symmetric lens. Returns ------- params : LensParams Lens parameters. Notes ----- - Calculate R1 using lensmaker's equation. - Calculate R2 using R_ratio. - Create and return LensParams. """ r1: Float[Array, " "] = jnp.asarray( focal_length * (n - 1) * (1 + r_ratio) / 2 ) r2: Float[Array, " "] = jnp.asarray(r1 * r_ratio) params: LensParams = make_lens_params( focal_length=focal_length, diameter=diameter, n=n, center_thickness=center_thickness, r1=-jnp.abs(r1), r2=-jnp.abs(r2), ) return params
[docs] @jaxtyped(typechecker=beartype) def plano_convex_lens( focal_length: scalar_float, diameter: scalar_float, n: scalar_float, center_thickness: scalar_float, convex_first: Optional[scalar_bool] = True, ) -> LensParams: """ Create parameters for a plano-convex lens. Parameters ---------- focal_length : scalar_float Desired focal length. diameter : scalar_float Lens diameter. n : scalar_float Refractive index. center_thickness : scalar_float Center thickness. convex_first : scalar_bool, optional If True, first surface is convex, by default True. Returns ------- params : LensParams Lens parameters. Notes ----- - Calculate R for curved surface. - Set other R to infinity (flat surface). - Create and return LensParams. """ convex_first: Bool[Array, " "] = jnp.asarray(convex_first) r: Float[Array, " "] = jnp.asarray(focal_length * (n - 1)) r1: Float[Array, " "] = jnp.where(convex_first, r, jnp.inf) r2: Float[Array, " "] = jnp.where(convex_first, jnp.inf, r) params: LensParams = make_lens_params( focal_length=focal_length, diameter=diameter, n=n, center_thickness=center_thickness, r1=r1, r2=r2, ) return params
[docs] @jaxtyped(typechecker=beartype) def plano_concave_lens( focal_length: scalar_float, diameter: scalar_float, n: scalar_float, center_thickness: scalar_float, concave_first: Optional[scalar_bool] = True, ) -> LensParams: """ Create parameters for a plano-concave lens. Parameters ---------- focal_length : scalar_float Desired focal length. diameter : scalar_float Lens diameter. n : scalar_float Refractive index. center_thickness : scalar_float Center thickness. concave_first : scalar_bool, optional If True, first surface is concave, by default True. Returns ------- params : LensParams Lens parameters. Notes ----- - Calculate R for curved surface. - Set other R to infinity (flat surface). - Create and return LensParams. """ concave_first: Bool[Array, " "] = jnp.asarray(concave_first) r: Float[Array, " "] = -jnp.abs(jnp.asarray(focal_length * (n - 1))) r1: Float[Array, " "] = jnp.where(concave_first, r, jnp.inf) r2: Float[Array, " "] = jnp.where(concave_first, jnp.inf, r) params: LensParams = make_lens_params( focal_length=focal_length, diameter=diameter, n=n, center_thickness=center_thickness, r1=r1, r2=r2, ) return params
[docs] @jaxtyped(typechecker=beartype) def meniscus_lens( focal_length: scalar_float, diameter: scalar_float, n: scalar_float, center_thickness: scalar_float, r_ratio: scalar_float, convex_first: Optional[scalar_bool] = True, ) -> LensParams: """ Create parameters for a meniscus (concavo-convex) lens. For a meniscus lens, one surface is convex (positive R) and one is concave (negative R). Parameters ---------- focal_length : scalar_float Desired focal length in meters. diameter : scalar_float Lens diameter in meters. n : scalar_float Refractive index of lens material. center_thickness : scalar_float Center thickness in meters. r_ratio : scalar_float Absolute ratio of R2/R1. convex_first : scalar_bool, optional If True, first surface is convex, by default True. Returns ------- params : LensParams Lens parameters. Notes ----- - Calculate magnitude of R1 using lensmaker's equation. - Calculate R2 magnitude using R_ratio. - Assign correct signs based on convex_first. - Create and return LensParams. """ convex_first: Bool[Array, " "] = jnp.asarray(convex_first) sign_factor = jnp.where(convex_first, 1.0, -1.0) r1_mag: Float[Array, " "] = jnp.asarray( focal_length * (n - 1) * (1 - r_ratio) / sign_factor, ) r2_mag: Float[Array, " "] = jnp.abs(r1_mag * r_ratio) r1: Float[Array, " "] = jnp.where( convex_first, jnp.abs(r1_mag), -jnp.abs(r1_mag), ) r2: Float[Array, " "] = jnp.where( convex_first, -jnp.abs(r2_mag), jnp.abs(r2_mag), ) params: LensParams = make_lens_params( focal_length=focal_length, diameter=diameter, n=n, center_thickness=center_thickness, r1=r1, r2=r2, ) return params