"""Coherence types for partially coherent field representation.
Extended Summary
----------------
This module provides PyTree data structures for representing partially
coherent optical fields. It supports both spatial coherence (extended
sources, van Cittert-Zernike theorem) and temporal coherence (finite
bandwidth, chromatic effects).
The key insight is that any partially coherent field can be decomposed
into orthogonal coherent modes (Mercer's theorem), enabling efficient
simulation by propagating a finite number of modes and summing intensities.
Routine Listings
----------------
CoherentModeSet : NamedTuple
PyTree for coherent mode decomposition of partially coherent fields
PolychromaticWavefront : NamedTuple
PyTree structure for polychromatic/broadband field representation
MutualIntensity : NamedTuple
PyTree structure for full mutual intensity J(r1, r2) representation
MixedStatePtychoData : NamedTuple
PyTree for mixed-state ptychography reconstruction state
make_coherent_mode_set : function
Factory function to create validated CoherentModeSet instances
make_polychromatic_wavefront : function
Factory function to create validated PolychromaticWavefront instances
make_mutual_intensity : function
Factory function to create validated MutualIntensity instances
make_mixed_state_ptycho_data : function
Factory function to create validated MixedStatePtychoData instances
Notes
-----
For practical simulations, prefer CoherentModeSet over MutualIntensity:
- CoherentModeSet: O(M × N²) memory for M modes on N×N grid
- MutualIntensity: O(N⁴) memory - use only for small grids or demonstrations
The total intensity from coherent modes is:
I(r) = Σₙ wₙ |φₙ(r)|²
where wₙ are the modal weights (eigenvalues) and φₙ are the modes.
References
----------
1. Mandel, L. & Wolf, E. "Optical Coherence and Quantum Optics" (1995)
2. Starikov, A. & Wolf, E. "Coherent-mode representation of Gaussian
Schell-model sources" JOSA A (1982)
3. Thibault, P. & Menzel, A. "Reconstructing state mixtures from
diffraction measurements" Nature (2013)
"""
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import NamedTuple, Tuple, Union
from jax import lax
from jax.tree_util import register_pytree_node_class
from jaxtyping import Array, Bool, Complex, Float, jaxtyped
from .common_types import ScalarNumeric
@register_pytree_node_class
class CoherentModeSet(NamedTuple):
"""PyTree for coherent mode decomposition of partially coherent fields.
A partially coherent field can be represented as a weighted sum of
orthogonal coherent modes. The total intensity is the incoherent sum:
I(r) = Σₙ weights[n] × |modes[n](r)|²
This representation is memory-efficient O(M × N²) compared to the full
mutual intensity O(N⁴), and naturally parallelizable with vmap.
Attributes
----------
modes : Union[Complex[Array, " num_modes hh ww"],
Complex[Array, " num_modes hh ww 2"]]
Complex amplitude of coherent modes. Can be scalar (M, H, W) or
polarized with Jones vectors (M, H, W, 2).
weights : Float[Array, " num_modes"]
Modal weights (eigenvalues from Mercer decomposition).
Must be non-negative and sum to 1 for normalized representation.
wavelength : Float[Array, " "]
Wavelength of the optical field in meters.
dx : Float[Array, " "]
Spatial sampling interval (grid spacing) in meters.
z_position : Float[Array, " "]
Axial position of the modes along the propagation direction in meters.
polarization : Bool[Array, " "]
Whether the modes are polarized (True for 4D modes, False for 3D).
Notes
-----
The effective number of modes (participation ratio) quantifies partial
coherence:
N_eff = (Σ wₙ)² / Σ(wₙ²)
N_eff = 1 indicates full coherence (single mode dominates).
Larger N_eff indicates greater partial coherence.
For a Gaussian Schell-model source, the eigenvalues and modes have
analytical forms in terms of Hermite-Gaussian functions.
"""
modes: Union[
Complex[Array, " num_modes hh ww"],
Complex[Array, " num_modes hh ww 2"],
]
weights: Float[Array, " num_modes"]
wavelength: Float[Array, " "]
dx: Float[Array, " "]
z_position: Float[Array, " "]
polarization: Bool[Array, " "]
intensity: Float[Array, " hh ww"]
effective_mode_count: Float[Array, " "]
def tree_flatten(
self,
) -> Tuple[
Tuple[
Union[
Complex[Array, " num_modes hh ww"],
Complex[Array, " num_modes hh ww 2"],
],
Float[Array, " num_modes"],
Float[Array, " "],
Float[Array, " "],
Float[Array, " "],
Bool[Array, " "],
Float[Array, " hh ww"],
Float[Array, " "],
],
None,
]:
"""Flatten the CoherentModeSet into a tuple of its components."""
return (
(
self.modes,
self.weights,
self.wavelength,
self.dx,
self.z_position,
self.polarization,
self.intensity,
self.effective_mode_count,
),
None,
)
@classmethod
def tree_unflatten(
cls,
_aux_data: None,
children: Tuple[
Union[
Complex[Array, " num_modes hh ww"],
Complex[Array, " num_modes hh ww 2"],
],
Float[Array, " num_modes"],
Float[Array, " "],
Float[Array, " "],
Float[Array, " "],
Bool[Array, " "],
Float[Array, " hh ww"],
Float[Array, " "],
],
) -> "CoherentModeSet":
"""Unflatten the CoherentModeSet from a tuple of its components."""
return cls(*children)
@register_pytree_node_class
class PolychromaticWavefront(NamedTuple):
"""PyTree structure for polychromatic/broadband field representation.
Represents a field with finite spectral bandwidth for temporal coherence
and chromatic simulations. The total intensity is the spectrally-weighted
incoherent sum:
I(r) = Σᵢ spectral_weights[i] × |fields[i](r)|²
Attributes
----------
fields : Union[Complex[Array, " num_wavelengths hh ww"],
Complex[Array, " num_wavelengths hh ww 2"]]
Complex amplitude at each wavelength. Can be scalar (Nλ, H, W) or
polarized with Jones vectors (Nλ, H, W, 2).
wavelengths : Float[Array, " num_wavelengths"]
Wavelength sample points in meters.
spectral_weights : Float[Array, " num_wavelengths"]
Normalized spectral weights S(λ). Must be non-negative and sum to 1.
dx : Float[Array, " "]
Spatial sampling interval (grid spacing) in meters.
z_position : Float[Array, " "]
Axial position along the propagation direction in meters.
polarization : Bool[Array, " "]
Whether the fields are polarized (True for 4D fields, False for 3D).
Notes
-----
The coherence length for a Gaussian spectrum is:
Lc = λ₀² / Δλ
where λ₀ is the center wavelength and Δλ is the FWHM bandwidth.
For accurate chromatic simulations, ensure sufficient wavelength sampling:
- Use at least 5-11 wavelengths spanning ±2σ of the spectrum
- More samples needed for broader spectra or longer propagation distances
"""
fields: Union[
Complex[Array, " num_wavelengths hh ww"],
Complex[Array, " num_wavelengths hh ww 2"],
]
wavelengths: Float[Array, " num_wavelengths"]
spectral_weights: Float[Array, " num_wavelengths"]
dx: Float[Array, " "]
z_position: Float[Array, " "]
polarization: Bool[Array, " "]
intensity: Float[Array, " hh ww"]
center_wavelength: Float[Array, " "]
def tree_flatten(
self,
) -> Tuple[
Tuple[
Union[
Complex[Array, " num_wavelengths hh ww"],
Complex[Array, " num_wavelengths hh ww 2"],
],
Float[Array, " num_wavelengths"],
Float[Array, " num_wavelengths"],
Float[Array, " "],
Float[Array, " "],
Bool[Array, " "],
Float[Array, " hh ww"],
Float[Array, " "],
],
None,
]:
"""Flatten PolychromaticWavefront into a tuple of its components."""
return (
(
self.fields,
self.wavelengths,
self.spectral_weights,
self.dx,
self.z_position,
self.polarization,
self.intensity,
self.center_wavelength,
),
None,
)
@classmethod
def tree_unflatten(
cls,
_aux_data: None,
children: Tuple[
Union[
Complex[Array, " num_wavelengths hh ww"],
Complex[Array, " num_wavelengths hh ww 2"],
],
Float[Array, " num_wavelengths"],
Float[Array, " num_wavelengths"],
Float[Array, " "],
Float[Array, " "],
Bool[Array, " "],
Float[Array, " hh ww"],
Float[Array, " "],
],
) -> "PolychromaticWavefront":
"""Unflatten PolychromaticWavefront from a tuple of its components."""
return cls(*children)
@register_pytree_node_class
class MutualIntensity(NamedTuple):
"""PyTree structure for full mutual intensity J(r₁, r₂) representation.
The mutual intensity describes spatial coherence of a quasi-monochromatic
field:
J(r₁, r₂) = ⟨E*(r₁) E(r₂)⟩
The complex degree of coherence normalizes this:
μ(r₁, r₂) = J(r₁, r₂) / √(I(r₁) I(r₂))
where |μ| = 1 indicates full coherence and |μ| = 0 indicates incoherence.
Attributes
----------
j_matrix : Complex[Array, " hh ww hh ww"]
Full mutual intensity matrix J(r₁, r₂).
j_matrix[i, j, k, l] = J(r[i,j], r[k,l])
wavelength : Float[Array, " "]
Wavelength of the optical field in meters.
dx : Float[Array, " "]
Spatial sampling interval (grid spacing) in meters.
z_position : Float[Array, " "]
Axial position along the propagation direction in meters.
Warnings
--------
Memory scales as O(N⁴) for an N×N grid. For a 256×256 grid, this requires
~34 GB for complex128. Use only for small grids (≤64×64) or theoretical
demonstrations. For practical simulations, use CoherentModeSet instead.
Notes
-----
The mutual intensity can be decomposed into coherent modes via eigenvalue
decomposition:
J(r₁, r₂) = Σₙ λₙ φₙ*(r₁) φₙ(r₂)
where λₙ are eigenvalues and φₙ are orthonormal eigenfunctions.
"""
j_matrix: Complex[Array, " hh ww hh ww"]
wavelength: Float[Array, " "]
dx: Float[Array, " "]
z_position: Float[Array, " "]
intensity: Float[Array, " hh ww"]
def tree_flatten(
self,
) -> Tuple[
Tuple[
Complex[Array, " hh ww hh ww"],
Float[Array, " "],
Float[Array, " "],
Float[Array, " "],
Float[Array, " hh ww"],
],
None,
]:
"""Flatten the MutualIntensity into a tuple of its components."""
return (
(
self.j_matrix,
self.wavelength,
self.dx,
self.z_position,
self.intensity,
),
None,
)
@classmethod
def tree_unflatten(
cls,
_aux_data: None,
children: Tuple[
Complex[Array, " hh ww hh ww"],
Float[Array, " "],
Float[Array, " "],
Float[Array, " "],
Float[Array, " hh ww"],
],
) -> "MutualIntensity":
"""Unflatten the MutualIntensity from a tuple of its components."""
return cls(*children)
@jaxtyped(typechecker=beartype)
def make_coherent_mode_set(
modes: Union[
Complex[Array, " num_modes hh ww"],
Complex[Array, " num_modes hh ww 2"],
],
weights: Float[Array, " num_modes"],
wavelength: ScalarNumeric,
dx: ScalarNumeric,
z_position: ScalarNumeric = 0.0,
polarization: Union[bool, Bool[Array, " "]] = False,
normalize_weights: bool = True,
) -> CoherentModeSet:
"""Create a validated CoherentModeSet instance.
Factory function that validates inputs and creates a CoherentModeSet
PyTree suitable for partially coherent field simulations.
Parameters
----------
modes : Union[Complex[Array, " num_modes hh ww"],
Complex[Array, " num_modes hh ww 2"]]
Complex amplitude of coherent modes. Shape (M, H, W) for scalar
or (M, H, W, 2) for polarized fields.
weights : Float[Array, " num_modes"]
Modal weights (eigenvalues). Must be non-negative.
wavelength : ScalarNumeric
Wavelength of the optical field in meters. Must be positive.
dx : ScalarNumeric
Spatial sampling interval in meters. Must be positive.
z_position : ScalarNumeric, optional
Axial position in meters. Default is 0.0.
polarization : Union[bool, Bool[Array, " "]], optional
Whether modes are polarized. Accepts Python bool or JAX bool array.
Default is False.
normalize_weights : bool, optional
If True, normalize weights to sum to 1. Default is True.
Returns
-------
CoherentModeSet
Validated coherent mode set instance.
Raises
------
ValueError
If modes and weights have inconsistent shapes, or if validation fails.
"""
non_polar_dim: int = 3
polar_dim: int = 4
modes = jnp.asarray(modes, dtype=jnp.complex128)
weights = jnp.asarray(weights, dtype=jnp.float64)
wavelength_arr: Float[Array, " "] = jnp.asarray(
wavelength, dtype=jnp.float64
)
dx_arr: Float[Array, " "] = jnp.asarray(dx, dtype=jnp.float64)
z_position_arr: Float[Array, " "] = jnp.asarray(
z_position, dtype=jnp.float64
)
polarization_arr: Bool[Array, " "] = jnp.asarray(
polarization, dtype=jnp.bool_
)
polarization_arr = jnp.where(
modes.ndim == polar_dim,
jnp.asarray(modes.shape[-1] == 2, dtype=jnp.bool_),
polarization_arr,
)
def validate_and_create() -> CoherentModeSet:
def check_modes_shape() -> Complex[Array, "..."]:
def check_polarized() -> Complex[Array, " num_modes hh ww 2"]:
return lax.cond(
jnp.logical_and(
modes.ndim == polar_dim,
modes.shape[-1] == 2,
),
lambda: modes,
lambda: lax.stop_gradient(
lax.cond(False, lambda: modes, lambda: modes)
),
)
def check_scalar() -> Complex[Array, " num_modes hh ww"]:
return lax.cond(
modes.ndim == non_polar_dim,
lambda: modes,
lambda: lax.stop_gradient(
lax.cond(False, lambda: modes, lambda: modes)
),
)
return lax.cond(
polarization_arr,
check_polarized,
check_scalar,
)
def check_weights_shape(
m: Complex[Array, "..."],
) -> Float[Array, " num_modes"]:
num_modes: int = m.shape[0]
is_valid: Bool[Array, " "] = weights.shape[0] == num_modes
return lax.cond(
is_valid,
lambda: weights,
lambda: lax.stop_gradient(
lax.cond(False, lambda: weights, lambda: weights)
),
)
def check_weights_nonnegative(
w: Float[Array, " num_modes"],
) -> Float[Array, " num_modes"]:
w_clipped: Float[Array, " num_modes"] = jnp.maximum(w, 0.0)
return w_clipped
def normalize_weights_fn(
w: Float[Array, " num_modes"],
) -> Float[Array, " num_modes"]:
w_sum: Float[Array, " "] = jnp.sum(w)
return lax.cond(
jnp.logical_and(normalize_weights, w_sum > 1e-12),
lambda: w / w_sum,
lambda: w,
)
def check_wavelength() -> Float[Array, " "]:
return lax.cond(
wavelength_arr > 0,
lambda: wavelength_arr,
lambda: lax.stop_gradient(
lax.cond(
False, lambda: wavelength_arr, lambda: wavelength_arr
)
),
)
def check_dx() -> Float[Array, " "]:
return lax.cond(
dx_arr > 0,
lambda: dx_arr,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dx_arr, lambda: dx_arr)
),
)
validated_modes: Complex[Array, "..."] = check_modes_shape()
validated_weights: Float[Array, " num_modes"] = normalize_weights_fn(
check_weights_nonnegative(check_weights_shape(validated_modes))
)
validated_wavelength: Float[Array, " "] = check_wavelength()
validated_dx: Float[Array, " "] = check_dx()
def _compute_intensity_int() -> Float[Array, " hh ww"]:
"""Compute total intensity from incoherent mode sum."""
abs_squared: Float[Array, "..."] = jnp.abs(validated_modes) ** 2
is_polarized: bool = validated_modes.ndim == 4
mode_intensities: Float[Array, " num_modes hh ww"] = (
jnp.sum(abs_squared, axis=-1) if is_polarized else abs_squared
)
total: Float[Array, " hh ww"] = jnp.sum(
validated_weights[:, jnp.newaxis, jnp.newaxis]
* mode_intensities,
axis=0,
)
return total
def _compute_effective_mode_count_int() -> Float[Array, " "]:
"""Compute effective number of modes (participation ratio)."""
weights_sum: Float[Array, " "] = jnp.sum(validated_weights)
weights_sq_sum: Float[Array, " "] = jnp.sum(validated_weights**2)
n_eff: Float[Array, " "] = weights_sum**2 / (
weights_sq_sum + 1e-12
)
return n_eff
intensity: Float[Array, " hh ww"] = _compute_intensity_int()
effective_mode_count: Float[Array, " "] = (
_compute_effective_mode_count_int()
)
return CoherentModeSet(
modes=validated_modes,
weights=validated_weights,
wavelength=validated_wavelength,
dx=validated_dx,
z_position=z_position_arr,
polarization=polarization_arr,
intensity=intensity,
effective_mode_count=effective_mode_count,
)
return validate_and_create()
@jaxtyped(typechecker=beartype)
def make_polychromatic_wavefront(
fields: Union[
Complex[Array, " num_wavelengths hh ww"],
Complex[Array, " num_wavelengths hh ww 2"],
],
wavelengths: Float[Array, " num_wavelengths"],
spectral_weights: Float[Array, " num_wavelengths"],
dx: ScalarNumeric,
z_position: ScalarNumeric = 0.0,
polarization: Union[bool, Bool[Array, " "]] = False,
normalize_weights: bool = True,
) -> PolychromaticWavefront:
"""Create a validated PolychromaticWavefront instance.
Factory function that validates inputs and creates a PolychromaticWavefront
PyTree suitable for chromatic/temporal coherence simulations.
Parameters
----------
fields : Union[Complex[Array, " num_wavelengths hh ww"],
Complex[Array, " num_wavelengths hh ww 2"]]
Complex amplitude at each wavelength. Shape (Nλ, H, W) for scalar
or (Nλ, H, W, 2) for polarized fields.
wavelengths : Float[Array, " num_wavelengths"]
Wavelength sample points in meters. Must be positive.
spectral_weights : Float[Array, " num_wavelengths"]
Spectral weights S(λ). Must be non-negative.
dx : ScalarNumeric
Spatial sampling interval in meters. Must be positive.
z_position : ScalarNumeric, optional
Axial position in meters. Default is 0.0.
polarization : Union[bool, Bool[Array, " "]], optional
Whether fields are polarized. Accepts Python bool or JAX bool array.
Default is False.
normalize_weights : bool, optional
If True, normalize spectral_weights to sum to 1. Default is True.
Returns
-------
PolychromaticWavefront
Validated polychromatic wavefront instance.
"""
non_polar_dim: int = 3
polar_dim: int = 4
fields = jnp.asarray(fields, dtype=jnp.complex128)
wavelengths = jnp.asarray(wavelengths, dtype=jnp.float64)
spectral_weights = jnp.asarray(spectral_weights, dtype=jnp.float64)
dx_arr: Float[Array, " "] = jnp.asarray(dx, dtype=jnp.float64)
z_position_arr: Float[Array, " "] = jnp.asarray(
z_position, dtype=jnp.float64
)
polarization_arr: Bool[Array, " "] = jnp.asarray(
polarization, dtype=jnp.bool_
)
polarization_arr = jnp.where(
fields.ndim == polar_dim,
jnp.asarray(fields.shape[-1] == 2, dtype=jnp.bool_),
polarization_arr,
)
def validate_and_create() -> PolychromaticWavefront:
def check_fields_shape() -> Complex[Array, "..."]:
def check_polarized() -> Complex[Array, " n hh ww 2"]:
return lax.cond(
jnp.logical_and(
fields.ndim == polar_dim,
fields.shape[-1] == 2,
),
lambda: fields,
lambda: lax.stop_gradient(
lax.cond(False, lambda: fields, lambda: fields)
),
)
def check_scalar() -> Complex[Array, " num_wavelengths hh ww"]:
return lax.cond(
fields.ndim == non_polar_dim,
lambda: fields,
lambda: lax.stop_gradient(
lax.cond(False, lambda: fields, lambda: fields)
),
)
return lax.cond(
polarization_arr,
check_polarized,
check_scalar,
)
def check_wavelengths_shape(
f: Complex[Array, "..."],
) -> Float[Array, " num_wavelengths"]:
num_wl: int = f.shape[0]
is_valid: Bool[Array, " "] = wavelengths.shape[0] == num_wl
return lax.cond(
is_valid,
lambda: wavelengths,
lambda: lax.stop_gradient(
lax.cond(False, lambda: wavelengths, lambda: wavelengths)
),
)
def check_spectral_weights_shape(
f: Complex[Array, "..."],
) -> Float[Array, " num_wavelengths"]:
num_wl: int = f.shape[0]
is_valid: Bool[Array, " "] = spectral_weights.shape[0] == num_wl
return lax.cond(
is_valid,
lambda: spectral_weights,
lambda: lax.stop_gradient(
lax.cond(
False,
lambda: spectral_weights,
lambda: spectral_weights,
)
),
)
def check_weights_nonnegative(
w: Float[Array, " num_wavelengths"],
) -> Float[Array, " num_wavelengths"]:
return jnp.maximum(w, 0.0)
def normalize_weights_fn(
w: Float[Array, " num_wavelengths"],
) -> Float[Array, " num_wavelengths"]:
w_sum: Float[Array, " "] = jnp.sum(w)
return lax.cond(
jnp.logical_and(normalize_weights, w_sum > 1e-12),
lambda: w / w_sum,
lambda: w,
)
def check_wavelengths_positive(
wl: Float[Array, " num_wavelengths"],
) -> Float[Array, " num_wavelengths"]:
is_valid: Bool[Array, " "] = jnp.all(wl > 0)
return lax.cond(
is_valid,
lambda: wl,
lambda: lax.stop_gradient(
lax.cond(False, lambda: wl, lambda: wl)
),
)
def check_dx() -> Float[Array, " "]:
return lax.cond(
dx_arr > 0,
lambda: dx_arr,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dx_arr, lambda: dx_arr)
),
)
validated_fields: Complex[Array, "..."] = check_fields_shape()
validated_wavelengths: Float[Array, " num_wavelengths"] = (
check_wavelengths_positive(
check_wavelengths_shape(validated_fields)
)
)
validated_spectral_weights: Float[Array, " num_wavelengths"] = (
normalize_weights_fn(
check_weights_nonnegative(
check_spectral_weights_shape(validated_fields)
)
)
)
validated_dx: Float[Array, " "] = check_dx()
def _compute_intensity_int() -> Float[Array, " hh ww"]:
"""Compute total intensity from spectral sum."""
abs_squared: Float[Array, "..."] = jnp.abs(validated_fields) ** 2
is_polarized: bool = validated_fields.ndim == 4
field_intensities: Float[Array, " num_wavelengths hh ww"] = (
jnp.sum(abs_squared, axis=-1) if is_polarized else abs_squared
)
total: Float[Array, " hh ww"] = jnp.sum(
validated_spectral_weights[:, jnp.newaxis, jnp.newaxis]
* field_intensities,
axis=0,
)
return total
def _compute_center_wavelength_int() -> Float[Array, " "]:
"""Compute weighted center wavelength."""
center: Float[Array, " "] = jnp.sum(
validated_spectral_weights * validated_wavelengths
)
return center
intensity: Float[Array, " hh ww"] = _compute_intensity_int()
center_wavelength: Float[Array, " "] = _compute_center_wavelength_int()
return PolychromaticWavefront(
fields=validated_fields,
wavelengths=validated_wavelengths,
spectral_weights=validated_spectral_weights,
dx=validated_dx,
z_position=z_position_arr,
polarization=polarization_arr,
intensity=intensity,
center_wavelength=center_wavelength,
)
return validate_and_create()
@jaxtyped(typechecker=beartype)
def make_mutual_intensity(
j_matrix: Complex[Array, " hh ww hh ww"],
wavelength: ScalarNumeric,
dx: ScalarNumeric,
z_position: ScalarNumeric = 0.0,
) -> MutualIntensity:
"""Create a validated MutualIntensity instance.
Factory function that validates inputs and creates a MutualIntensity
PyTree. Use with caution due to O(N⁴) memory scaling.
Parameters
----------
j_matrix : Complex[Array, " hh ww hh ww"]
Full mutual intensity matrix J(r₁, r₂).
wavelength : ScalarNumeric
Wavelength in meters. Must be positive.
dx : ScalarNumeric
Spatial sampling interval in meters. Must be positive.
z_position : ScalarNumeric, optional
Axial position in meters. Default is 0.0.
Returns
-------
MutualIntensity
Validated mutual intensity instance.
Warnings
--------
Memory scales as O(N⁴). For 64×64 grid: ~134 MB.
For 128×128 grid: ~2 GB. For 256×256 grid: ~34 GB.
Consider using CoherentModeSet for larger grids.
"""
j_matrix = jnp.asarray(j_matrix, dtype=jnp.complex128)
wavelength_arr: Float[Array, " "] = jnp.asarray(
wavelength, dtype=jnp.float64
)
dx_arr: Float[Array, " "] = jnp.asarray(dx, dtype=jnp.float64)
z_position_arr: Float[Array, " "] = jnp.asarray(
z_position, dtype=jnp.float64
)
def validate_and_create() -> MutualIntensity:
def check_j_matrix_shape() -> Complex[Array, " hh ww hh ww"]:
is_valid_ndim: Bool[Array, " "] = j_matrix.ndim == 4
is_valid_shape: Bool[Array, " "] = jnp.logical_and(
j_matrix.shape[0] == j_matrix.shape[2],
j_matrix.shape[1] == j_matrix.shape[3],
)
is_valid: Bool[Array, " "] = jnp.logical_and(
is_valid_ndim, is_valid_shape
)
return lax.cond(
is_valid,
lambda: j_matrix,
lambda: lax.stop_gradient(
lax.cond(False, lambda: j_matrix, lambda: j_matrix)
),
)
def check_wavelength() -> Float[Array, " "]:
return lax.cond(
wavelength_arr > 0,
lambda: wavelength_arr,
lambda: lax.stop_gradient(
lax.cond(
False, lambda: wavelength_arr, lambda: wavelength_arr
)
),
)
def check_dx() -> Float[Array, " "]:
return lax.cond(
dx_arr > 0,
lambda: dx_arr,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dx_arr, lambda: dx_arr)
),
)
validated_j_matrix: Complex[Array, " hh ww hh ww"] = (
check_j_matrix_shape()
)
validated_wavelength: Float[Array, " "] = check_wavelength()
validated_dx: Float[Array, " "] = check_dx()
def _compute_intensity_int() -> Float[Array, " hh ww"]:
"""Compute intensity I(r) = J(r, r) from diagonal."""
diagonal: Float[Array, " hh ww"] = jnp.real(
jnp.diagonal(
jnp.diagonal(validated_j_matrix, axis1=0, axis2=2),
axis1=0,
axis2=1,
)
)
return diagonal
intensity: Float[Array, " hh ww"] = _compute_intensity_int()
return MutualIntensity(
j_matrix=validated_j_matrix,
wavelength=validated_wavelength,
dx=validated_dx,
z_position=z_position_arr,
intensity=intensity,
)
return validate_and_create()
[docs]
@register_pytree_node_class
class MixedStatePtychoData(NamedTuple):
"""PyTree structure for mixed-state ptychography reconstruction.
Extends standard ptychography data to support partially coherent
illumination via coherent mode decomposition.
Attributes
----------
diffraction_patterns : Float[Array, " N H W"]
Measured diffraction intensities at each scan position.
probe_modes : CoherentModeSet
Coherent mode decomposition of the partially coherent probe.
sample : Complex[Array, " Hs Ws"]
Object transmission function estimate.
positions : Float[Array, " N 2"]
Scan positions in pixels.
wavelength : Float[Array, " "]
Wavelength in meters.
dx : Float[Array, " "]
Pixel spacing in meters.
Notes
-----
The forward model is:
I_i = Sigma_n w_n |FFT(probe_n * shift(object, r_i))|^2
Gradients flow through probe_modes.modes, probe_modes.weights,
and sample, enabling joint optimization of all parameters.
"""
diffraction_patterns: Float[Array, " N H W"]
probe_modes: CoherentModeSet
sample: Complex[Array, " Hs Ws"]
positions: Float[Array, " N 2"]
wavelength: Float[Array, " "]
dx: Float[Array, " "]
def tree_flatten(
self,
) -> Tuple[
Tuple[
Float[Array, " N H W"],
CoherentModeSet,
Complex[Array, " Hs Ws"],
Float[Array, " N 2"],
Float[Array, " "],
Float[Array, " "],
],
None,
]:
"""Flatten for JAX pytree compatibility."""
children = (
self.diffraction_patterns,
self.probe_modes,
self.sample,
self.positions,
self.wavelength,
self.dx,
)
return (children, None)
@classmethod
def tree_unflatten(
cls,
_aux_data: None,
children: Tuple[
Float[Array, " N H W"],
CoherentModeSet,
Complex[Array, " Hs Ws"],
Float[Array, " N 2"],
Float[Array, " "],
Float[Array, " "],
],
) -> "MixedStatePtychoData":
"""Unflatten from JAX pytree representation."""
return cls(*children)
[docs]
@jaxtyped(typechecker=beartype)
def make_mixed_state_ptycho_data(
diffraction_patterns: Float[Array, " N H W"],
probe_modes: CoherentModeSet,
sample: Complex[Array, " Hs Ws"],
positions: Float[Array, " N 2"],
wavelength: ScalarNumeric,
dx: ScalarNumeric,
) -> MixedStatePtychoData:
"""Create validated MixedStatePtychoData.
Factory function that validates inputs and creates a MixedStatePtychoData
PyTree suitable for mixed-state ptychography reconstruction.
Parameters
----------
diffraction_patterns : Float[Array, " N H W"]
Measured diffraction patterns.
probe_modes : CoherentModeSet
Partially coherent probe as coherent modes.
sample : Complex[Array, " Hs Ws"]
Initial object estimate.
positions : Float[Array, " N 2"]
Scan positions (x, y) in pixels.
wavelength : ScalarNumeric
Wavelength in meters. Must be positive.
dx : ScalarNumeric
Pixel size in meters. Must be positive.
Returns
-------
data : MixedStatePtychoData
Validated data structure.
"""
diffraction_patterns_arr: Float[Array, " N H W"] = jnp.asarray(
diffraction_patterns, dtype=jnp.float64
)
sample_arr: Complex[Array, " Hs Ws"] = jnp.asarray(
sample, dtype=jnp.complex128
)
positions_arr: Float[Array, " N 2"] = jnp.asarray(
positions, dtype=jnp.float64
)
wavelength_arr: Float[Array, " "] = jnp.asarray(
wavelength, dtype=jnp.float64
)
dx_arr: Float[Array, " "] = jnp.asarray(dx, dtype=jnp.float64)
expected_dp_ndim: int = 3
expected_pos_cols: int = 2
def validate_and_create() -> MixedStatePtychoData:
def check_diffraction_patterns() -> Float[Array, " N H W"]:
dp_arr = diffraction_patterns_arr
is_valid_ndim: Bool[Array, " "] = dp_arr.ndim == expected_dp_ndim
is_non_negative: Bool[Array, " "] = jnp.all(dp_arr >= 0)
is_valid: Bool[Array, " "] = jnp.logical_and(
is_valid_ndim, is_non_negative
)
return lax.cond(
is_valid,
lambda: dp_arr,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dp_arr, lambda: dp_arr)
),
)
def check_positions(
dp: Float[Array, " N H W"],
) -> Float[Array, " N 2"]:
num_positions: int = dp.shape[0]
is_valid_shape: Bool[Array, " "] = jnp.logical_and(
positions_arr.shape[0] == num_positions,
positions_arr.shape[1] == expected_pos_cols,
)
return lax.cond(
is_valid_shape,
lambda: positions_arr,
lambda: lax.stop_gradient(
lax.cond(
False, lambda: positions_arr, lambda: positions_arr
)
),
)
def check_wavelength() -> Float[Array, " "]:
return lax.cond(
wavelength_arr > 0,
lambda: wavelength_arr,
lambda: lax.stop_gradient(
lax.cond(
False, lambda: wavelength_arr, lambda: wavelength_arr
)
),
)
def check_dx() -> Float[Array, " "]:
return lax.cond(
dx_arr > 0,
lambda: dx_arr,
lambda: lax.stop_gradient(
lax.cond(False, lambda: dx_arr, lambda: dx_arr)
),
)
validated_dp: Float[Array, " N H W"] = check_diffraction_patterns()
validated_positions: Float[Array, " N 2"] = check_positions(
validated_dp
)
validated_wavelength: Float[Array, " "] = check_wavelength()
validated_dx: Float[Array, " "] = check_dx()
return MixedStatePtychoData(
diffraction_patterns=validated_dp,
probe_modes=probe_modes,
sample=sample_arr,
positions=validated_positions,
wavelength=validated_wavelength,
dx=validated_dx,
)
return validate_and_create()