"""Optical beam generation functions.
Extended Summary
----------------
Factory functions for creating optical beams with physically meaningful
parameters. These functions generate OpticalWavefront PyTrees with the
correct amplitude and phase profiles for various beam types commonly
used in optical microscopy and imaging systems.
Routine Listings
----------------
plane_wave : function
Creates a uniform plane wave with optional tilt
sinusoidal_wave : function
Creates a sinusoidal interference pattern
collimated_gaussian : function
Creates a collimated Gaussian beam with flat phase
converging_gaussian : function
Creates a Gaussian beam converging to a focus
diverging_gaussian : function
Creates a Gaussian beam diverging from a virtual source
gaussian_beam : function
Creates a Gaussian beam from complex beam parameter q
bessel_beam : function
Creates a Bessel beam with specified cone angle
laguerre_gaussian : function
Creates Laguerre-Gaussian modes (includes vortex beams)
hermite_gaussian : function
Creates Hermite-Gaussian modes
propagate_beam : function
Generates a beam at multiple z positions as a PropagatingWavefront
Notes
-----
All beam generators return OpticalWavefront PyTrees that are compatible
with JAX transformations. The phase profiles encode the propagation
behavior of the beam - converging beams have quadratic phase that causes
focusing, while collimated beams have flat phase.
The key insight is that intensity alone does not determine beam behavior.
Two beams with identical intensity profiles but different phase profiles
will evolve completely differently upon propagation.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Optional, Tuple, Union
from jaxtyping import Array, Complex, Float, Int, jaxtyped
from janssen.optics import create_spatial_grid
from janssen.utils import bessel_j0
from janssen.types import (
OpticalWavefront,
PropagatingWavefront,
ScalarFloat,
ScalarInteger,
make_optical_wavefront,
)
from janssen.types.factory import make_propagating_wavefront
[docs]
@jaxtyped(typechecker=beartype)
def plane_wave(
wavelength: ScalarFloat,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
amplitude: ScalarFloat = 1.0,
tilt_x: ScalarFloat = 0.0,
tilt_y: ScalarFloat = 0.0,
z_position: ScalarFloat = 0.0,
) -> OpticalWavefront:
r"""Create a uniform plane wave with optional tilt.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]]
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
amplitude : ScalarFloat, optional
Amplitude of the plane wave, by default 1.0.
tilt_x : ScalarFloat, optional
Tilt angle along x-axis in radians (small angle), by default 0.0.
tilt_y : ScalarFloat, optional
Tilt angle along y-axis in radians (small angle), by default 0.0.
z_position : ScalarFloat, optional
Initial z position of the wavefront in meters, by default 0.0.
Returns
-------
wavefront : OpticalWavefront
Plane wave OpticalWavefront PyTree.
Notes
-----
A plane wave has uniform amplitude and linear phase (flat if no tilt).
The tilt angles introduce a linear phase ramp corresponding to
propagation at an angle to the optical axis:
.. math::
E(x, y) = A \\exp(i k (x \\sin\\theta_x + y \\sin\\theta_y))
For small angles, :math:`\\sin\\theta \\approx \\theta`.
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny: ScalarInteger = grid_size[0]
nx: ScalarInteger = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx, yy = create_spatial_grid(diameter, num_points)
k: Float[Array, " "] = 2.0 * jnp.pi / jnp.asarray(wavelength)
phase: Float[Array, " ny nx"] = k * (
xx * jnp.asarray(tilt_x) + yy * jnp.asarray(tilt_y)
)
field: Complex[Array, " ny nx"] = jnp.asarray(amplitude) * jnp.exp(
1j * phase
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def sinusoidal_wave(
wavelength: ScalarFloat,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
period: ScalarFloat,
direction: ScalarFloat = 0.0,
amplitude: ScalarFloat = 1.0,
z_position: ScalarFloat = 0.0,
) -> OpticalWavefront:
r"""Create a sinusoidal interference pattern.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]]
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
period : ScalarFloat
Spatial period of the sinusoidal pattern in meters.
direction : ScalarFloat
Direction angle of the sinusoidal pattern in radians.
0 = horizontal stripes, pi/2 = vertical stripes.
By default 0.0.
amplitude : ScalarFloat
Peak amplitude of the wave, by default 1.0.
z_position : ScalarFloat
Initial z position of the wavefront in meters, by default 0.0.
Returns
-------
wavefront : OpticalWavefront
Sinusoidal wave OpticalWavefront PyTree.
Notes
-----
A sinusoidal wave has an intensity profile that varies as:
.. math::
E(x, y) = A \cos\left(\frac{2\pi}{T}(x \cos\theta + y \sin\theta)
\right)
where :math:`T` is the spatial period and :math:`\theta` is the
direction angle.
This pattern represents the interference of two plane waves and is
useful for testing optical systems, creating gratings, and studying
diffraction phenomena.
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny: ScalarInteger = grid_size[0]
nx: ScalarInteger = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx, yy = create_spatial_grid(diameter, num_points)
period_arr: Float[Array, " "] = jnp.asarray(period, dtype=jnp.float64)
direction_arr: Float[Array, " "] = jnp.asarray(
direction, dtype=jnp.float64
)
spatial_coord: Float[Array, " ny nx"] = xx * jnp.cos(
direction_arr
) + yy * jnp.sin(direction_arr)
sinusoid: Float[Array, " ny nx"] = jnp.cos(
2.0 * jnp.pi * spatial_coord / period_arr
)
field: Complex[Array, " ny nx"] = (
jnp.asarray(amplitude, dtype=jnp.float64)
* sinusoid
* jnp.ones_like(sinusoid, dtype=jnp.complex128)
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def collimated_gaussian(
wavelength: ScalarFloat,
waist: ScalarFloat,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
center: Optional[Tuple[ScalarFloat, ScalarFloat]] = (0.0, 0.0),
amplitude: Optional[ScalarFloat] = 1.0,
z_position: Optional[ScalarFloat] = 0.0,
) -> OpticalWavefront:
r"""Create a collimated Gaussian beam with flat phase.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
waist : ScalarFloat
Beam waist (1/e² intensity radius) in meters.
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
center : Tuple[ScalarFloat, ScalarFloat], optional
Center position (x0, y0) in meters, by default (0.0, 0.0).
amplitude : ScalarFloat, optional
Peak amplitude at beam center, by default 1.0.
z_position : ScalarFloat, optional
Initial z position of the wavefront in meters, by default 0.0.
Returns
-------
wavefront : OpticalWavefront
Collimated Gaussian beam OpticalWavefront PyTree.
Notes
-----
A collimated Gaussian beam has a Gaussian intensity profile and
flat (constant) phase across the beam:
.. math::
E(x, y) = A \\exp\\left(-\\frac{(x-x_0)^2 + (y-y_0)^2}{w^2}\\right)
where :math:`w` is the beam waist (1/e² intensity radius).
This represents a beam at its waist position where the wavefront
is planar. Upon propagation, the beam will expand due to diffraction.
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny: ScalarInteger = grid_size[0]
nx: ScalarInteger = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx: Float[Array, " ny nx"]
yy: Float[Array, " ny nx"]
xx, yy = create_spatial_grid(diameter, num_points)
x0: Float[Array, " "] = jnp.asarray(center[0], dtype=jnp.float64)
y0: Float[Array, " "] = jnp.asarray(center[1], dtype=jnp.float64)
w: Float[Array, " "] = jnp.asarray(waist, dtype=jnp.float64)
r2: Float[Array, " ny nx"] = (xx - x0) ** 2 + (yy - y0) ** 2
field: Complex[Array, " ny nx"] = (
jnp.asarray(amplitude, dtype=jnp.float64)
* jnp.exp(-r2 / (w**2))
* jnp.ones_like(r2, dtype=jnp.complex128)
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def converging_gaussian(
wavelength: ScalarFloat,
waist: ScalarFloat,
focus_distance: ScalarFloat,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
center: Optional[Tuple[ScalarFloat, ScalarFloat]] = (0.0, 0.0),
amplitude: Optional[ScalarFloat] = 1.0,
z_position: Optional[ScalarFloat] = 0.0,
) -> OpticalWavefront:
r"""Create a Gaussian beam converging to a focus.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
waist : ScalarFloat
Current beam waist (1/e² intensity radius) in meters.
This is the waist at the current plane, not at the focus.
focus_distance : ScalarFloat
Distance to the focus in meters (positive = focus downstream).
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[int, Tuple[int, int]]
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
center : Tuple[ScalarFloat, ScalarFloat], optional
Center position (x0, y0) in meters, by default (0.0, 0.0).
amplitude : ScalarFloat, optional
Peak amplitude at beam center, by default 1.0.
z_position : ScalarFloat, optional
Initial z position of the wavefront in meters, by default 0.0.
Returns
-------
wavefront : OpticalWavefront
Converging Gaussian beam OpticalWavefront PyTree.
Notes
-----
A converging Gaussian beam has a Gaussian intensity profile with a
spherical (quadratic) converging phase:
.. math::
E(x, y) = A \\exp\\left(-\\frac{r^2}{w^2}\\right)
\\exp\\left(-i \\frac{k r^2}{2 f}\\right)
where :math:`f` is the focus distance and the negative sign indicates
convergence (wavefront curving inward toward the optical axis).
The radius of curvature R equals the focus distance f for a beam
that will come to a focus at distance f downstream.
Upon propagation, this beam will decrease in size until reaching
the focus, then expand as a diverging beam.
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny: ScalarInteger = grid_size[0]
nx: ScalarInteger = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx: Float[Array, " ny nx"]
yy: Float[Array, " ny nx"]
xx, yy = create_spatial_grid(diameter, num_points)
x0: Float[Array, " "] = jnp.asarray(center[0], dtype=jnp.float64)
y0: Float[Array, " "] = jnp.asarray(center[1], dtype=jnp.float64)
w: Float[Array, " "] = jnp.asarray(waist, dtype=jnp.float64)
f: Float[Array, " "] = jnp.asarray(focus_distance, dtype=jnp.float64)
k: Float[Array, " "] = 2.0 * jnp.pi / jnp.asarray(wavelength)
r2: Float[Array, " ny nx"] = (xx - x0) ** 2 + (yy - y0) ** 2
gaussian_amplitude: Float[Array, " ny nx"] = jnp.asarray(
amplitude, dtype=jnp.float64
) * jnp.exp(-r2 / (w**2))
safe_floor: float = 1e-15
f_safe: Float[Array, " "] = jnp.where(
jnp.abs(f) < safe_floor, safe_floor, f
)
converging_phase: Float[Array, " ny nx"] = -k * r2 / (2.0 * f_safe)
field: Complex[Array, " ny nx"] = gaussian_amplitude * jnp.exp(
1j * converging_phase
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def diverging_gaussian(
wavelength: ScalarFloat,
waist: ScalarFloat,
source_distance: ScalarFloat,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
center: Optional[Tuple[ScalarFloat, ScalarFloat]] = (0.0, 0.0),
amplitude: Optional[ScalarFloat] = 1.0,
z_position: Optional[ScalarFloat] = 0.0,
) -> OpticalWavefront:
r"""Create a Gaussian beam diverging from a virtual source.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
waist : ScalarFloat
Current beam waist (1/e² intensity radius) in meters.
This is the waist at the current plane.
source_distance : ScalarFloat
Distance from the virtual source point in meters
(positive = source was upstream).
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]]
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
center : Tuple[ScalarFloat, ScalarFloat], optional
Center position (x0, y0) in meters, by default (0.0, 0.0).
amplitude : ScalarFloat, optional
Peak amplitude at beam center, by default 1.0.
z_position : ScalarFloat, optional
Initial z position of the wavefront in meters, by default 0.0.
Returns
-------
wavefront : OpticalWavefront
Diverging Gaussian beam OpticalWavefront PyTree.
Notes
-----
A diverging Gaussian beam has a Gaussian intensity profile with a
spherical (quadratic) diverging phase:
.. math::
E(x, y) = A \\exp\\left(-\\frac{r^2}{w^2}\\right)
\\exp\\left(+i \\frac{k r^2}{2 R}\\right)
where :math:`R` is the radius of curvature (positive for diverging,
equal to the source distance).
This represents a beam that originated from a point source at
distance `source_distance` upstream. Upon propagation, this beam
will continue to expand.
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny: ScalarInteger = grid_size[0]
nx: ScalarInteger = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx: Float[Array, " ny nx"]
yy: Float[Array, " ny nx"]
xx, yy = create_spatial_grid(diameter, num_points)
x0: Float[Array, " "] = jnp.asarray(center[0], dtype=jnp.float64)
y0: Float[Array, " "] = jnp.asarray(center[1], dtype=jnp.float64)
w: Float[Array, " "] = jnp.asarray(waist, dtype=jnp.float64)
rr: Float[Array, " "] = jnp.asarray(source_distance, dtype=jnp.float64)
k: Float[Array, " "] = 2.0 * jnp.pi / jnp.asarray(wavelength)
r2: Float[Array, " ny nx"] = (xx - x0) ** 2 + (yy - y0) ** 2
gaussian_amplitude: Float[Array, " ny nx"] = jnp.asarray(
amplitude, dtype=jnp.float64
) * jnp.exp(-r2 / (w**2))
safe_floor: float = 1e-15
rr_safe: Float[Array, " "] = jnp.where(
jnp.abs(rr) < safe_floor, safe_floor, rr
)
diverging_phase: Float[Array, " ny nx"] = k * r2 / (2.0 * rr_safe)
field: Complex[Array, " ny nx"] = gaussian_amplitude * jnp.exp(
1j * diverging_phase
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def gaussian_beam(
wavelength: ScalarFloat,
waist_0: ScalarFloat,
z_from_waist: ScalarFloat,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
center: Optional[Tuple[ScalarFloat, ScalarFloat]] = (0.0, 0.0),
amplitude: Optional[ScalarFloat] = 1.0,
z_position: Optional[ScalarFloat] = 0.0,
include_gouy_phase: Optional[bool] = True,
) -> OpticalWavefront:
r"""Create a Gaussian beam at arbitrary position from waist.
This is the most general Gaussian beam generator, using the full
Gaussian beam propagation formulas including Gouy phase.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
waist_0 : ScalarFloat
Beam waist at the waist position (minimum spot size) in meters.
z_from_waist : ScalarFloat
Distance from the beam waist in meters.
Positive = downstream from waist (diverging).
Negative = upstream from waist (converging toward waist).
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]]
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
center : Tuple[ScalarFloat, ScalarFloat], optional
Center position (x0, y0) in meters, by default (0.0, 0.0).
amplitude : ScalarFloat, optional
Peak amplitude at beam waist, by default 1.0.
z_position : ScalarFloat, optional
Initial z position of the wavefront in meters, by default 0.0.
include_gouy_phase : bool, optional
Whether to include the Gouy phase shift, by default True.
Returns
-------
wavefront : OpticalWavefront
Gaussian beam OpticalWavefront PyTree.
Notes
-----
The complete Gaussian beam field is:
.. math::
E(r, z) = A \\frac{w_0}{w(z)} \\exp\\left(-\\frac{r^2}{w(z)^2}\\right)
\\exp\\left(-ikz - i\\frac{kr^2}{2R(z)} + i\\zeta(z)\\right)
where:
- :math:`w(z) = w_0 \\sqrt{1 + (z/z_R)^2}` is the beam radius
- :math:`R(z) = z (1 + (z_R/z)^2)` is the radius of curvature
- :math:`\\zeta(z) = \\arctan(z/z_R)` is the Gouy phase
- :math:`z_R = \\pi w_0^2 / \\lambda` is the Rayleigh range
At the waist (z=0), the beam has minimum size and flat phase.
The Gouy phase represents an additional phase shift accumulated
through the focus.
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny: ScalarInteger = grid_size[0]
nx: ScalarInteger = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx: Float[Array, " ny nx"]
yy: Float[Array, " ny nx"]
xx, yy = create_spatial_grid(diameter, num_points)
x0: Float[Array, " "] = jnp.asarray(center[0], dtype=jnp.float64)
y0: Float[Array, " "] = jnp.asarray(center[1], dtype=jnp.float64)
w0: Float[Array, " "] = jnp.asarray(waist_0, dtype=jnp.float64)
z: Float[Array, " "] = jnp.asarray(z_from_waist, dtype=jnp.float64)
lam: Float[Array, " "] = jnp.asarray(wavelength, dtype=jnp.float64)
k: Float[Array, " "] = 2.0 * jnp.pi / lam
rayleigh_range: Float[Array, " "] = jnp.pi * w0**2 / lam
beam_radius_at_z: Float[Array, " "] = w0 * jnp.sqrt(
1.0 + (z / rayleigh_range) ** 2
)
z_safe: Float[Array, " "] = jnp.where(jnp.abs(z) < 1e-15, 1e-15, z)
radius_of_curvature: Float[Array, " "] = z_safe * (
1.0 + (rayleigh_range / z_safe) ** 2
)
gouy_phase: Float[Array, " "] = jnp.arctan2(z, rayleigh_range)
radial_distance_squared: Float[Array, " ny nx"] = (xx - x0) ** 2 + (
yy - y0
) ** 2
amplitude_with_energy_conservation: Float[Array, " ny nx"] = (
jnp.asarray(amplitude, dtype=jnp.float64)
* (w0 / beam_radius_at_z)
* jnp.exp(-radial_distance_squared / (beam_radius_at_z**2))
)
curvature_phase: Float[Array, " ny nx"] = (
-k * radial_distance_squared / (2.0 * radius_of_curvature)
)
is_at_waist: Float[Array, " "] = jnp.abs(z) < 1e-12
curvature_phase = jnp.where(is_at_waist, 0.0, curvature_phase)
gouy_phase_adjusted: Float[Array, " "] = jnp.where(
include_gouy_phase, gouy_phase, 0.0
)
total_phase: Float[Array, " ny nx"] = curvature_phase + gouy_phase_adjusted
field: Complex[Array, " ny nx"] = (
amplitude_with_energy_conservation * jnp.exp(1j * total_phase)
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def bessel_beam(
wavelength: ScalarFloat,
cone_angle: ScalarFloat,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
amplitude: Optional[ScalarFloat] = 1.0,
z_position: Optional[ScalarFloat] = 0.0,
) -> OpticalWavefront:
r"""Create a Bessel beam with specified cone angle.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
cone_angle : ScalarFloat
Cone half-angle in radians. Determines the transverse wave
vector component: k_r = k * sin(cone_angle).
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]]
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
amplitude : ScalarFloat, optional
Peak amplitude at beam center, by default 1.0.
z_position : ScalarFloat, optional
Initial z position of the wavefront in meters, by default 0.0.
Returns
-------
wavefront : OpticalWavefront
Bessel beam OpticalWavefront PyTree.
Notes
-----
An ideal Bessel beam has a transverse profile given by the
zeroth-order Bessel function:
.. math::
E(r) = A J_0(k_r r)
where :math:`k_r = k \\sin(\\theta)` is the transverse wave vector
and :math:`\\theta` is the cone half-angle.
Bessel beams are "non-diffracting" - their transverse profile
remains constant upon propagation (in the ideal infinite case).
In practice, physical Bessel beams have finite extent and
eventually diffract.
The central lobe radius (first zero) is approximately
:math:`r_0 \\approx 2.405 / k_r`.
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny: ScalarInteger = grid_size[0]
nx: ScalarInteger = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx: Float[Array, " ny nx"]
yy: Float[Array, " ny nx"]
xx, yy = create_spatial_grid(diameter, num_points)
k: Float[Array, " "] = 2.0 * jnp.pi / jnp.asarray(wavelength)
transverse_wave_vector: Float[Array, " "] = k * jnp.sin(
jnp.asarray(cone_angle)
)
radial_distance: Float[Array, " ny nx"] = jnp.sqrt(xx**2 + yy**2)
bessel_profile: Float[Array, " ny nx"] = bessel_j0(
transverse_wave_vector * radial_distance
)
field: Complex[Array, " ny nx"] = (
jnp.asarray(amplitude, dtype=jnp.float64)
* bessel_profile
* jnp.ones_like(radial_distance, dtype=jnp.complex128)
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def laguerre_gaussian(
wavelength: ScalarFloat,
waist: ScalarFloat,
p: ScalarInteger,
l: ScalarInteger,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
center: Optional[Tuple[ScalarFloat, ScalarFloat]] = (0.0, 0.0),
amplitude: Optional[ScalarFloat] = 1.0,
z_position: Optional[ScalarFloat] = 0.0,
) -> OpticalWavefront:
r"""Create a Laguerre-Gaussian mode at the beam waist.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
waist : ScalarFloat
Beam waist (1/e² intensity radius of fundamental mode) in meters.
p : ScalarInteger
Radial mode index (number of radial nodes), p >= 0.
l : ScalarInteger
Azimuthal mode index (topological charge for vortex beams).
Can be positive or negative.
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]]
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
center : Tuple[ScalarFloat, ScalarFloat], optional
Center position (x0, y0) in meters, by default (0.0, 0.0).
amplitude : ScalarFloat, optional
Peak amplitude normalization, by default 1.0.
z_position : ScalarFloat, optional
Initial z position of the wavefront in meters, by default 0.0.
Returns
-------
wavefront : OpticalWavefront
Laguerre-Gaussian mode OpticalWavefront PyTree.
Notes
-----
The Laguerre-Gaussian modes at the waist are:
.. math::
E_{p,l}(r, \\phi) = A \\left(\\frac{r\\sqrt{2}}{w}\\right)^{|l|}
L_p^{|l|}\\left(\\frac{2r^2}{w^2}\\right)
\\exp\\left(-\\frac{r^2}{w^2}\\right)
\\exp(i l \\phi)
where :math:`L_p^{|l|}` is the generalized Laguerre polynomial.
Special cases:
- (p=0, l=0): Fundamental Gaussian mode
- (p=0, l≠0): Optical vortex beams with topological charge l
- (p>0, l=0): Radial modes with p rings
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny: ScalarInteger = grid_size[0]
nx: ScalarInteger = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx * dx, ny * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray([nx, ny], dtype=jnp.int32)
xx: Float[Array, " ny nx"]
yy: Float[Array, " ny nx"]
xx, yy = create_spatial_grid(diameter, num_points)
x0: Float[Array, " "] = jnp.asarray(center[0], dtype=jnp.float64)
y0: Float[Array, " "] = jnp.asarray(center[1], dtype=jnp.float64)
w: Float[Array, " "] = jnp.asarray(waist, dtype=jnp.float64)
x_shifted: Float[Array, " ny nx"] = xx - x0
y_shifted: Float[Array, " ny nx"] = yy - y0
r: Float[Array, " ny nx"] = jnp.sqrt(x_shifted**2 + y_shifted**2)
phi: Float[Array, " ny nx"] = jnp.arctan2(y_shifted, x_shifted)
rho: Float[Array, " ny nx"] = r * jnp.sqrt(2.0) / w
rho2: Float[Array, " ny nx"] = 2.0 * r**2 / (w**2)
abs_l: Int[Array, " "] = jnp.abs(jnp.asarray(l, dtype=jnp.int32))
p_int: Int[Array, " "] = jnp.asarray(p, dtype=jnp.int32)
def _laguerre_polynomial(
n: Int[Array, " "], alpha: Int[Array, " "], x: Float[Array, " ny nx"]
) -> Float[Array, " ny nx"]:
"""Compute generalized Laguerre polynomial L_n^alpha(x)."""
def body_fn(
k: int,
carry: Tuple[Float[Array, " ny nx"], Float[Array, " ny nx"]],
) -> Tuple[Float[Array, " ny nx"], Float[Array, " ny nx"]]:
L_km1, L_km2 = carry
k_float = jnp.float64(k)
alpha_float = jnp.float64(alpha)
L_k = (
(2 * k_float - 1 + alpha_float - x) * L_km1
- (k_float - 1 + alpha_float) * L_km2
) / k_float
return (L_k, L_km1)
L_0 = jnp.ones_like(x)
L_1 = 1.0 + jnp.float64(alpha) - x
result = jax.lax.cond(
n == 0,
lambda: L_0,
lambda: jax.lax.cond(
n == 1,
lambda: L_1,
lambda: jax.lax.fori_loop(2, n + 1, body_fn, (L_1, L_0))[0],
),
)
return result
L_pl: Float[Array, " ny nx"] = _laguerre_polynomial(p_int, abs_l, rho2)
radial_amplitude: Float[Array, " ny nx"] = (
(rho ** jnp.float64(abs_l)) * jnp.exp(-(r**2) / (w**2)) * L_pl
)
l_float: Float[Array, " "] = jnp.float64(l)
azimuthal_phase: Complex[Array, " ny nx"] = jnp.exp(1j * l_float * phi)
field: Complex[Array, " ny nx"] = (
jnp.asarray(amplitude, dtype=jnp.float64)
* radial_amplitude
* azimuthal_phase
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def hermite_gaussian(
wavelength: ScalarFloat,
waist: ScalarFloat,
n: ScalarInteger,
m: ScalarInteger,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
center: Optional[Tuple[ScalarFloat, ScalarFloat]] = (0.0, 0.0),
amplitude: Optional[ScalarFloat] = 1.0,
z_position: Optional[ScalarFloat] = 0.0,
) -> OpticalWavefront:
r"""Create a Hermite-Gaussian mode at the beam waist.
Parameters
----------
wavelength : ScalarFloat
Wavelength of light in meters.
waist : ScalarFloat
Beam waist (1/e² intensity radius of fundamental mode) in meters.
n : ScalarInteger
Mode index in x direction (number of nodes along x), n >= 0.
m : ScalarInteger
Mode index in y direction (number of nodes along y), m >= 0.
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]]
Size of the computational grid. If int, creates square grid.
If tuple, specifies (height, width).
center : Tuple[ScalarFloat, ScalarFloat], optional
Center position (x0, y0) in meters, by default (0.0, 0.0).
amplitude : ScalarFloat, optional
Peak amplitude normalization, by default 1.0.
z_position : ScalarFloat, optional
Initial z position of the wavefront in meters, by default 0.0.
Returns
-------
wavefront : OpticalWavefront
Hermite-Gaussian mode OpticalWavefront PyTree.
Notes
-----
The Hermite-Gaussian modes at the waist are:
.. math::
E_{n,m}(x, y) = A H_n\\left(\\frac{x\\sqrt{2}}{w}\\right)
H_m\\left(\\frac{y\\sqrt{2}}{w}\\right)
\\exp\\left(-\\frac{x^2 + y^2}{w^2}\\right)
where :math:`H_n` is the physicist's Hermite polynomial.
Special cases:
- (n=0, m=0): Fundamental Gaussian mode (TEM00)
- (n=1, m=0): TEM10 mode with one node along x
- (n=0, m=1): TEM01 mode with one node along y
"""
grid_size = jnp.asarray(grid_size, dtype=jnp.int32)
ny_grid: Int[Array, " "] = grid_size[0]
nx_grid: Int[Array, " "] = grid_size[1]
diameter: Float[Array, " 2"] = jnp.asarray(
[nx_grid * dx, ny_grid * dx], dtype=jnp.float64
)
num_points: Int[Array, " 2"] = jnp.asarray(
[nx_grid, ny_grid], dtype=jnp.int32
)
xx: Float[Array, " ny_grid nx_grid"]
yy: Float[Array, " ny_grid nx_grid"]
xx, yy = create_spatial_grid(diameter, num_points)
x0: Float[Array, " "] = jnp.asarray(center[0], dtype=jnp.float64)
y0: Float[Array, " "] = jnp.asarray(center[1], dtype=jnp.float64)
w: Float[Array, " "] = jnp.asarray(waist, dtype=jnp.float64)
x_norm: Float[Array, " ny_grid nx_grid"] = (xx - x0) * jnp.sqrt(2.0) / w
y_norm: Float[Array, " ny_grid nx_grid"] = (yy - y0) * jnp.sqrt(2.0) / w
n_int: Int[Array, " "] = jnp.asarray(n, dtype=jnp.int32)
m_int: Int[Array, " "] = jnp.asarray(m, dtype=jnp.int32)
def _hermite_polynomial(
order: Int[Array, " "], x: Float[Array, " ny_grid nx_grid"]
) -> Float[Array, " ny_grid nx_grid"]:
"""Compute physicist's Hermite polynomial H_n(x)."""
def _body_fn(
k: int,
carry: Tuple[
Float[Array, " ny_grid nx_grid"],
Float[Array, " ny_grid nx_grid"],
],
) -> Tuple[
Float[Array, " ny_grid nx_grid"], Float[Array, " ny_grid nx_grid"]
]:
hh_km1: Float[Array, " ny_grid nx_grid"]
hh_km2: Float[Array, " ny_grid nx_grid"]
hh_km1, hh_km2 = carry
hh_k: Float[Array, " ny_grid nx_grid"] = (2.0 * x * hh_km1) - (
2.0 * (k - 1) * hh_km2
)
return (hh_k, hh_km1)
hh_0: Float[Array, " ny_grid nx_grid"] = jnp.ones_like(x)
hh_1: Float[Array, " ny_grid nx_grid"] = 2.0 * x
result: Float[Array, " ny_grid nx_grid"] = jax.lax.cond(
order == 0,
lambda: hh_0,
lambda: jax.lax.cond(
order == 1,
lambda: hh_1,
lambda: jax.lax.fori_loop(
2, order + 1, _body_fn, (hh_1, hh_0)
)[0],
),
)
return result
hh_n: Float[Array, " ny_grid nx_grid"] = _hermite_polynomial(n_int, x_norm)
hh_m: Float[Array, " ny_grid nx_grid"] = _hermite_polynomial(m_int, y_norm)
r2: Float[Array, " ny_grid nx_grid"] = (xx - x0) ** 2 + (yy - y0) ** 2
gaussian: Float[Array, " ny_grid nx_grid"] = jnp.exp(-r2 / (w**2))
field: Complex[Array, " ny_grid nx_grid"] = (
jnp.asarray(amplitude, dtype=jnp.float64)
* hh_n
* hh_m
* gaussian
* jnp.ones_like(gaussian, dtype=jnp.complex128)
)
wavefront: OpticalWavefront = make_optical_wavefront(
field=field,
wavelength=wavelength,
dx=dx,
z_position=z_position,
)
return wavefront
[docs]
@jaxtyped(typechecker=beartype)
def propagate_beam(
beam_type: str,
z_positions: Float[Array, " zz"],
wavelength: ScalarFloat,
dx: ScalarFloat,
grid_size: Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]],
waist: ScalarFloat = 1e-3,
amplitude: ScalarFloat = 1.0,
center: Tuple[ScalarFloat, ScalarFloat] = (0.0, 0.0),
tilt_x: ScalarFloat = 0.0,
tilt_y: ScalarFloat = 0.0,
focus_distance: ScalarFloat = 1.0,
source_distance: ScalarFloat = 1.0,
waist_0: ScalarFloat = 1e-3,
include_gouy_phase: bool = True,
cone_angle: ScalarFloat = 0.01,
period: ScalarFloat = 1e-4,
direction: ScalarFloat = 0.0,
p: ScalarInteger = 0,
l: ScalarInteger = 0,
n: ScalarInteger = 0,
m: ScalarInteger = 0,
) -> PropagatingWavefront:
"""Generate a beam at multiple z positions as a PropagatingWavefront.
Parameters
----------
beam_type : str
Type of beam to generate. One of:
- "plane_wave": Uniform plane wave with optional tilt
- "sinusoidal_wave": Sinusoidal interference pattern
- "collimated_gaussian": Collimated Gaussian beam with flat phase
- "converging_gaussian": Gaussian beam converging to a focus
- "diverging_gaussian": Gaussian beam diverging from a source
- "gaussian_beam": General Gaussian beam at arbitrary z from waist
- "bessel_beam": Bessel beam with specified cone angle
- "laguerre_gaussian": Laguerre-Gaussian modes
- "hermite_gaussian": Hermite-Gaussian modes
z_positions : Float[Array, " zz"]
Array of z positions at which to evaluate the beam (meters).
For "gaussian_beam", these are distances from the waist.
For other beam types, these set the z_position attribute.
wavelength : ScalarFloat
Wavelength of light in meters.
dx : ScalarFloat
Spatial sampling interval (pixel size) in meters.
grid_size : Union[Int[Array, " 2"], Tuple[ScalarInteger, ScalarInteger]]
Size of the computational grid as (height, width).
waist : ScalarFloat
Beam waist (1/e² intensity radius) in meters. Used by
collimated_gaussian, converging_gaussian, diverging_gaussian,
laguerre_gaussian, hermite_gaussian. Default is 1e-3.
amplitude : ScalarFloat
Peak amplitude, by default 1.0.
center : Tuple[ScalarFloat, ScalarFloat]
Center position (x0, y0) in meters, by default (0.0, 0.0).
tilt_x : ScalarFloat
Tilt angle along x-axis in radians for plane_wave, by default 0.0.
tilt_y : ScalarFloat
Tilt angle along y-axis in radians for plane_wave, by default 0.0.
focus_distance : ScalarFloat
Distance to focus for converging_gaussian (meters). Default is 1.0.
source_distance : ScalarFloat
Distance from source for diverging_gaussian (meters). Default is 1.0.
waist_0 : ScalarFloat
Beam waist at the waist position for gaussian_beam (meters).
Default is 1e-3.
include_gouy_phase : bool
Whether to include Gouy phase for gaussian_beam, by default True.
cone_angle : ScalarFloat
Cone half-angle in radians for bessel_beam. Default is 0.01.
period : ScalarFloat
Spatial period for sinusoidal_wave (meters). Default is 1e-4.
direction : ScalarFloat
Direction angle for sinusoidal_wave (radians). Default is 0.0.
p : ScalarInteger
Radial mode index for laguerre_gaussian, by default 0.
l : ScalarInteger
Azimuthal mode index for laguerre_gaussian, by default 0.
n : ScalarInteger
Mode index in x direction for hermite_gaussian, by default 0.
m : ScalarInteger
Mode index in y direction for hermite_gaussian, by default 0.
Returns
-------
propagating_wavefront : PropagatingWavefront
A PropagatingWavefront containing the beam at all specified z
positions.
Raises
------
ValueError
If beam_type is not recognized.
Notes
-----
This function uses jax.vmap to efficiently generate the beam at all
z positions in parallel. The resulting PropagatingWavefront can be
used to visualize beam evolution or as input to propagation algorithms.
For "gaussian_beam", the z_positions represent distances from the beam
waist, allowing visualization of beam evolution through focus.
For other beam types, z_positions sets the z_position attribute but
the field profile remains constant (as these are evaluated at a
single plane).
"""
z_positions_arr = jnp.asarray(z_positions, dtype=jnp.float64)
if beam_type == "plane_wave":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = plane_wave(
wavelength=wavelength,
dx=dx,
grid_size=grid_size,
amplitude=amplitude,
tilt_x=tilt_x,
tilt_y=tilt_y,
z_position=z,
)
return wf.field
elif beam_type == "sinusoidal_wave":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = sinusoidal_wave(
wavelength=wavelength,
dx=dx,
grid_size=grid_size,
period=period,
direction=direction,
amplitude=amplitude,
z_position=z,
)
return wf.field
elif beam_type == "collimated_gaussian":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = collimated_gaussian(
wavelength=wavelength,
waist=waist,
dx=dx,
grid_size=grid_size,
center=center,
amplitude=amplitude,
z_position=z,
)
return wf.field
elif beam_type == "converging_gaussian":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = converging_gaussian(
wavelength=wavelength,
waist=waist,
focus_distance=focus_distance,
dx=dx,
grid_size=grid_size,
center=center,
amplitude=amplitude,
z_position=z,
)
return wf.field
elif beam_type == "diverging_gaussian":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = diverging_gaussian(
wavelength=wavelength,
waist=waist,
source_distance=source_distance,
dx=dx,
grid_size=grid_size,
center=center,
amplitude=amplitude,
z_position=z,
)
return wf.field
elif beam_type == "gaussian_beam":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = gaussian_beam(
wavelength=wavelength,
waist_0=waist_0,
z_from_waist=z,
dx=dx,
grid_size=grid_size,
center=center,
amplitude=amplitude,
z_position=z,
include_gouy_phase=include_gouy_phase,
)
return wf.field
elif beam_type == "bessel_beam":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = bessel_beam(
wavelength=wavelength,
cone_angle=cone_angle,
dx=dx,
grid_size=grid_size,
amplitude=amplitude,
z_position=z,
)
return wf.field
elif beam_type == "laguerre_gaussian":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = laguerre_gaussian(
wavelength=wavelength,
waist=waist,
p=p,
l=l,
dx=dx,
grid_size=grid_size,
center=center,
amplitude=amplitude,
z_position=z,
)
return wf.field
elif beam_type == "hermite_gaussian":
def _make_beam(z: Float[Array, " "]) -> Complex[Array, " hh ww"]:
wf = hermite_gaussian(
wavelength=wavelength,
waist=waist,
n=n,
m=m,
dx=dx,
grid_size=grid_size,
center=center,
amplitude=amplitude,
z_position=z,
)
return wf.field
else:
raise ValueError(
f"Unknown beam_type: {beam_type}. Must be one of: "
"plane_wave, sinusoidal_wave, collimated_gaussian, "
"converging_gaussian, diverging_gaussian, gaussian_beam, "
"bessel_beam, laguerre_gaussian, hermite_gaussian"
)
fields: Complex[Array, " zz hh ww"] = jax.vmap(_make_beam)(z_positions_arr)
return make_propagating_wavefront(
field=fields,
wavelength=wavelength,
dx=dx,
z_positions=z_positions_arr,
polarization=False,
)