r"""High-NA vector focusing using Richards-Wolf diffraction integrals.
Extended Summary
----------------
This module implements the Richards-Wolf vector diffraction theory for
computing the focal field of high numerical aperture (NA) optical systems.
Unlike scalar diffraction theory, vector theory correctly accounts for the
polarization rotation that occurs when light is focused by a high-NA lens.
At high NA (> 0.7), three key effects become significant:
1. **Depolarization**: Linear polarization develops a weak longitudinal (Ez)
component and becomes slightly elliptical at focus.
2. **Radial focusing enhancement**: Radially polarized beams create strong
Ez at focus, producing a tighter focal spot than linearly polarized light.
3. **Azimuthal donut**: Azimuthally polarized beams produce no Ez and create
a dark-center "donut" focal spot.
These effects are invisible to scalar diffraction theory (Fresnel, Fraunhofer,
Angular Spectrum) and require full vector treatment.
Routine Listings
----------------
high_na_focus : function
Compute focal field using Richards-Wolf vector diffraction integrals
debye_wolf_focus : function
Alternative interface using Debye-Wolf formulation
compute_focal_volume : function
Compute 3D focal volume at multiple z planes
aplanatic_apodization : function
Apply √cos(θ) apodization for aplanatic lens systems
Notes
-----
The Richards-Wolf integrals express the focal field as:
.. math::
\\vec{E}(\\rho_f, \\phi_f, z_f) = -\\frac{i k f}{2\\pi} \\int_0^{\\theta_{max}}
\\int_0^{2\\pi} \\sqrt{\\cos\\theta} \\, \\mathbf{P}(\\theta, \\phi)
\\cdot \\vec{E}_{pupil}(\\theta, \\phi) \\,
e^{ikz_f\\cos\\theta} \\, e^{ik\\rho_f\\sin\\theta\\cos(\\phi - \\phi_f)}
\\sin\\theta \\, d\\phi \\, d\\theta
where P(θ,φ) is the polarization rotation matrix that accounts for how the
electric field vector rotates as light refracts through the lens.
References
----------
.. [1] Richards, B., & Wolf, E. (1959). "Electromagnetic diffraction in
optical systems, II. Structure of the image field in an aplanatic
system". Proc. R. Soc. Lond. A, 253(1274), 358-379.
.. [2] Youngworth, K. S., & Brown, T. G. (2000). "Focusing of high
numerical aperture cylindrical-vector beams". Opt. Express, 7(2),
77-87.
.. [3] Novotny, L., & Hecht, B. (2012). "Principles of Nano-Optics",
2nd ed. Cambridge University Press. Chapter 3.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Tuple, Union
from jax import lax
from jaxtyping import Array, Bool, Complex, Float, Int, jaxtyped
from janssen.optics import create_spatial_grid
from janssen.types import (
OpticalWavefront,
ScalarFloat,
VectorWavefront3D,
make_optical_wavefront,
make_vector_wavefront_3d,
)
SAFE_DIVIDE_FLOOR: float = 1e-15
POLARIZED_FIELD_NDIM: int = 3
JONES_VECTOR_DIM: int = 2
@jaxtyped(typechecker=beartype)
def _create_pupil_coordinates(
grid_size: Tuple[int, int],
dx: ScalarFloat,
na: ScalarFloat,
) -> Tuple[
Float[Array, " ny nx"],
Float[Array, " ny nx"],
Float[Array, " ny nx"],
Float[Array, " ny nx"],
Float[Array, " ny nx"],
Bool[Array, " ny nx"],
]:
"""Create pupil plane coordinates and angular mappings.
Parameters
----------
grid_size : Tuple[int, int]
Grid dimensions (ny, nx).
dx : ScalarFloat
Pupil plane pixel spacing in meters.
na : ScalarFloat
Numerical aperture of the focusing lens.
Returns
-------
rho_norm : Float[Array, " ny nx"]
Normalized radial coordinate (0 at center, 1 at NA edge).
phi : Float[Array, " ny nx"]
Azimuthal angle in pupil plane.
sin_theta : Float[Array, " ny nx"]
sin(θ) where θ is the convergence angle.
cos_theta : Float[Array, " ny nx"]
cos(θ) where θ is the convergence angle.
theta : Float[Array, " ny nx"]
Convergence angle in radians.
pupil_mask : Bool[Array, " ny nx"]
Binary mask for valid pupil region (rho_norm <= 1).
Notes
-----
The pupil radius corresponds to NA for a unit magnification system.
In a real system, pupil_radius = f * NA where f is focal length.
Here we normalize so that the edge of the illuminated region corresponds
to the NA. The sin(θ) = NA * ρ_norm mapping is for an aplanatic system.
"""
ny, nx = grid_size
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx: Float[Array, " ny nx"]
yy: Float[Array, " ny nx"]
xx, yy = create_spatial_grid(diameter, num_points)
rho: Float[Array, " ny nx"] = jnp.sqrt(xx**2 + yy**2)
pupil_radius: Float[Array, " "] = jnp.max(rho)
rho_norm: Float[Array, " ny nx"] = rho / (pupil_radius + SAFE_DIVIDE_FLOOR)
sin_theta: Float[Array, " ny nx"] = na * rho_norm
sin_theta = jnp.clip(sin_theta, 0.0, 1.0 - SAFE_DIVIDE_FLOOR)
cos_theta: Float[Array, " ny nx"] = jnp.sqrt(1.0 - sin_theta**2)
theta: Float[Array, " ny nx"] = jnp.arcsin(sin_theta)
phi: Float[Array, " ny nx"] = jnp.arctan2(yy, xx)
pupil_mask: Bool[Array, " ny nx"] = rho_norm <= 1.0
return rho_norm, phi, sin_theta, cos_theta, theta, pupil_mask
[docs]
@jaxtyped(typechecker=beartype)
def aplanatic_apodization(
cos_theta: Float[Array, " ny nx"],
) -> Float[Array, " ny nx"]:
"""Apply aplanatic lens apodization factor.
For an aplanatic (sine-condition satisfying) lens, the amplitude
apodization is √cos(θ) where θ is the convergence angle.
Parameters
----------
cos_theta : Float[Array, " ny nx"]
Cosine of convergence angle at each pupil point.
Returns
-------
apodization : Float[Array, " ny nx"]
Apodization factor √cos(θ).
Notes
-----
The √cos(θ) factor arises from energy conservation when mapping
a uniform pupil plane wave to converging spherical wavefronts.
This is the standard apodization for microscope objectives.
"""
apodization: Float[Array, " ny nx"] = jnp.sqrt(
jnp.maximum(cos_theta, SAFE_DIVIDE_FLOOR)
)
return apodization
@jaxtyped(typechecker=beartype)
def _polarization_rotation_matrix(
sin_theta: Float[Array, " ny nx"],
cos_theta: Float[Array, " ny nx"],
phi: Float[Array, " ny nx"],
) -> Tuple[
Float[Array, " ny nx"],
Float[Array, " ny nx"],
Float[Array, " ny nx"],
Float[Array, " ny nx"],
Float[Array, " ny nx"],
Float[Array, " ny nx"],
]:
"""Compute polarization rotation matrix elements for high-NA focusing.
When light refracts through a high-NA lens, the electric field vector
must remain transverse to the k-vector. This rotates the polarization
and generates Ez components.
Parameters
----------
sin_theta : Float[Array, " ny nx"]
sin(θ) convergence angle.
cos_theta : Float[Array, " ny nx"]
cos(θ) convergence angle.
phi : Float[Array, " ny nx"]
Azimuthal angle in pupil.
Returns
-------
p_xx, p_xy, p_yx, p_yy, p_zx, p_zy : Float[Array, " ny nx"]
Elements of the 3x2 polarization rotation matrix P where:
[Ex_focal] [p_xx p_xy] [Ex_pupil]
[Ey_focal] = [p_yx p_yy] [Ey_pupil]
[Ez_focal] [p_zx p_zy]
Notes
-----
The full rotation matrix accounts for:
1. Rotation into the meridional plane (by -φ)
2. Tilt of k-vector by angle θ
3. Rotation back out of meridional plane (by +φ)
The matrix elements are derived in Novotny & Hecht, Chapter 3:
p_xx = cos(θ)cos²(φ) + sin²(φ)
p_xy = (cos(θ) - 1)sin(φ)cos(φ)
p_yx = (cos(θ) - 1)sin(φ)cos(φ)
p_yy = cos(θ)sin²(φ) + cos²(φ)
p_zx = -sin(θ)cos(φ)
p_zy = -sin(θ)sin(φ)
"""
cos_phi: Float[Array, " ny nx"] = jnp.cos(phi)
sin_phi: Float[Array, " ny nx"] = jnp.sin(phi)
cos_phi_sq: Float[Array, " ny nx"] = cos_phi**2
sin_phi_sq: Float[Array, " ny nx"] = sin_phi**2
sin_cos_phi: Float[Array, " ny nx"] = sin_phi * cos_phi
p_xx: Float[Array, " ny nx"] = cos_theta * cos_phi_sq + sin_phi_sq
p_xy: Float[Array, " ny nx"] = (cos_theta - 1.0) * sin_cos_phi
p_yx: Float[Array, " ny nx"] = (cos_theta - 1.0) * sin_cos_phi
p_yy: Float[Array, " ny nx"] = cos_theta * sin_phi_sq + cos_phi_sq
p_zx: Float[Array, " ny nx"] = -sin_theta * cos_phi
p_zy: Float[Array, " ny nx"] = -sin_theta * sin_phi
return p_xx, p_xy, p_yx, p_yy, p_zx, p_zy
@jaxtyped(typechecker=beartype)
def _apply_polarization_rotation(
ex_pupil: Complex[Array, " ny nx"],
ey_pupil: Complex[Array, " ny nx"],
p_xx: Float[Array, " ny nx"],
p_xy: Float[Array, " ny nx"],
p_yx: Float[Array, " ny nx"],
p_yy: Float[Array, " ny nx"],
p_zx: Float[Array, " ny nx"],
p_zy: Float[Array, " ny nx"],
) -> Tuple[
Complex[Array, " ny nx"],
Complex[Array, " ny nx"],
Complex[Array, " ny nx"],
]:
"""Apply polarization rotation to pupil field.
Parameters
----------
ex_pupil, ey_pupil : Complex[Array, " ny nx"]
Input pupil field components.
p_xx, p_xy, p_yx, p_yy, p_zx, p_zy : Float[Array, " ny nx"]
Polarization rotation matrix elements.
Returns
-------
ex_rot, ey_rot, ez_rot : Complex[Array, " ny nx"]
Rotated field components ready for focusing integral.
"""
ex_rot: Complex[Array, " ny nx"] = p_xx * ex_pupil + p_xy * ey_pupil
ey_rot: Complex[Array, " ny nx"] = p_yx * ex_pupil + p_yy * ey_pupil
ez_rot: Complex[Array, " ny nx"] = p_zx * ex_pupil + p_zy * ey_pupil
return ex_rot, ey_rot, ez_rot
@jaxtyped(typechecker=beartype)
def _compute_defocus_phase(
cos_theta: Float[Array, " ny nx"],
z_focus: ScalarFloat,
wavenumber: ScalarFloat,
) -> Complex[Array, " ny nx"]:
"""Compute defocus phase factor for off-focal-plane calculations.
Parameters
----------
cos_theta : Float[Array, " ny nx"]
cos(θ) convergence angle.
z_focus : ScalarFloat
Axial position relative to focal plane (z=0) in meters.
wavenumber : ScalarFloat
Wavenumber k = 2π/λ.
Returns
-------
defocus_phase : Complex[Array, " ny nx"]
Phase factor exp(i k z cos(θ)).
"""
defocus_phase: Complex[Array, " ny nx"] = jnp.exp(
1j * wavenumber * z_focus * cos_theta
)
return defocus_phase
[docs]
@jaxtyped(typechecker=beartype)
def high_na_focus(
pupil_field: OpticalWavefront,
na: ScalarFloat,
focal_length: ScalarFloat,
z_focus: ScalarFloat = 0.0,
output_dx: Union[ScalarFloat, None] = None,
output_grid_size: Union[Tuple[int, int], Int[Array, " 2"], None] = None,
refractive_index: ScalarFloat = 1.0,
include_aplanatic_factor: bool = True,
) -> VectorWavefront3D:
"""Compute focal field using Richards-Wolf vector diffraction integrals.
This is the main entry point for high-NA vector focusing simulations.
It takes a polarized pupil field and computes the full 3D vector field
(Ex, Ey, Ez) at the focal plane.
Parameters
----------
pupil_field : OpticalWavefront
Input field in the pupil plane. Must be polarized with shape
(H, W, 2) containing [Ex, Ey] Jones components.
na : ScalarFloat
Numerical aperture of the focusing lens.
focal_length : ScalarFloat
Focal length of the lens in meters.
z_focus : ScalarFloat, optional
Axial position relative to geometric focus in meters.
z_focus = 0 gives the focal plane. Default is 0.0.
output_dx : ScalarFloat, optional
Output pixel size in focal plane. If None, computed from
diffraction limit: dx ≈ λ/(2*NA) / 4.
output_grid_size : Tuple[int, int], optional
Output grid size. If None, uses same size as input.
refractive_index : ScalarFloat, optional
Refractive index of focal medium. Default is 1.0 (air).
include_aplanatic_factor : bool, optional
Whether to include √cos(θ) apodization. Default is True.
Returns
-------
focal_field : VectorWavefront3D
Vector field at focal plane with shape (H, W, 3) containing
[Ex, Ey, Ez] components.
Notes
-----
The algorithm:
1. Create pupil coordinates and map to convergence angles
2. Apply aplanatic apodization √cos(θ)
3. Compute polarization rotation matrix P(θ, φ)
4. Apply rotation to get (Ex', Ey', Ez') in focal coordinates
5. Apply defocus phase exp(ikz cos(θ))
6. Use 2D FFT to evaluate the focusing integral
7. Package result as VectorWavefront3D
For the focal plane (z=0), the integral reduces to a 2D Fourier
transform relationship, enabling efficient FFT-based computation.
Examples
--------
>>> from janssen.models import radially_polarized_beam
>>> from janssen.prop import high_na_focus
>>>
>>> # Create radially polarized beam in pupil
>>> pupil = radially_polarized_beam(
... wavelength=633e-9,
... dx=10e-6,
... grid_size=(256, 256),
... beam_radius=1e-3,
... )
>>>
>>> # Focus with high-NA lens
>>> focal = high_na_focus(
... pupil_field=pupil,
... na=0.9,
... focal_length=3e-3,
... )
>>>
>>> print(f"Ez peak: {jnp.max(jnp.abs(focal.ez)**2):.3e}")
"""
ny, nx = pupil_field.field.shape[:2]
wavelength: Float[Array, " "] = pupil_field.wavelength
wavenumber: Float[Array, " "] = (
2.0 * jnp.pi * refractive_index / wavelength
)
out_ny, out_nx = lax.cond(
output_grid_size is None,
lambda: (ny, nx),
lambda: (
int(jnp.asarray(output_grid_size, dtype=jnp.int32)[0]),
int(jnp.asarray(output_grid_size, dtype=jnp.int32)[1]),
),
)
output_dx = lax.cond(
output_dx is None,
lambda: wavelength / (8.0 * na),
lambda: jnp.asarray(output_dx, dtype=jnp.float64),
)
(
_,
phi,
sin_theta,
cos_theta,
_,
pupil_mask,
) = _create_pupil_coordinates(
grid_size=(ny, nx),
dx=pupil_field.dx,
na=na,
)
ex_pupil: Complex[Array, " ny nx"] = pupil_field.field[:, :, 0]
ey_pupil: Complex[Array, " ny nx"] = pupil_field.field[:, :, 1]
ex_pupil = ex_pupil * pupil_mask
ey_pupil = ey_pupil * pupil_mask
apod: Float[Array, " ny nx"] = lax.cond(
include_aplanatic_factor,
lambda: aplanatic_apodization(cos_theta),
lambda: jnp.ones_like(cos_theta),
)
ex_pupil = ex_pupil * apod
ey_pupil = ey_pupil * apod
p_xx, p_xy, p_yx, p_yy, p_zx, p_zy = _polarization_rotation_matrix(
sin_theta, cos_theta, phi
)
ex_rot, ey_rot, ez_rot = _apply_polarization_rotation(
ex_pupil, ey_pupil, p_xx, p_xy, p_yx, p_yy, p_zx, p_zy
)
defocus_phase: Complex[Array, " ny nx"] = _compute_defocus_phase(
cos_theta, z_focus, wavenumber
)
ex_rot = ex_rot * defocus_phase
ey_rot = ey_rot * defocus_phase
ez_rot = ez_rot * defocus_phase
jacobian: Float[Array, " ny nx"] = sin_theta
ex_integrand: Complex[Array, " ny nx"] = ex_rot * jacobian
ey_integrand: Complex[Array, " ny nx"] = ey_rot * jacobian
ez_integrand: Complex[Array, " ny nx"] = ez_rot * jacobian
ex_focal: Complex[Array, " ny nx"] = jnp.fft.fftshift(
jnp.fft.fft2(jnp.fft.ifftshift(ex_integrand))
)
ey_focal: Complex[Array, " ny nx"] = jnp.fft.fftshift(
jnp.fft.fft2(jnp.fft.ifftshift(ey_integrand))
)
ez_focal: Complex[Array, " ny nx"] = jnp.fft.fftshift(
jnp.fft.fft2(jnp.fft.ifftshift(ez_integrand))
)
scale_factor: Complex[Array, " "] = (
-1j * wavenumber * focal_length / (2.0 * jnp.pi) * (pupil_field.dx**2)
)
ex_focal = ex_focal * scale_factor
ey_focal = ey_focal * scale_factor
ez_focal = ez_focal * scale_factor
focal_dx: Float[Array, " "] = (
wavelength * focal_length / (nx * pupil_field.dx * na)
)
actual_dx: Float[Array, " "] = focal_dx
field_3d: Complex[Array, " ny nx 3"] = jnp.stack(
[ex_focal, ey_focal, ez_focal], axis=-1
)
focal_field: VectorWavefront3D = make_vector_wavefront_3d(
field=field_3d,
wavelength=wavelength,
dx=actual_dx,
z_position=z_focus,
)
return focal_field
[docs]
@jaxtyped(typechecker=beartype)
def debye_wolf_focus(
pupil_field: OpticalWavefront,
na: ScalarFloat,
focal_length: ScalarFloat,
z_focus: ScalarFloat = 0.0,
refractive_index: ScalarFloat = 1.0,
) -> VectorWavefront3D:
"""Compute focal field using Debye-Wolf formulation.
This is an alias for `high_na_focus` using the alternative naming
convention from the Debye approximation literature.
Parameters
----------
pupil_field : OpticalWavefront
Input polarized field in pupil plane.
na : ScalarFloat
Numerical aperture.
focal_length : ScalarFloat
Focal length in meters.
z_focus : ScalarFloat, optional
Axial position relative to focus. Default is 0.0.
refractive_index : ScalarFloat, optional
Refractive index of focal medium. Default is 1.0.
Returns
-------
focal_field : VectorWavefront3D
Vector field at focal plane.
See Also
--------
high_na_focus : Main implementation with additional options.
"""
return high_na_focus(
pupil_field=pupil_field,
na=na,
focal_length=focal_length,
z_focus=z_focus,
refractive_index=refractive_index,
)
[docs]
@jaxtyped(typechecker=beartype)
def compute_focal_volume(
pupil_field: OpticalWavefront,
na: ScalarFloat,
focal_length: ScalarFloat,
z_positions: Float[Array, " nz"],
refractive_index: ScalarFloat = 1.0,
) -> Tuple[Complex[Array, " nz ny nx 3"], Float[Array, " "]]:
"""Compute 3D focal volume at multiple z planes.
Uses vmap to efficiently compute the focal field at multiple
axial positions.
Parameters
----------
pupil_field : OpticalWavefront
Input polarized field in pupil plane.
na : ScalarFloat
Numerical aperture.
focal_length : ScalarFloat
Focal length in meters.
z_positions : Float[Array, " nz"]
Array of axial positions relative to focus.
refractive_index : ScalarFloat, optional
Refractive index of focal medium. Default is 1.0.
Returns
-------
focal_volume : Complex[Array, " nz ny nx 3"]
3D vector field volume with [Ex, Ey, Ez] at each z.
dx_focal : Float[Array, " "]
Transverse pixel size at focal plane.
Examples
--------
>>> z_range = jnp.linspace(-2e-6, 2e-6, 41) # ±2 μm
>>> volume, dx = compute_focal_volume(
... pupil, na=0.9, focal_length=3e-3, z_positions=z_range
... )
>>> # volume has shape (41, ny, nx, 3)
"""
def focus_at_z(z: Float[Array, " "]) -> Complex[Array, " ny nx 3"]:
result = high_na_focus(
pupil_field=pupil_field,
na=na,
focal_length=focal_length,
z_focus=z,
refractive_index=refractive_index,
)
return result.field
first_result = high_na_focus(
pupil_field=pupil_field,
na=na,
focal_length=focal_length,
z_focus=z_positions[0],
refractive_index=refractive_index,
)
dx_focal: Float[Array, " "] = first_result.dx
focal_volume: Complex[Array, " nz ny nx 3"] = jax.vmap(focus_at_z)(
z_positions
)
return focal_volume, dx_focal
[docs]
@jaxtyped(typechecker=beartype)
def scalar_focus_for_comparison(
pupil_field: OpticalWavefront,
focal_length: ScalarFloat,
z_focus: ScalarFloat = 0.0,
) -> OpticalWavefront:
"""Compute scalar focal field for comparison with vector result.
This function ignores polarization and computes a simple Fourier
transform focal field, demonstrating what scalar theory predicts.
Useful for comparing with vector results to highlight the differences.
Parameters
----------
pupil_field : OpticalWavefront
Input field (polarization ignored, uses total amplitude).
focal_length : ScalarFloat
Focal length in meters.
z_focus : ScalarFloat, optional
Axial position. Default is 0.0.
Returns
-------
scalar_focal : OpticalWavefront
Scalar focal field (Ez effects not modeled).
Notes
-----
The scalar approximation:
- Ignores polarization rotation (no Ez generation)
- Predicts identical PSF for all input polarizations
- Fails at high NA where vector effects are significant
"""
ny, nx = pupil_field.field.shape[:2]
wavelength = pupil_field.wavelength
def polarized_to_scalar() -> Complex[Array, " ny nx"]:
total_field: Complex[Array, " ny nx"] = jnp.sqrt(
jnp.abs(pupil_field.field[:, :, 0]) ** 2
+ jnp.abs(pupil_field.field[:, :, 1]) ** 2
)
phase: Float[Array, " ny nx"] = jnp.angle(pupil_field.field[:, :, 0])
return total_field * jnp.exp(1j * phase)
def scalar_passthrough() -> Complex[Array, " ny nx"]:
return pupil_field.field
scalar_pupil: Complex[Array, " ny nx"] = lax.cond(
pupil_field.field.ndim == POLARIZED_FIELD_NDIM,
polarized_to_scalar,
scalar_passthrough,
)
focal_field: Complex[Array, " ny nx"] = jnp.fft.fftshift(
jnp.fft.fft2(jnp.fft.ifftshift(scalar_pupil))
)
focal_dx: Float[Array, " "] = (
wavelength * focal_length / (nx * pupil_field.dx)
)
scalar_focal: OpticalWavefront = make_optical_wavefront(
field=focal_field,
wavelength=wavelength,
dx=focal_dx,
z_position=z_focus,
polarization=False,
)
return scalar_focal