Source code for janssen.prop.free_space_prop

"""Lens propagation functions.

Extended Summary
----------------
Optical field propagation methods based on scalar diffraction theory.
Implements various propagation algorithms including angular spectrum,
Fresnel, and Fraunhofer propagation methods for simulating light
propagation in optical systems.

Routine Listings
----------------
angular_spectrum_prop : function
    Propagates a complex optical field using the angular spectrum method
    without making any paraxial approximations.
correct_propagator : function
    Automatically selects the most appropriate propagation method.
digital_zoom : function
    Zooms an optical wavefront by a specified factor.
fresnel_prop : function
    Propagates a complex optical field using the Fresnel approximation
fraunhofer_prop : function
    Propagates a complex optical field using the Fraunhofer
    approximation.
lens_propagation : function
    Propagates an optical wavefront through a lens.
optical_zoom : function
    Modifies the calibration of an optical wavefront without changing
    its field.

Notes
-----
All propagation methods are implemented using FFT-based algorithms for
efficiency. The choice of propagation method depends on the Fresnel
number and the specific requirements of the simulation.
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Optional
from jaxtyping import Array, Bool, Complex, Float, Integer, jaxtyped

from janssen.lenses import create_lens_phase
from janssen.types import (
    LensParams,
    OpticalWavefront,
    ScalarFloat,
    ScalarInteger,
    ScalarNumeric,
    make_optical_wavefront,
)


[docs] @jaxtyped(typechecker=beartype) def angular_spectrum_prop( incoming: OpticalWavefront, z_move: ScalarNumeric, refractive_index: Optional[ScalarNumeric] = 1.0, ) -> OpticalWavefront: """Propagate a complex field using the angular spectrum method. Parameters ---------- incoming : OpticalWavefront PyTree with the following parameters: field : Complex[Array, " hh ww"] Input complex field wavelength : Float[Array, " "] Wavelength of light in meters dx : Float[Array, " "] Grid spacing in meters z_position : Float[Array, " "] Wave front position in meters z_move : ScalarNumeric Propagation distance in meters This is in free space. refractive_index : Optional[ScalarNumeric], optional Index of refraction of the medium. Default is 1.0 (vacuum). Returns ------- propagated : OpticalWavefront Propagated wave front Notes ----- The angular spectrum method is an exact solution to the Helmholtz equation for propagation in homogeneous media. It decomposes the field into plane waves, propagates each component, and reconstructs the field. The transfer function is H(fx,fy) = exp(i*k*z*sqrt(1 - (lambda*fx)^2 - (lambda*fy)^2)) where k = 2*pi/lambda is the wavenumber. For spatial frequencies where (lambda*fx)^2 + (lambda*fy)^2 > 1, the waves become evanescent and are set to zero to prevent numerical instability. This method makes no paraxial approximation and is valid for all propagation distances, though numerical accuracy may degrade for very large distances due to sampling limitations. Algorithm: 1. Compute spatial frequency grids fx, fy using FFT frequencies 2. Build transfer function H = exp(i*k*z*sqrt(1 - lambda^2*(fx^2+fy^2))) 3. Create evanescent mask where fx^2 + fy^2 <= 1/lambda^2 4. FFT input field, multiply by masked transfer function, inverse FFT 5. Return propagated wavefront with updated z_position """ ny: ScalarInteger = incoming.field.shape[0] nx: ScalarInteger = incoming.field.shape[1] wavenumber: Float[Array, " "] = 2 * jnp.pi / incoming.wavelength path_length: Float[Array, " "] = refractive_index * z_move fx: Float[Array, " hh"] = jnp.fft.fftfreq(nx, d=incoming.dx) fy: Float[Array, " ww"] = jnp.fft.fftfreq(ny, d=incoming.dx) fx_mesh: Float[Array, " hh ww"] fy_mesh: Float[Array, " hh ww"] fx_mesh, fy_mesh = jnp.meshgrid(fx, fy) fsq_mesh: Float[Array, " hh ww"] = (fx_mesh**2) + (fy_mesh**2) asp_transfer: Complex[Array, " "] = jnp.exp( 1j * wavenumber * path_length * jnp.sqrt(1 - (incoming.wavelength**2) * fsq_mesh), ) evanescent_mask: Bool[Array, " hh ww"] = ( 1 / incoming.wavelength ) ** 2 >= fsq_mesh h_mask: Complex[Array, " hh ww"] = asp_transfer * evanescent_mask field_ft: Complex[Array, " hh ww"] = jnp.fft.fft2(incoming.field) propagated_ft: Complex[Array, " hh ww"] = field_ft * h_mask propagated_field: Complex[Array, " hh ww"] = jnp.fft.ifft2(propagated_ft) propagated: OpticalWavefront = make_optical_wavefront( field=propagated_field, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position + path_length, ) return propagated
[docs] @jaxtyped(typechecker=beartype) def fresnel_prop( incoming: OpticalWavefront, z_move: ScalarNumeric, refractive_index: Optional[ScalarNumeric] = 1.0, ) -> OpticalWavefront: """Propagate a complex field using the Fresnel approximation. Parameters ---------- incoming : OpticalWavefront PyTree with the following parameters: field : Complex[Array, " hh ww"] Input complex field wavelength : Float[Array, " "] Wavelength of light in meters dx : Float[Array, " "] Grid spacing in meters z_position : Float[Array, " "] Wave front position in meters z_move : ScalarNumeric Propagation distance in meters This is in free space. refractive_index : Optional[ScalarNumeric], optional Index of refraction of the medium. Default is 1.0 (vacuum). Returns ------- propagated : OpticalWavefront Propagated wave front Notes ----- The Fresnel approximation is a paraxial approximation to scalar diffraction theory. It assumes that the propagation angle is small, which allows simplification of the angular spectrum transfer function using a Taylor expansion: sqrt(1 - lambda^2*(fx^2+fy^2)) ≈ 1 - lambda^2*(fx^2+fy^2)/2. This leads to the Fresnel transfer function: H(fx,fy) = exp(-i*pi*lambda*z*(fx^2 + fy^2)) The output field is multiplied by a global phase factor exp(i*k*z) representing the on-axis propagation phase. The Fresnel approximation is valid when the Fresnel number F = a^2/(λz) is large (typically F > 1), where a is the characteristic aperture size. For small Fresnel numbers, use fraunhofer_prop instead. Algorithm: 1. Compute spatial frequency grids fx, fy using FFT frequencies 2. Build Fresnel transfer function H = exp(-i*pi*lambda*z*(fx^2+fy^2)) 3. FFT input field, multiply by transfer function, inverse FFT 4. Multiply result by global phase exp(i*k*z) 5. Return propagated wavefront with updated z_position """ ny: ScalarInteger = incoming.field.shape[0] nx: ScalarInteger = incoming.field.shape[1] k: Float[Array, " "] = (2 * jnp.pi) / incoming.wavelength path_length: Float[Array, " "] = refractive_index * z_move fx: Float[Array, " hh"] = jnp.fft.fftfreq(nx, d=incoming.dx) fy: Float[Array, " ww"] = jnp.fft.fftfreq(ny, d=incoming.dx) fx_mesh: Float[Array, " hh ww"] fy_mesh: Float[Array, " hh ww"] fx_mesh, fy_mesh = jnp.meshgrid(fx, fy) transfer_phase: Float[Array, " hh ww"] = ( -jnp.pi * incoming.wavelength * path_length * (fx_mesh**2 + fy_mesh**2) ) transfer_function: Complex[Array, " hh ww"] = jnp.exp(1j * transfer_phase) field_ft: Complex[Array, " hh ww"] = jnp.fft.fft2(incoming.field) propagated_ft: Complex[Array, " hh ww"] = field_ft * transfer_function propagated_field: Complex[Array, " hh ww"] = jnp.fft.ifft2(propagated_ft) global_phase: Complex[Array, " "] = jnp.exp(1j * k * path_length) final_propagated_field: Complex[Array, " hh ww"] = ( global_phase * propagated_field ) propagated: OpticalWavefront = make_optical_wavefront( field=final_propagated_field, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position + path_length, ) return propagated
[docs] @jaxtyped(typechecker=beartype) def fraunhofer_prop( incoming: OpticalWavefront, z_move: ScalarNumeric, refractive_index: ScalarNumeric = 1.0, ) -> OpticalWavefront: """Propagate a complex field using the Fraunhofer approximation. Parameters ---------- incoming : OpticalWavefront PyTree with the following parameters: field : Complex[Array, " hh ww"] Input complex field wavelength : Float[Array, " "] Wavelength of light in meters dx : Float[Array, " "] Grid spacing in meters z_position : Float[Array, " "] Wave front position in meters z_move : ScalarNumeric Propagation distance in meters. This is in free space. refractive_index : ScalarNumeric, optional Index of refraction of the medium. Default is 1.0 (vacuum). Returns ------- propagated : OpticalWavefront Propagated wave front. Note that the output pixel size (dx) changes according to Fraunhofer scaling: dx_out = wavelength * z / (N * dx_in) Notes ----- The Fraunhofer approximation represents far-field diffraction where the diffraction pattern is proportional to the Fourier transform of the aperture function. This is the limiting case of Fresnel diffraction when the propagation distance is very large. The full Fraunhofer diffraction integral gives: U(x',y') = exp(i*k*z)/(i*lambda*z) * exp(i*k*(x'^2+y'^2)/(2*z)) * FT{U(x,y)} * dx^2 where FT denotes the Fourier transform and the output coordinates are related to spatial frequencies by x' = lambda*z*fx, y' = lambda*z*fy. The quadratic phase term exp(i*k*(x'^2+y'^2)/(2*z)) represents the spherical wavefront curvature in the output plane and is essential for coherent imaging and phase-sensitive applications. The output pixel size changes according to the Fraunhofer scaling relation: dx_out = lambda * z / (N * dx_in), where N is the array size. The Fraunhofer approximation is valid when the Fresnel number F = a^2/(λz) is small (typically F < 1), where a is the characteristic aperture size. For large Fresnel numbers, use fresnel_prop or angular_spectrum_prop. Algorithm --------- 1. Compute centered FFT of input field using ifftshift/fft2/fftshift 2. Compute output pixel size dx_out = lambda*z/(N*dx_in) 3. Create output coordinate grid and compute quadratic phase 4. Apply global phase factor exp(i*k*z) 5. Apply amplitude scaling 1/(i*lambda*z) and area element dx^2 6. Multiply by quadratic phase term 7. Return propagated wavefront with new dx and updated z_position """ ny: ScalarInteger = incoming.field.shape[0] nx: ScalarInteger = incoming.field.shape[1] k: Float[Array, " "] = 2 * jnp.pi / incoming.wavelength path_length: Float[Array, " "] = refractive_index * z_move dx_out: Float[Array, " "] = ( incoming.wavelength * path_length / (nx * incoming.dx) ) field_ft: Complex[Array, " hh ww"] = jnp.fft.fftshift( jnp.fft.fft2(jnp.fft.ifftshift(incoming.field)) ) x_out: Float[Array, " ww"] = (jnp.arange(nx) - nx / 2) * dx_out y_out: Float[Array, " hh"] = (jnp.arange(ny) - ny / 2) * dx_out x_mesh: Float[Array, " hh ww"] y_mesh: Float[Array, " hh ww"] x_mesh, y_mesh = jnp.meshgrid(x_out, y_out) quadratic_phase: Complex[Array, " hh ww"] = jnp.exp( 1j * k * (x_mesh**2 + y_mesh**2) / (2 * path_length) ) global_phase: Complex[Array, " "] = jnp.exp(1j * k * path_length) amplitude_scale_factor: Complex[Array, " "] = 1 / ( 1j * incoming.wavelength * path_length ) propagated_field: Complex[Array, " hh ww"] = ( global_phase * amplitude_scale_factor * quadratic_phase * field_ft * (incoming.dx**2) ) propagated: OpticalWavefront = make_optical_wavefront( field=propagated_field, wavelength=incoming.wavelength, dx=dx_out, z_position=incoming.z_position + path_length, ) return propagated
[docs] @jaxtyped(typechecker=beartype) def fraunhofer_prop_scaled( incoming: OpticalWavefront, z_move: ScalarNumeric, output_dx: ScalarFloat, refractive_index: ScalarNumeric = 1.0, ) -> OpticalWavefront: """Propagate using Fraunhofer with output at specified pixel size. Performs Fraunhofer propagation with the output sampled at the desired pixel size. The output array has the same shape as the input. Parameters ---------- incoming : OpticalWavefront Input optical wavefront. z_move : ScalarNumeric Propagation distance in meters. output_dx : ScalarFloat Desired output pixel size in meters. refractive_index : ScalarNumeric, optional Index of refraction. Default is 1.0 (vacuum). Returns ------- propagated : OpticalWavefront Propagated wavefront with specified output pixel size. Output array shape equals input array shape. Notes ----- The standard Fraunhofer relation is: U(x') = C * FT{U(x)} evaluated at fx = x'/(λz) where C includes phase and amplitude factors. For output pixel n (centered), we want x'_n = (n - N/2) * output_dx. This corresponds to spatial frequency fx_n = x'_n/(λz). We achieve this by using a chirp-z transform (CZT) approach: instead of FFT which samples at fx = n/(N*dx_in), we use interpolation in Fourier space to sample at the desired frequencies. The key insight is that the DFT samples at frequencies: fx_fft[n] = (n - N/2) / (N * dx_in) And we want to sample at: fx_out[n] = (n - N/2) * output_dx / (λz) The ratio is: fx_out / fx_fft = output_dx * N * dx_in / (λz) This is equivalent to scaling the output coordinates, which we implement by interpolating the FFT result. """ ny: ScalarInteger = incoming.field.shape[0] nx: ScalarInteger = incoming.field.shape[1] k: Float[Array, " "] = 2 * jnp.pi / incoming.wavelength path_length: Float[Array, " "] = refractive_index * z_move dx_fraunhofer: Float[Array, " "] = ( incoming.wavelength * path_length / (nx * incoming.dx) ) fft_zoom_scale: Float[Array, " "] = output_dx / dx_fraunhofer field_ft: Complex[Array, " hh ww"] = jnp.fft.fftshift( jnp.fft.fft2(jnp.fft.ifftshift(incoming.field)) ) center_y: Float[Array, " "] = (ny - 1) / 2.0 center_x: Float[Array, " "] = (nx - 1) / 2.0 out_y: Float[Array, " hh"] = jnp.arange(ny, dtype=jnp.float64) out_x: Float[Array, " ww"] = jnp.arange(nx, dtype=jnp.float64) in_y: Float[Array, " hh"] = (out_y - center_y) * fft_zoom_scale + center_y in_x: Float[Array, " ww"] = (out_x - center_x) * fft_zoom_scale + center_x in_y_mesh: Float[Array, " hh ww"] in_x_mesh: Float[Array, " hh ww"] in_y_mesh, in_x_mesh = jnp.meshgrid(in_y, in_x, indexing="ij") scaled_ft_real: Float[Array, " hh ww"] = jax.scipy.ndimage.map_coordinates( field_ft.real, [in_y_mesh, in_x_mesh], order=1, mode="constant", cval=0.0, ) scaled_ft_imag: Float[Array, " hh ww"] = jax.scipy.ndimage.map_coordinates( field_ft.imag, [in_y_mesh, in_x_mesh], order=1, mode="constant", cval=0.0, ) scaled_ft: Complex[Array, " hh ww"] = scaled_ft_real + 1j * scaled_ft_imag x_out: Float[Array, " ww"] = (jnp.arange(nx) - nx / 2) * output_dx y_out: Float[Array, " hh"] = (jnp.arange(ny) - ny / 2) * output_dx x_mesh: Float[Array, " hh ww"] y_mesh: Float[Array, " hh ww"] x_mesh, y_mesh = jnp.meshgrid(x_out, y_out) quadratic_phase: Complex[Array, " hh ww"] = jnp.exp( 1j * k * (x_mesh**2 + y_mesh**2) / (2 * path_length) ) global_phase: Complex[Array, " "] = jnp.exp(1j * k * path_length) amplitude_scale_factor: Complex[Array, " "] = 1 / ( 1j * incoming.wavelength * path_length ) propagated_field: Complex[Array, " hh ww"] = ( global_phase * amplitude_scale_factor * quadratic_phase * scaled_ft * (incoming.dx**2) ) propagated: OpticalWavefront = make_optical_wavefront( field=propagated_field, wavelength=incoming.wavelength, dx=output_dx, z_position=incoming.z_position + path_length, ) return propagated
[docs] @jaxtyped(typechecker=beartype) def digital_zoom( wavefront: OpticalWavefront, zoom_factor: ScalarNumeric, ) -> OpticalWavefront: """Zoom an optical wavefront by a specified factor. Key is this returns the same sized array as the original wavefront. Parameters ---------- wavefront : OpticalWavefront Incoming optical wavefront. zoom_factor : ScalarNumeric Zoom factor (greater than 1 to zoom in, less than 1 to zoom out). Returns ------- zoomed_wavefront : OpticalWavefront Zoomed optical wavefront of the same spatial dimensions. Notes ----- Algorithm: For zoom in (zoom_factor >= 1.0): - Calculate the crop fraction (1 / zoom_factor) to determine the central region to extract - Create interpolation coordinates for the zoomed region centered on the image - Use scipy.ndimage.map_coordinates with bilinear interpolation to sample the field - Return the zoomed field with adjusted pixel size (dx / zoom_factor) For zoom out (zoom_factor < 1.0): - Calculate the shrink fraction (zoom_factor) to determine the final image size - Create a coordinate mapping from the full image to the shrunken region - Use scipy.ndimage.map_coordinates to interpolate the original field - Apply a mask to zero out regions outside the shrunken area (padding effect) - Return the zoomed field with adjusted pixel size (dx / zoom_factor) """ epsilon: Float[Array, " "] = 1e-10 zoom_factor: Float[Array, " "] = jnp.maximum(zoom_factor, epsilon) hh: int ww: int hh, ww = wavefront.field.shape def zoom_in_fn() -> Complex[Array, " hh ww"]: crop_fraction: Float[Array, " "] = 1.0 / zoom_factor center_y: Float[Array, " "] = (hh - 1) / 2 center_x: Float[Array, " "] = (ww - 1) / 2 half_crop_y: Float[Array, " "] = (hh * crop_fraction) / 2 half_crop_x: Float[Array, " "] = (ww * crop_fraction) / 2 y_interp: Float[Array, " hh"] = jnp.linspace( center_y - half_crop_y, center_y + half_crop_y, hh ) x_interp: Float[Array, " ww"] = jnp.linspace( center_x - half_crop_x, center_x + half_crop_x, ww ) y_grid: Float[Array, " hh ww"] x_grid: Float[Array, " hh ww"] y_grid, x_grid = jnp.meshgrid(y_interp, x_interp, indexing="ij") zoomed: Complex[Array, " hh ww"] = jax.scipy.ndimage.map_coordinates( wavefront.field.real, [y_grid, x_grid], order=1, mode="constant", cval=0.0, ) + 1j * jax.scipy.ndimage.map_coordinates( wavefront.field.imag, [y_grid, x_grid], order=1, mode="constant", cval=0.0, ) return zoomed def zoom_out_fn() -> Complex[Array, " hh ww"]: shrink_fraction: Float[Array, " "] = zoom_factor shrunk_h: Integer[Array, " "] = jnp.round(hh * shrink_fraction).astype( jnp.int32 ) shrunk_w: Integer[Array, " "] = jnp.round(ww * shrink_fraction).astype( jnp.int32 ) shrunk_h: Integer[Array, " "] = jnp.minimum(shrunk_h, hh) shrunk_w: Integer[Array, " "] = jnp.minimum(shrunk_w, ww) center_y: Float[Array, " "] = (hh - 1) / 2 center_x: Float[Array, " "] = (ww - 1) / 2 half_shrunk_y: Float[Array, " "] = shrunk_h / 2 half_shrunk_x: Float[Array, " "] = shrunk_w / 2 y_coords: Float[Array, " hh"] = jnp.linspace(0, hh - 1, hh) x_coords: Float[Array, " ww"] = jnp.linspace(0, ww - 1, ww) def get_interp_coord( coord: Float[Array, " "], center: Float[Array, " "], half_size: Float[Array, " "], full_size: Integer[Array, " "], ) -> Float[Array, " "]: norm_coord: Float[Array, " "] = (coord - (center - half_size)) / ( 2 * half_size ) return norm_coord * (full_size - 1) y_grid: Float[Array, " hh ww"] x_grid: Float[Array, " hh ww"] y_grid, x_grid = jnp.meshgrid(y_coords, x_coords, indexing="ij") mask: Bool[Array, " hh ww"] = ( jnp.abs(y_grid - center_y) <= half_shrunk_y ) & (jnp.abs(x_grid - center_x) <= half_shrunk_x) y_interp: Float[Array, " hh ww"] = get_interp_coord( y_grid, center_y, half_shrunk_y, hh ) x_interp: Float[Array, " hh ww"] = get_interp_coord( x_grid, center_x, half_shrunk_x, ww ) zoomed_real: Float[Array, " hh ww"] = ( jax.scipy.ndimage.map_coordinates( wavefront.field.real, [y_interp, x_interp], order=1, mode="constant", cval=0.0, ) ) zoomed_imag: Float[Array, " hh ww"] = ( jax.scipy.ndimage.map_coordinates( wavefront.field.imag, [y_interp, x_interp], order=1, mode="constant", cval=0.0, ) ) zoomed: Complex[Array, " hh ww"] = ( zoomed_real + 1j * zoomed_imag ) * mask return zoomed zoomed_field: Complex[Array, " hh ww"] = jax.lax.cond( zoom_factor >= 1.0, zoom_in_fn, zoom_out_fn, ) zoomed_wavefront: OpticalWavefront = make_optical_wavefront( field=zoomed_field, wavelength=wavefront.wavelength, dx=wavefront.dx / zoom_factor, z_position=wavefront.z_position, ) return zoomed_wavefront
[docs] @jaxtyped(typechecker=beartype) def optical_zoom( wavefront: OpticalWavefront, zoom_factor: ScalarNumeric, ) -> OpticalWavefront: """Modify the calibration of an optical wavefront without changing field. Parameters ---------- wavefront : OpticalWavefront Incoming optical wavefront. zoom_factor : ScalarNumeric Zoom factor (greater than 1 to zoom in, less than 1 to zoom out). Returns ------- zoomed_wavefront : OpticalWavefront Zoomed optical wavefront of the same spatial dimensions. """ new_dx = wavefront.dx * zoom_factor zoomed_wavefront: OpticalWavefront = make_optical_wavefront( field=wavefront.field, wavelength=wavefront.wavelength, dx=new_dx, z_position=wavefront.z_position, ) return zoomed_wavefront
[docs] @jaxtyped(typechecker=beartype) def lens_propagation( incoming: OpticalWavefront, lens: LensParams ) -> OpticalWavefront: """Propagate an optical wavefront through a lens. The lens is modeled as a thin lens with a given focal length and diameter. Parameters ---------- incoming : OpticalWavefront The incoming optical wavefront lens : LensParams The lens parameters including focal length and diameter Returns ------- outgoing : OpticalWavefront The propagated optical wavefront after passing through the lens Notes ----- Algorithm: - Create a meshgrid of coordinates based on the incoming wavefront's shape and pixel size. - Calculate the phase profile and transmission function of the lens. - Apply the phase screen to the incoming wavefront's field. - Return the new optical wavefront with the updated field, wavelength, and pixel size. """ hh: int ww: int hh, ww = incoming.field.shape xline: Float[Array, " ww"] = ( jnp.linspace(-ww // 2, ww // 2 - 1, ww) * incoming.dx ) yline: Float[Array, " hh"] = ( jnp.linspace(-hh // 2, hh // 2 - 1, hh) * incoming.dx ) xarr: Float[Array, " hh ww"] yarr: Float[Array, " hh ww"] xarr, yarr = jnp.meshgrid(xline, yline) phase_profile: Float[Array, " hh ww"] transmission: Float[Array, " hh ww"] phase_profile, transmission = create_lens_phase( xarr, yarr, lens, incoming.wavelength ) transmitted_field: Complex[Array, " hh ww"] = ( incoming.field * transmission * jnp.exp(1j * phase_profile) ) outgoing: OpticalWavefront = make_optical_wavefront( field=transmitted_field, wavelength=incoming.wavelength, dx=incoming.dx, z_position=incoming.z_position, ) return outgoing
[docs] @jaxtyped(typechecker=beartype) def correct_propagator( incoming: OpticalWavefront, z_move: ScalarNumeric, refractive_index: Optional[ScalarNumeric] = 1.0, ) -> OpticalWavefront: """Automatically select and apply the most appropriate propagator. This function selects the optimal propagation method based on the Fresnel number and sampling criteria. It uses: - Angular spectrum for very short distances or high spatial frequencies - Fresnel propagation for intermediate distances - Fraunhofer propagation for far-field distances Parameters ---------- incoming : OpticalWavefront PyTree with the following parameters: field : Complex[Array, " hh ww"] Input complex field wavelength : Float[Array, " "] Wavelength of light in meters dx : Float[Array, " "] Grid spacing in meters z_position : Float[Array, " "] Wave front position in meters z_move : ScalarNumeric Propagation distance in meters (in free space) refractive_index : Optional[ScalarNumeric], optional Index of refraction of the medium. Default is 1.0 (vacuum) Returns ------- propagated : OpticalWavefront Propagated wave front using the most appropriate method Notes ----- Implementation: 1. Get field dimensions (ny, nx) 2. Calculate field intensity distribution 3. Create coordinate arrays centered at field center 4. Calculate RMS width in both x and y directions 5. Use larger RMS width times 2 as characteristic aperture size 6. Account for refractive index in path length calculation 7. Calculate Fresnel number: F = a²/(λz) 8. Check angular spectrum validity criterion: dx < 0.5 * z * λ / L where L is the field size 9. Use nested jax.lax.cond to select propagator: - If F > 1.0 AND angular spectrum valid: use angular spectrum - Else if F > 0.1: use Fresnel propagation - Else: use Fraunhofer propagation (far-field) Selection criteria: - Angular spectrum: F > 1 and sampling valid (most accurate, no paraxial approximation) - Fresnel: 0.1 < F ≤ 1 (near to intermediate field) - Fraunhofer: F < 0.1 (far-field) The angular spectrum method is preferred when applicable as it makes no paraxial approximations. """ fresnel_number_threshold: ScalarFloat = 0.1 ny: ScalarInteger = incoming.field.shape[0] nx: ScalarInteger = incoming.field.shape[1] field_intensity: Float[Array, " hh ww"] = jnp.abs(incoming.field) ** 2 total_intensity: Float[Array, " "] = jnp.sum(field_intensity) y_coords: Float[Array, " hh"] = (jnp.arange(ny) - ny / 2) * incoming.dx x_coords: Float[Array, " ww"] = (jnp.arange(nx) - nx / 2) * incoming.dx y_mesh: Float[Array, " hh ww"] x_mesh: Float[Array, " hh ww"] y_mesh, x_mesh = jnp.meshgrid(y_coords, x_coords, indexing="ij") x_rms: Float[Array, " "] = jnp.sqrt( jnp.sum(field_intensity * x_mesh**2) / (total_intensity + 1e-10) ) y_rms: Float[Array, " "] = jnp.sqrt( jnp.sum(field_intensity * y_mesh**2) / (total_intensity + 1e-10) ) aperture_size: Float[Array, " "] = jnp.maximum(x_rms, y_rms) * 2 path_length: Float[Array, " "] = refractive_index * z_move fresnel_number: Float[Array, " "] = aperture_size**2 / ( incoming.wavelength * jnp.abs(path_length) ) field_size: Float[Array, " "] = jnp.maximum( nx * incoming.dx, ny * incoming.dx ) angular_spectrum_valid: Bool[Array, " "] = ( incoming.dx < 0.5 * jnp.abs(path_length) * incoming.wavelength / field_size ) def use_angular_spectrum() -> OpticalWavefront: return angular_spectrum_prop(incoming, z_move, refractive_index) def use_fresnel() -> OpticalWavefront: return fresnel_prop(incoming, z_move, refractive_index) def use_fraunhofer() -> OpticalWavefront: return fraunhofer_prop(incoming, z_move, refractive_index) def select_fresnel_or_fraunhofer() -> OpticalWavefront: return jax.lax.cond( fresnel_number > fresnel_number_threshold, use_fresnel, use_fraunhofer, ) propagated: OpticalWavefront = jax.lax.cond( (fresnel_number > 1.0) & angular_spectrum_valid, use_angular_spectrum, select_fresnel_or_fraunhofer, ) return propagated