"""Helper functions for optical simulations.
Extended Summary
----------------
Utility functions for creating computational grids, manipulating optical
fields, and performing common operations in optical simulations.
Routine Listings
----------------
create_spatial_grid : function
Creates a 2D spatial grid for optical propagation
normalize_field : function
Normalizes a complex field to unit power
add_phase_screen : function
Adds a phase screen to a complex field
field_intensity : function
Calculates intensity from a complex field
scale_pixel : function
Rescales OpticalWavefront pixel size while keeping array shape fixed
Notes
-----
These helper functions provide common operations needed in optical
simulations and are optimized for use with JAX transformations.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Tuple, Union
from jaxtyping import Array, Complex, Float, Int, Num, jaxtyped
from janssen.types import (
OpticalWavefront,
ScalarFloat,
ScalarInteger,
ScalarNumeric,
make_optical_wavefront,
)
[docs]
@jaxtyped(typechecker=beartype)
def sellmeier(
wavelength_nm: ScalarNumeric | Num[Array, " nn"],
sellmeier_b: Num[Array, " 3"],
sellmeier_c: Num[Array, " 3"],
) -> Union[Float[Array, " "], Float[Array, " nn"]]:
r"""Calculate refractive index using the Sellmeier equation.
Parameters
----------
wavelength_nm : ScalarNumeric | Num[Array, " nn"]
Wavelength in nanometers. Can be scalar or array.
sellmeier_b : Num[Array, " 3"]
Sellmeier B coefficients [B1, B2, B3].
sellmeier_c : Num[Array, " 3"]
Sellmeier C coefficients [C1, C2, C3] in micrometers squared.
Returns
-------
n : Union[Float[Array, " "], Float[Array, " nn"]]
Refractive index. Shape matches input wavelength.
Notes
-----
The Sellmeier equation relates refractive index to wavelength:
.. math::
n^2(\lambda) = 1 + \sum_{i=1}^{3} \frac{B_i\lambda^2}{\lambda^2 - C_i}
where :math:`\lambda` is in micrometers and :math:`C_i` are in
micrometers squared.
"""
wavelength_um_sq: Float[Array, " nn"] | Float[Array, " "] = (
wavelength_nm / 1000.0
) ** 2
terms: Num[Array, "... 3"] = (
sellmeier_b * wavelength_um_sq[..., None]
) / (wavelength_um_sq[..., None] - sellmeier_c)
n_squared: Float[Array, " nn"] | Float[Array, " "] = 1.0 + jnp.sum(
terms, axis=-1
)
refractive_index: Float[Array, " nn"] | Float[Array, " "] = jnp.sqrt(
n_squared
)
return refractive_index
[docs]
@jaxtyped(typechecker=beartype)
def create_spatial_grid(
diameter: ScalarNumeric | Num[Array, " 2"],
num_points: ScalarInteger | Int[Array, " 2"],
) -> Tuple[Float[Array, " hh ww"], Float[Array, " hh ww"]]:
"""
Create a 2D spatial grid for optical propagation.
Parameters
----------
diameter : ScalarNumeric | Num[Array, " 2"]
Physical size of the grid in meters. Can be a scalar (square grid)
or array of shape (2,) with [diameter_x, diameter_y] for rectangular
grids.
num_points : ScalarInteger | Int[Array, " 2"]
Number of points in each dimension. Can be a scalar (square grid)
or array of shape (2,) with [num_points_x, num_points_y] for
rectangular grids.
Returns
-------
xx : Float[Array, " hh ww"]
X coordinate grid in meters.
yy : Float[Array, " hh ww"]
Y coordinate grid in meters.
Notes
-----
- Create a linear space of points along the x-axis.
- Create a linear space of points along the y-axis.
- Create a meshgrid of spatial coordinates.
- Return the meshgrid.
- Supports both square and non-square grids without if-else statements.
Examples
--------
Square grid:
>>> xx, yy = create_spatial_grid(1e-3, 256)
Rectangular grid:
>>> grid_size = jnp.asarray((256, 512), dtype=jnp.int32)
>>> xx, yy = create_spatial_grid(jnp.array([1e-3, 2e-3]), grid_size)
"""
diameter_arr: Num[Array, "..."] = jnp.atleast_1d(diameter)
num_points_arr: Int[Array, "..."] = jnp.atleast_1d(num_points)
diameter_arr = jnp.broadcast_to(diameter_arr, (2,))
num_points_arr = jnp.broadcast_to(num_points_arr, (2,))
diameter_x: Num[Array, " "] = diameter_arr[0]
diameter_y: Num[Array, " "] = diameter_arr[1]
num_points_x: Int[Array, " "] = num_points_arr[0]
num_points_y: Int[Array, " "] = num_points_arr[1]
x: Float[Array, " ww"] = jnp.linspace(
-diameter_x / 2, diameter_x / 2, num_points_x
)
y: Float[Array, " hh"] = jnp.linspace(
-diameter_y / 2, diameter_y / 2, num_points_y
)
xx: Float[Array, " hh ww"]
yy: Float[Array, " hh ww"]
xx, yy = jnp.meshgrid(x, y)
return (xx, yy)
[docs]
@jaxtyped(typechecker=beartype)
def normalize_field(
field: Complex[Array, " hh ww"],
) -> Complex[Array, " hh ww"]:
"""
Normalize complex field to unit power.
Parameters
----------
field : Complex[Array, " hh ww"]
Input complex field.
Returns
-------
normalized_field : Complex[Array, " hh ww"]
Normalized complex field.
Notes
-----
- Calculate the power of the field as the sum of the square of
the absolute value of the field.
- Normalize the field by dividing by the square root of the power.
- Return the normalized field.
"""
power: Float[Array, " "] = jnp.sum(jnp.abs(field) ** 2)
normalized_field: Complex[Array, " hh ww"] = field / jnp.sqrt(power)
return normalized_field
[docs]
@jaxtyped(typechecker=beartype)
def add_phase_screen(
field: Num[Array, " hh ww"],
phase: Float[Array, " hh ww"],
) -> Complex[Array, " H W"]:
"""
Add a phase screen to a complex field.
Parameters
----------
field : Num[Array, " hh ww"]
Input complex field.
phase : Float[Array, " hh ww"]
Phase screen to add.
Returns
-------
screened_field : Complex[Array, " hh ww"]
Field with phase screen added.
Notes
-----
- Multiply the input field by the exponential of the phase screen.
- Return the screened field.
"""
screened_field: Complex[Array, " hh ww"] = field * jnp.exp(1j * phase)
return screened_field
[docs]
@jaxtyped(typechecker=beartype)
def field_intensity(field: Complex[Array, " hh ww"]) -> Float[Array, " hh ww"]:
"""
Calculate intensity from complex field.
Parameters
----------
field : Complex[Array, " hh ww"]
Input complex field.
Returns
-------
intensity : Float[Array, " hh ww"]
Intensity of the field.
Notes
-----
- Calculate the intensity as the square of the absolute value of the
field.
- Return the intensity.
"""
intensity: Float[Array, " hh ww"] = jnp.power(jnp.abs(field), 2)
return intensity
[docs]
@jaxtyped(typechecker=beartype)
def scale_pixel(
wavefront: OpticalWavefront,
new_dx: ScalarFloat,
) -> OpticalWavefront:
"""
Rescale OpticalWavefront pixel size while keeping array shape fixed.
JAX-compatible (jit/vmap-safe). Crops or pads to preserve shape.
Parameters
----------
wavefront : OpticalWavefront
OpticalWavefront to be resized.
new_dx : ScalarFloat
New pixel size (meters).
Returns
-------
scaled_wavefront : OpticalWavefront
Resized OpticalWavefront with updated pixel size
and resized field, which is of the same size as
the original field.
Notes
-----
- If the new pixel size is smaller than the old one,
then the new FOV is smaller too at the same field
size. So we will first find the new smaller FOV,
and crop to that size with the current pixel size.
Then we will resize to the new pizel size with the
cropped FOV so that the size of the field remains
the same.
So here the order is crop, then resize.
- If the new pixel size is larger than the old one,
then the new FOV of the final field is larger too
- Return the resized OpticalWavefront.
"""
field: Complex[Array, " hh ww"] = wavefront.field
old_dx: ScalarFloat = wavefront.dx
hh: int
ww: int
hh, ww = field.shape
scale: ScalarFloat = new_dx / old_dx
current_fov_h: ScalarFloat = hh * old_dx
current_fov_w: ScalarFloat = ww * old_dx
new_fov_h: ScalarFloat = hh * new_dx
new_fov_w: ScalarFloat = ww * new_dx
def _smaller_pixel_size(
field: Complex[Array, " hh ww"],
) -> Complex[Array, " hh ww"]:
"""
If the new pixel size is smaller than the old one.
Then the new FOV is smaller too at the same field
size. So we will first find the new smaller FOV,
and crop to that size with the current pixel size.
Then we will resize to the new pizel size with the
cropped FOV so that the size of the field remains
the same.
So here the order is crop, then resize.
"""
new_h: Int[Array, " "] = jnp.floor(new_fov_h / old_dx).astype(int)
new_w: Int[Array, " "] = jnp.floor(new_fov_w / old_dx).astype(int)
start_h: Int[Array, " "] = jnp.floor(
(current_fov_h - new_fov_h) / (2 * old_dx)
).astype(int)
start_w: Int[Array, " "] = jnp.floor(
(current_fov_w - new_fov_w) / (2 * old_dx)
).astype(int)
cropped: Complex[Array, " new_h new_w"] = jax.lax.dynamic_slice(
field, (start_h, start_w), (new_h, new_w)
)
resized: Complex[Array, " hh ww"] = jax.image.resize(
cropped,
(hh, ww),
method="linear",
antialias=True,
)
return resized
def _larger_pixel_size(
field: Complex[Array, " hh ww"],
) -> Complex[Array, " hh ww"]:
"""
If the new pixel size is larger than the old one.
Then the new FOV of the final field is larger too
at the same field size. So we will need to first
get the current FOV data with the new pixel size,
which will be smaller than the current field size.
Following this, we need to pad out to fill the
field.
So here the order is resize then pad.
"""
data_minima_h: Float[Array, " "] = jnp.min(jnp.abs(field))
new_h: Int[Array, " "] = jnp.floor(current_fov_h / new_dx).astype(int)
new_w: Int[Array, " "] = jnp.floor(current_fov_w / new_dx).astype(int)
resized: Complex[Array, " H W"] = jax.image.resize(
field,
(new_h, new_w),
method="linear",
antialias=True,
)
pad_h_0: Int[Array, " "] = jnp.floor((hh - new_h) / 2).astype(int)
pad_h_1: Int[Array, " "] = hh - (new_h + pad_h_0)
pad_w_0: Int[Array, " "] = jnp.floor((ww - new_w) / 2).astype(int)
pad_w_1: Int[Array, " "] = ww - (new_w + pad_w_0)
return jnp.pad(
resized,
((pad_h_0, pad_h_1), (pad_w_0, pad_w_1)),
mode="constant",
constant_values=data_minima_h,
)
resized_field = jax.lax.cond(
scale > 1.0, _larger_pixel_size, _smaller_pixel_size, field
)
scaled_wavefront: OpticalWavefront = make_optical_wavefront(
field=resized_field,
dx=new_dx,
wavelength=wavefront.wavelength,
z_position=wavefront.z_position,
)
return scaled_wavefront