"""Mathematical utility functions for complex-valued computations.
Extended Summary
----------------
Core mathematical utilities for the janssen package, including
Wirtinger calculus for complex-valued optimization, FFT-based
shifting, and other mathematical operations commonly used in
optical simulations.
Routine Listings
----------------
flatten_params : function
Flatten complex arrays to real parameter vector for optimization
unflatten_params : function
Unflatten real parameter vector back to complex arrays
fourier_shift : function
Shift a 2D field using FFT-based sub-pixel shifting
wirtinger_grad : function
Compute the Wirtinger gradient of a complex-valued function
Notes
-----
All functions are JAX-compatible and support automatic differentiation.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Any, Callable, Optional, Sequence, Tuple, Union
from jaxtyping import Array, Complex, Float, jaxtyped
from janssen.types import ScalarFloat, ScalarInteger
[docs]
@jax.jit
@jaxtyped(typechecker=beartype)
def flatten_params(
sample: Complex[Array, " Hs Ws"],
probe: Complex[Array, " Hp Wp"],
) -> Float[Array, " n"]:
"""Flatten complex arrays to real parameter vector for optimization.
Converts two complex 2D arrays into a single real-valued vector by
separating real and imaginary components. Handles arrays with
different shapes (e.g., sample=512×512 FOV, probe=64×64 spot).
Parameters
----------
sample : Complex[Array, " Hs Ws"]
First complex array (e.g., object transmission function).
probe : Complex[Array, " Hp Wp"]
Second complex array (e.g., probe wavefront).
Returns
-------
params : Float[Array, " n"]
Flattened real parameter vector with n = 2*(Hs*Ws + Hp*Wp).
Layout: [sample_real, sample_imag, probe_real, probe_imag]
Examples
--------
>>> sample = jnp.ones((512, 512), dtype=jnp.complex128)
>>> probe = jnp.ones((64, 64), dtype=jnp.complex128) * (1+1j)
>>> params = flatten_params(sample, probe)
>>> params.shape
(532480,) # 2*(512*512 + 64*64)
See Also
--------
unflatten_params : Inverse operation to reconstruct complex arrays
"""
result: Float[Array, " n"] = jnp.concatenate(
[
sample.real.ravel(),
sample.imag.ravel(),
probe.real.ravel(),
probe.imag.ravel(),
]
)
return result
[docs]
def unflatten_params(
params: Float[Array, " n"],
sample_shape: Tuple[int, int],
probe_shape: Tuple[int, int],
) -> Tuple[Complex[Array, " Hs Ws"], Complex[Array, " Hp Wp"]]:
"""Unflatten real parameter vector back to complex arrays.
Reconstructs two complex 2D arrays from a flattened real parameter
vector. Handles arrays with different shapes (e.g., sample=512×512
FOV, probe=64×64 spot).
Parameters
----------
params : Float[Array, " n"]
Flattened real parameter vector with n = 2*(Hs*Ws + Hp*Wp).
sample_shape : Tuple[int, int]
Shape (Hs, Ws) of the sample array.
probe_shape : Tuple[int, int]
Shape (Hp, Wp) of the probe array.
Returns
-------
sample : Complex[Array, " Hs Ws"]
First complex array reconstructed from first 2*Hs*Ws elements.
probe : Complex[Array, " Hp Wp"]
Second complex array reconstructed from remaining elements.
Notes
-----
This function cannot be JIT-compiled directly because it uses
dynamic indexing based on the shape parameters. When used within
a JIT-compiled function, ensure shapes are static via
static_argnums.
Examples
--------
>>> params = jnp.arange(532480, dtype=jnp.float64)
>>> sample, probe = unflatten_params(params, (512, 512), (64, 64))
>>> sample.shape, probe.shape
((512, 512), (64, 64))
See Also
--------
flatten_params : Forward operation to create flattened vector
"""
sample_size: ScalarInteger = sample_shape[0] * sample_shape[1]
probe_size: ScalarInteger = probe_shape[0] * probe_shape[1]
sample: Complex[Array, " Hs Ws"] = (
params[:sample_size].reshape(sample_shape)
+ 1j * params[sample_size : 2 * sample_size].reshape(sample_shape)
)
probe: Complex[Array, " Hp Wp"] = (
params[2 * sample_size : 2 * sample_size + probe_size].reshape(
probe_shape
)
+ 1j * params[2 * sample_size + probe_size :].reshape(probe_shape)
)
return sample, probe
[docs]
@jaxtyped(typechecker=beartype)
def fourier_shift(
field: Complex[Array, " H W"],
shift_x: ScalarFloat,
shift_y: ScalarFloat,
) -> Complex[Array, " H W"]:
r"""Shift a 2D field using FFT-based sub-pixel shifting.
Applies a phase ramp in Fourier space to shift the field by
(shift_x, shift_y) pixels relative to the center of the image.
Supports sub-pixel shifts with high accuracy.
Parameters
----------
field : Complex[Array, " H W"]
Input 2D complex field to shift. The field is assumed to be
centered (i.e., the center of the image is the origin).
shift_x : ScalarFloat
Shift in x direction (columns) in pixels. Positive shifts
move the field to the right.
shift_y : ScalarFloat
Shift in y direction (rows) in pixels. Positive shifts
move the field downward.
Returns
-------
shifted : Complex[Array, " H W"]
Shifted field with same shape as input.
Notes
-----
The shift is implemented by multiplying the Fourier transform of
the field by a phase ramp:
.. math::
F_{shifted}(f_x, f_y) = F(f_x, f_y) \\cdot
\\exp(-2\\pi i (f_x \\cdot \\Delta x + f_y \\cdot \\Delta y))
This is equivalent to circular shifting with sub-pixel accuracy.
The shift is relative to the center of the image, so a shift of
(0, 0) leaves the field unchanged.
Examples
--------
>>> import jax.numpy as jnp
>>> field = jnp.ones((64, 64), dtype=jnp.complex128)
>>> shifted = fourier_shift(field, 10.5, -5.25) # Sub-pixel shift
"""
ny: int = field.shape[0]
nx: int = field.shape[1]
freq_x: Float[Array, " W"] = jnp.fft.fftfreq(nx)
freq_y: Float[Array, " H"] = jnp.fft.fftfreq(ny)
fx: Float[Array, " H W"]
fy: Float[Array, " H W"]
fx, fy = jnp.meshgrid(freq_x, freq_y)
phase_ramp: Complex[Array, " H W"] = jnp.exp(
-2j * jnp.pi * (fx * shift_x + fy * shift_y)
)
field_ft: Complex[Array, " H W"] = jnp.fft.fft2(field)
shifted_ft: Complex[Array, " H W"] = field_ft * phase_ramp
shifted: Complex[Array, " H W"] = jnp.fft.ifft2(shifted_ft)
return shifted
[docs]
@jaxtyped(typechecker=beartype)
def wirtinger_grad(
func2diff: Callable[..., Float[Array, " ..."]],
argnums: Optional[Union[int, Sequence[int]]] = 0,
) -> Callable[
..., Union[Complex[Array, " ..."], Tuple[Complex[Array, " ..."], ...]]
]:
r"""Compute the Wirtinger gradient of a complex-valued function.
This function returns a new function that computes the Wirtinger
gradient of the input function f with respect to the specified
argument(s). This is based on the formula for Wirtinger derivative:
.. math::
\frac{\partial f}{\partial z} = \frac{1}{2} \left(
\frac{\partial f}{\partial x} - i \frac{\partial f}{\partial y}
\right)
Parameters
----------
func2diff : Callable[..., Float[Array, " ..."]]
A complex-valued function to differentiate.
argnums : Union[int, Sequence[int]], optional
Specifies which argument(s) to compute the gradient with respect
to. Can be an int or a sequence of ints. Default is 0.
Returns
-------
grad_f : Callable
A function that computes the Wirtinger gradient of f with
respect to the specified argument(s). Returns a single array
if argnums is an int, or a tuple of arrays if argnums is a
sequence.
Notes
-----
The Wirtinger derivative is essential for optimizing real-valued
loss functions with respect to complex-valued parameters. It provides
the correct gradient direction for steepest descent in the complex
plane.
Examples
--------
>>> import jax.numpy as jnp
>>> def loss(z):
... return jnp.sum(jnp.abs(z)**2)
>>> grad_fn = wirtinger_grad(loss)
>>> z = jnp.array([1+2j, 3+4j])
>>> grad_fn(z) # Returns the Wirtinger gradient
"""
def grad_f(
*args: Any,
) -> Union[Complex[Array, " ..."], Tuple[Complex[Array, " ..."], ...]]:
def split_complex(args: Any) -> Tuple[Any, ...]:
return tuple(
jnp.real(arg) if jnp.iscomplexobj(arg) else arg for arg in args
) + tuple(
jnp.imag(arg) if jnp.iscomplexobj(arg) else jnp.zeros_like(arg)
for arg in args
)
def combine_complex(r: Any, i: Any) -> Tuple[Any, ...]:
return tuple(
rr + 1j * ii if jnp.iscomplexobj(arg) else rr
for rr, ii, arg in zip(r, i, args, strict=False)
)
split_args: Tuple[Any, ...] = split_complex(args)
n: ScalarInteger = len(args)
def f_real(*split_args: Any) -> Float[Array, " ..."]:
return jnp.real(
func2diff(*combine_complex(split_args[:n], split_args[n:]))
)
def f_imag(*split_args: Any) -> Float[Array, " ..."]:
return jnp.imag(
func2diff(*combine_complex(split_args[:n], split_args[n:]))
)
gr: Union[Complex[Array, " ..."], Tuple[Complex[Array, " ..."], ...]]
gi: Union[Complex[Array, " ..."], Tuple[Complex[Array, " ..."], ...]]
gr = jax.grad(f_real, argnums=argnums)(*split_args)
gi = jax.grad(f_imag, argnums=argnums)(*split_args)
if isinstance(argnums, int):
return 0.5 * (gr - 1j * gi)
return tuple(
0.5 * (grr - 1j * gii) for grr, gii in zip(gr, gi, strict=False)
)
return grad_f