Source code for janssen.invert.engine

"""Engine for optical ptychography.

Extended Summary
----------------
Main ePIE reconstruction algorithm for optical ptychography.
Parallel processing over positions using vmap for faster but approximate
PIE.
Sequential processing over positions for refinement.
Updates object wavefront using rPIE algorithm.
Updates surface pattern using modified PIE.
Applies coherent transfer function in Fourier domain.
Applies position shift using phase multiplication.
Computes sensor plane intensity with pixel response.
Creates frequency grids for Fourier transforms.

Routine Listings
----------------
epie_optical : function
    Main ePIE reconstruction algorithm for optical ptychography
single_pie_iteration : function
    Single iteration of the ePIE algorithm
single_pie_vmap
    Parallel processing over positions using vmap for faster but
    approximate PIE.
single_pie_sequential : function
    Sequential processing over positions for refinement.
_update_object_wavefront : function, internal
    Updates object wavefront using rPIE algorithm
_update_surface_pattern : function, internal
    Updates surface pattern using modified PIE
_apply_coherent_transfer_function : function, internal
    Applies coherent transfer function in Fourier domain
_apply_position_shift : function, internal
    Applies position shift using phase multiplication
_compute_sensor_intensity : function, internal
    Computes sensor plane intensity with pixel response
_create_frequency_grids : function, internal
    Creates frequency grids for Fourier transforms
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Callable, Optional, Tuple
from jaxtyping import Array, Complex, Float, jaxtyped

from janssen.prop import angular_spectrum_prop
from janssen.types import (
    MicroscopeData,
    OpticalWavefront,
    SampleFunction,
    ScalarFloat,
    ScalarInteger,
    make_optical_wavefront,
    make_sample_function,
)


[docs] @jax.jit @jaxtyped(typechecker=beartype) def epie_optical( microscope_data: MicroscopeData, initial_object: OpticalWavefront, initial_surface: SampleFunction, pixel_mask: Float[Array, " H W"], propagation_distance_1: ScalarFloat, propagation_distance_2: ScalarFloat, magnification: ScalarInteger, vmap_iterations: Optional[ScalarInteger] = 0, alpha_object: Optional[ScalarFloat] = 0.1, gamma_object: Optional[ScalarFloat] = 0.5, alpha_surface: Optional[ScalarFloat] = 0.1, gamma_surface: Optional[ScalarFloat] = 0.5, num_loops: Optional[ScalarInteger] = 10, ) -> tuple[OpticalWavefront, SampleFunction]: """Reconstruct ptychography using the extended PIE algorithm. Parameters ---------- microscope_data : MicroscopeData Measured intensity data with positions. initial_object : OpticalWavefront Initial guess for object wavefront. initial_surface : SampleFunction Initial guess for surface pattern. pixel_mask : Float[Array, " H W"] Pixel response mask for modeling sensor characteristics. propagation_distance_1 : ScalarFloat Distance from object to diffuser plane in meters. propagation_distance_2 : ScalarFloat Distance from diffuser to sensor plane in meters. magnification : ScalarInteger Magnification factor for downsampling. vmap_iterations : ScalarInteger, optional Number of initial iterations to run in vmap mode for rapid convergence. If 0, use sequential mode for all iterations. If > 0, use vmap for first N iterations, then switch to sequential. Default is 0. alpha_object : ScalarFloat, optional Object update mixing parameter. Default is 0.1. gamma_object : ScalarFloat, optional Object update step size. Default is 0.5. alpha_surface : ScalarFloat, optional Surface update mixing parameter. Default is 0.1. gamma_surface : ScalarFloat, optional Surface update step size. Default is 0.5. num_loops : ScalarInteger, optional Number of iteration loops. Default is 10. Returns ------- tuple of (OpticalWavefront, SampleFunction) - recovered_object : OpticalWavefront Reconstructed object wavefront. - recovered_surface : SampleFunction Reconstructed surface pattern. Notes ----- Algorithm: - Compute image data - Compute positions - Compute frequency grids - Compute object recovery propagation field - Compute surface pattern - Define loop body - Apply fori_loop over loops - Compute final object field - Compute final object wavefront - Compute final surface pattern - Return final object and surface """ image_data: Float[Array, " P H W"] = microscope_data.image_data positions: Float[Array, " P 2"] = microscope_data.positions frequency_x_grid: Float[Array, " H W"] frequency_y_grid: Float[Array, " H W"] frequency_x_grid, frequency_y_grid = _create_frequency_grids( initial_object.field, initial_object.dx ) object_recovery_prop_ft: Complex[Array, " H W"] = angular_spectrum_prop( initial_object, propagation_distance_1 ).field surface_pattern: Complex[Array, " H W"] = initial_surface.sample def loop_body( loop_idx: ScalarInteger, state: tuple[Complex[Array, " H W"], Complex[Array, " H W"]], ) -> tuple[Complex[Array, " H W"], Complex[Array, " H W"]]: object_prop_ft: Complex[Array, " H W"] surface_pattern_current: Complex[Array, " H W"] object_prop_ft, surface_pattern_current = state use_vmap: bool = loop_idx < vmap_iterations position_processor: Callable = jax.lax.cond( use_vmap, lambda: single_pie_vmap, lambda: single_pie_sequential ) updated_state: tuple[ Complex[Array, " H W"], Complex[Array, " H W"] ] = position_processor( object_prop_ft, surface_pattern_current, image_data, positions, frequency_x_grid, frequency_y_grid, pixel_mask, propagation_distance_2, magnification, alpha_object, gamma_object, alpha_surface, gamma_surface, initial_object.wavelength, initial_object.dx, ) return updated_state final_object_ft: Complex[Array, " H W"] final_surface: Complex[Array, " H W"] final_object_ft, final_surface = jax.lax.fori_loop( 0, num_loops, loop_body, (object_recovery_prop_ft, surface_pattern) ) final_object_field: Complex[Array, " H W"] = angular_spectrum_prop( make_optical_wavefront( final_object_ft, initial_object.wavelength, initial_object.dx, initial_object.z_position + propagation_distance_1, ), -propagation_distance_1, ).field recovered_object: OpticalWavefront = make_optical_wavefront( final_object_field, initial_object.wavelength, initial_object.dx, initial_object.z_position, ) recovered_surface: SampleFunction = make_sample_function( final_surface, initial_surface.dx ) return recovered_object, recovered_surface
[docs] @jaxtyped(typechecker=beartype) def single_pie_iteration( object_prop_ft: Complex[Array, " H W"], surface_pattern: Complex[Array, " H W"], measurement: Float[Array, " H W"], position: Float[Array, " 2"], frequency_x_grid: Float[Array, " H W"], frequency_y_grid: Float[Array, " H W"], pixel_mask: Float[Array, " H W"], propagation_distance_2: ScalarFloat, magnification: ScalarInteger, alpha_object: ScalarFloat, gamma_object: ScalarFloat, alpha_surface: ScalarFloat, gamma_surface: ScalarFloat, wavelength: ScalarFloat, dx: ScalarFloat, ) -> tuple[Complex[Array, " H W"], Complex[Array, " H W"]]: """Single iteration of the extended PIE algorithm. Parameters ---------- object_prop_ft : Complex[Array, " H W"] Object wavefront at diffuser plane in Fourier domain. surface_pattern : Complex[Array, " H W"] Surface pattern function. measurement : Float[Array, " H W"] Measured intensity at current position. position : Float[Array, " 2"] Current scanning position [x, y]. frequency_x_grid : Float[Array, " H W"] Frequency grid in x direction. frequency_y_grid : Float[Array, " H W"] Frequency grid in y direction. pixel_mask : Float[Array, " H W"] Pixel response mask for sensor modeling. propagation_distance_2 : ScalarFloat Distance from diffuser to sensor. magnification : ScalarInteger Downsampling magnification factor. alpha_object : ScalarFloat Object update mixing parameter. gamma_object : ScalarFloat Object update step size. alpha_surface : ScalarFloat Surface update mixing parameter. gamma_surface : ScalarFloat Surface update step size. wavelength : ScalarFloat Wavelength of light in meters. dx : ScalarFloat Pixel spacing in meters. Returns ------- tuple of (Complex[Array, " H W"], Complex[Array, " H W"]) - updated_object_ft : Complex[Array, " H W"] Updated object wavefront in Fourier domain. - updated_surface : Complex[Array, " H W"] Updated surface pattern. Notes ----- Algorithm: - Compute object shifted - Compute surface plane - Compute surface propagation kernel - Compute sensor plane - Compute sensor intensity - Compute ratio map - Compute ratio map upsampled - Compute sensor plane new - Compute sensor plane new in Fourier domain - Compute CTF conjugate - Compute CTF maximum squared - Compute surface propagation kernel - Compute updated surface pattern - Compute updated object wavefront - Compute updated object wavefront in Fourier domain - Return updated object and surface """ object_shifted: Complex[Array, " H W"] = _apply_position_shift( object_prop_ft, position, frequency_x_grid, frequency_y_grid ) surface_plane: Complex[Array, " H W"] = object_shifted * surface_pattern surface_prop_ft: Complex[Array, " H W"] = jnp.fft.fftshift( jnp.fft.fft2(surface_plane) ) * _get_propagation_kernel( surface_plane.shape, propagation_distance_2, wavelength, dx ) sensor_plane_ft: Complex[Array, " H W"] = ( _apply_coherent_transfer_function(surface_prop_ft) ) sensor_plane: Complex[Array, " H W"] = jnp.fft.ifft2( jnp.fft.ifftshift(sensor_plane_ft) ) sensor_intensity: Float[Array, " H W"] = _compute_sensor_intensity( sensor_plane, pixel_mask, magnification ) ratio_map: Float[Array, " H W"] = jnp.sqrt(measurement) / ( jnp.sqrt(sensor_intensity) + 1e-10 ) ratio_map_upsampled: Float[Array, " H W"] = jnp.repeat( jnp.repeat(ratio_map, magnification, axis=0), magnification, axis=1 ) sensor_plane_new: Complex[Array, " H W"] = ( ratio_map_upsampled * sensor_plane ) sensor_plane_new_ft: Complex[Array, " H W"] = jnp.fft.fftshift( jnp.fft.fft2(sensor_plane_new) ) ctf_conj: Complex[Array, " H W"] = jnp.conj(_get_ctf()) ctf_max_squared: Float[Array, " "] = jnp.max(jnp.abs(_get_ctf()) ** 2) surface_prop_ft_updated: Complex[Array, " H W"] = surface_prop_ft + ( ctf_conj * (sensor_plane_new_ft - sensor_plane_ft) / ctf_max_squared ) surface_plane_new: Complex[Array, " H W"] = jnp.fft.ifft2( jnp.fft.ifftshift( surface_prop_ft_updated * _get_propagation_kernel( surface_prop_ft_updated.shape, -propagation_distance_2, wavelength, dx, ) ) ) updated_surface: Complex[Array, " H W"] = _update_surface_pattern( surface_pattern, object_shifted, surface_plane, surface_plane_new, alpha_surface, gamma_surface, ) object_shifted_updated: Complex[Array, " H W"] = _update_object_wavefront( object_shifted, surface_pattern, surface_plane, surface_plane_new, alpha_object, gamma_object, ) updated_object_ft: Complex[Array, " H W"] = _apply_position_shift( jnp.fft.fftshift(jnp.fft.fft2(object_shifted_updated)), -position, frequency_x_grid, frequency_y_grid, ) return updated_object_ft, updated_surface
[docs] @jaxtyped(typechecker=beartype) def single_pie_sequential( object_prop_ft: Complex[Array, " H W"], surface_pattern: Complex[Array, " H W"], image_data: Float[Array, " P H W"], positions: Float[Array, " P 2"], frequency_x_grid: Float[Array, " H W"], frequency_y_grid: Float[Array, " H W"], pixel_mask: Float[Array, " H W"], propagation_distance_2: ScalarFloat, magnification: ScalarInteger, alpha_object: ScalarFloat, gamma_object: ScalarFloat, alpha_surface: ScalarFloat, gamma_surface: ScalarFloat, wavelength: ScalarFloat, dx: ScalarFloat, ) -> tuple[Complex[Array, " H W"], Complex[Array, " H W"]]: """Sequential processing over positions using fori_loop for proper PIE convergence. Parameters ---------- object_prop_ft : Complex[Array, " H W"] Current object wavefront in Fourier domain. surface_pattern : Complex[Array, " H W"] Current surface pattern. image_data : Float[Array, " P H W"] Measurement data for all positions. positions : Float[Array, " P 2"] Position coordinates for all measurements. frequency_x_grid : Float[Array, " H W"] Frequency grid in x direction. frequency_y_grid : Float[Array, " H W"] Frequency grid in y direction. pixel_mask : Float[Array, " H W"] Pixel response mask for sensor modeling. propagation_distance_2 : ScalarFloat Distance from diffuser to sensor. magnification : ScalarInteger Downsampling magnification factor. alpha_object : ScalarFloat Object update mixing parameter. gamma_object : ScalarFloat Object update step size. alpha_surface : ScalarFloat Surface update mixing parameter. gamma_surface : ScalarFloat Surface update step size. wavelength : ScalarFloat Wavelength of light in meters. dx : ScalarFloat Pixel spacing in meters. Returns ------- tuple of (Complex[Array, " H W"], Complex[Array, " H W"]) Updated object and surface state after sequential processing. Notes ----- Algorithm: - Compute number of positions - Define position body - Apply fori_loop over positions - Return updated state """ num_positions: ScalarInteger = image_data.shape[0] def position_body( pos_idx: ScalarInteger, state: tuple[Complex[Array, " H W"], Complex[Array, " H W"]], ) -> tuple[Complex[Array, " H W"], Complex[Array, " H W"]]: return single_pie_iteration( state[0], state[1], image_data[pos_idx], positions[pos_idx], frequency_x_grid, frequency_y_grid, pixel_mask, propagation_distance_2, magnification, alpha_object, gamma_object, alpha_surface, gamma_surface, wavelength, dx, ) updated_state: tuple[Complex[Array, " H W"], Complex[Array, " H W"]] = ( jax.lax.fori_loop( 0, num_positions, position_body, (object_prop_ft, surface_pattern) ) ) return updated_state
[docs] @jaxtyped(typechecker=beartype) def single_pie_vmap( object_prop_ft: Complex[Array, " H W"], surface_pattern: Complex[Array, " H W"], image_data: Float[Array, " P H W"], positions: Float[Array, " P 2"], frequency_x_grid: Float[Array, " H W"], frequency_y_grid: Float[Array, " H W"], pixel_mask: Float[Array, " H W"], propagation_distance_2: ScalarFloat, magnification: ScalarInteger, alpha_object: ScalarFloat, gamma_object: ScalarFloat, alpha_surface: ScalarFloat, gamma_surface: ScalarFloat, wavelength: ScalarFloat, dx: ScalarFloat, ) -> tuple[Complex[Array, " H W"], Complex[Array, " H W"]]: """Parallel processing over positions using vmap for faster but approximate PIE. All positions use the same initial state, then updates are averaged. Parameters ---------- object_prop_ft : Complex[Array, " H W"] Current object wavefront in Fourier domain. surface_pattern : Complex[Array, " H W"] Current surface pattern. image_data : Float[Array, " P H W"] Measurement data for all positions. positions : Float[Array, " P 2"] Position coordinates for all measurements. frequency_x_grid : Float[Array, " H W"] Frequency grid in x direction. frequency_y_grid : Float[Array, " H W"] Frequency grid in y direction. pixel_mask : Float[Array, " H W"] Pixel response mask for sensor modeling. propagation_distance_2 : ScalarFloat Distance from diffuser to sensor. magnification : ScalarInteger Downsampling magnification factor. alpha_object : ScalarFloat Object update mixing parameter. gamma_object : ScalarFloat Object update step size. alpha_surface : ScalarFloat Surface update mixing parameter. gamma_surface : ScalarFloat Surface update step size. wavelength : ScalarFloat Wavelength of light in meters. dx : ScalarFloat Pixel spacing in meters. Returns ------- tuple of (Complex[Array, " H W"], Complex[Array, " H W"]) Updated object and surface state after parallel processing and averaging. Notes ----- Algorithm: - Apply vmap over all positions using same initial state - Compute average of all object updates - Compute average of all surface updates - Return averaged states """ vmapped_iteration = jax.vmap( single_pie_iteration, in_axes=( None, None, 0, 0, None, None, None, None, None, None, None, None, None, None, None, ), out_axes=0, ) batch_object_updates: Complex[Array, " P H W"] batch_surface_updates: Complex[Array, " P H W"] batch_object_updates, batch_surface_updates = vmapped_iteration( object_prop_ft, surface_pattern, image_data, positions, frequency_x_grid, frequency_y_grid, pixel_mask, propagation_distance_2, magnification, alpha_object, gamma_object, alpha_surface, gamma_surface, wavelength, dx, ) averaged_object: Complex[Array, " H W"] = jnp.mean( batch_object_updates, axis=0 ) averaged_surface: Complex[Array, " H W"] = jnp.mean( batch_surface_updates, axis=0 ) return averaged_object, averaged_surface
@jaxtyped(typechecker=beartype) def _update_object_wavefront( object_shift: Complex[Array, " H W"], surface_pattern: Complex[Array, " H W"], surface_plane: Complex[Array, " H W"], surface_plane_new: Complex[Array, " H W"], alpha_object: ScalarFloat, gamma_object: ScalarFloat, ) -> Complex[Array, " H W"]: """Update object wavefront using rPIE algorithm. Parameters ---------- object_shift : Complex[Array, " H W"] Current shifted object wavefront. surface_pattern : Complex[Array, " H W"] Surface pattern function. surface_plane : Complex[Array, " H W"] Current exit wave at surface. surface_plane_new : Complex[Array, " H W"] Updated exit wave at surface. alpha_object : ScalarFloat Mixing parameter for denominator. gamma_object : ScalarFloat Update step size. Returns ------- Complex[Array, " H W"] Updated object wavefront. Notes ----- Algorithm: - Compute surface conjugate - Compute difference between current and updated surface plane - Compute surface absolute squared - Compute surface maximum squared - Compute denominator - Compute update term - Compute updated object wavefront """ surface_conj: Complex[Array, " H W"] = jnp.conj(surface_pattern) difference: Complex[Array, " H W"] = surface_plane_new - surface_plane surface_abs_squared: Float[Array, " H W"] = jnp.abs(surface_pattern) ** 2 surface_max_squared: Float[Array, " "] = jnp.max(surface_abs_squared) denominator: Float[Array, " H W"] = ( alpha_object * surface_max_squared + (1 - alpha_object) * surface_abs_squared ) update_term: Complex[Array, " H W"] = ( gamma_object * surface_conj * difference / (denominator + 1e-10) ) updated_object: Complex[Array, " H W"] = object_shift + update_term return updated_object @jaxtyped(typechecker=beartype) def _update_surface_pattern( surface_pattern: Complex[Array, " H W"], object_shift: Complex[Array, " H W"], surface_plane: Complex[Array, " H W"], surface_plane_new: Complex[Array, " H W"], alpha_surface: ScalarFloat, gamma_surface: ScalarFloat, ) -> Complex[Array, " H W"]: """Update surface pattern using modified PIE algorithm. Parameters ---------- surface_pattern : Complex[Array, " H W"] Current surface pattern. object_shift : Complex[Array, " H W"] Shifted object wavefront. surface_plane : Complex[Array, " H W"] Current exit wave at surface. surface_plane_new : Complex[Array, " H W"] Updated exit wave at surface. alpha_surface : ScalarFloat Mixing parameter for denominator. gamma_surface : ScalarFloat Update step size. Returns ------- Complex[Array, " H W"] Updated surface pattern. Notes ----- Algorithm: - Compute object conjugate - Compute difference between current and updated surface plane - Compute object absolute squared - Compute object maximum squared - Compute denominator - Compute update term - Compute updated surface pattern """ object_conj: Complex[Array, " H W"] = jnp.conj(object_shift) difference: Complex[Array, " H W"] = surface_plane_new - surface_plane object_abs_squared: Float[Array, " H W"] = jnp.abs(object_shift) ** 2 object_max_squared: Float[Array, " "] = jnp.max(object_abs_squared) denominator: Float[Array, " H W"] = ( alpha_surface * object_max_squared + (1 - alpha_surface) * object_abs_squared ) update_term: Complex[Array, " H W"] = ( gamma_surface * object_conj * difference / (denominator + 1e-10) ) updated_surface: Complex[Array, " H W"] = surface_pattern + update_term return updated_surface @jaxtyped(typechecker=beartype) def _apply_coherent_transfer_function( field_ft: Complex[Array, " H W"], ) -> Complex[Array, " H W"]: """Apply coherent transfer function in Fourier domain. Parameters ---------- field_ft : Complex[Array, " H W"] Field in Fourier domain. Returns ------- Complex[Array, " H W"] Field after CTF application. Notes ----- Algorithm: - Compute CTF - Apply CTF to field - Return result """ ctf: Complex[Array, " H W"] = _get_ctf() result: Complex[Array, " H W"] = ctf * field_ft return result @jaxtyped(typechecker=beartype) def _apply_position_shift( field_ft: Complex[Array, " H W"], position: Float[Array, " 2"], frequency_x_grid: Float[Array, " H W"], frequency_y_grid: Float[Array, " H W"], ) -> Complex[Array, " H W"]: """Apply position shift using phase multiplication in Fourier domain. Implements MATLAB: Hs = exp(-1j*2*pi.*(FX.*xlocation(tt)/imSize + FY.*ylocation(tt)/imSize)) Parameters ---------- field_ft : Complex[Array, " H W"] Field in Fourier domain. position : Float[Array, " 2"] Position shift [x, y]. frequency_x_grid : Float[Array, " H W"] Frequency grid in x direction. frequency_y_grid : Float[Array, " H W"] Frequency grid in y direction. Returns ------- Complex[Array, " H W"] Position-shifted field in real space. Notes ----- Algorithm: - Compute image size - Compute phase factor - Compute position-shifted field in Fourier domain - Compute position-shifted field in real space - Return position-shifted field """ image_size: ScalarInteger = field_ft.shape[0] phase_factor: Complex[Array, " H W"] = jnp.exp( -1j * 2 * jnp.pi * ( frequency_x_grid * position[0] / image_size + frequency_y_grid * position[1] / image_size ) ) shifted_field_ft: Complex[Array, " H W"] = field_ft * phase_factor shifted_field: Complex[Array, " H W"] = jnp.fft.ifft2( jnp.fft.ifftshift(shifted_field_ft) ) return shifted_field @jaxtyped(typechecker=beartype) def _get_propagation_kernel( field_shape: Tuple[int, int], distance: ScalarFloat, wavelength: ScalarFloat, dx: ScalarFloat, ) -> Complex[Array, " H W"]: """Return propagation kernel H_d for free space propagation. Parameters ---------- field_shape : Tuple[int, int] Shape of the field (height, width). distance : ScalarFloat Propagation distance in meters. wavelength : ScalarFloat Wavelength in meters. dx : ScalarFloat Pixel spacing in meters. Returns ------- Complex[Array, " H W"] Propagation kernel for angular spectrum method. Notes ----- Algorithm: - Compute height and width of field - Compute frequency grid in x direction - Compute frequency grid in y direction - Compute frequency grid in x and y directions - Compute frequency squared - Compute k_0 - Compute k_z - Compute propagation kernel - Return propagation kernel """ height: int width: int height, width = field_shape frequency_x: Float[Array, " W"] = jnp.fft.fftfreq(width, dx) frequency_y: Float[Array, " H"] = jnp.fft.fftfreq(height, dx) frequency_x_grid: Float[Array, " H W"] frequency_y_grid: Float[Array, " H W"] frequency_x_grid, frequency_y_grid = jnp.meshgrid(frequency_x, frequency_y) frequency_squared: Float[Array, " H W"] = ( frequency_x_grid**2 + frequency_y_grid**2 ) k_0: ScalarFloat = 2 * jnp.pi / wavelength k_z: Complex[Array, " H W"] = jnp.sqrt( k_0**2 - (2 * jnp.pi) ** 2 * frequency_squared + 0j ) kernel: Complex[Array, " H W"] = jnp.exp(1j * k_z * distance) return kernel @jaxtyped(typechecker=beartype) def _compute_sensor_intensity( sensor_field: Complex[Array, " H W"], pixel_mask: Float[Array, " H W"], magnification: ScalarInteger, ) -> Float[Array, " H W"]: """Compute sensor plane intensity with pixel response mask and downsampling. Implements the MATLAB: conv2(pixelMask.*abs(sensorPlane).^2, ones(mag,mag)) Parameters ---------- sensor_field : Complex[Array, " H W"] Complex field at sensor plane. pixel_mask : Float[Array, " H W"] Pixel response mask modeling sensor characteristics. magnification : ScalarInteger Downsampling magnification factor. Returns ------- Float[Array, " H W"] Downsampled intensity pattern. Notes ----- Algorithm: - Compute intensity - Apply pixel mask - Compute kernel - Convolve intensity with kernel - Downsample intensity """ intensity: Float[Array, " H W"] = jnp.abs(sensor_field) ** 2 masked_intensity: Float[Array, " H W"] = pixel_mask * intensity kernel: Float[Array, " mag mag"] = jnp.ones((magnification, magnification)) convolved_intensity: Float[Array, " H W"] = jax.scipy.signal.convolve2d( masked_intensity, kernel, mode="same" ) downsampled_intensity: Float[Array, " H_new W_new"] = convolved_intensity[ magnification - 1 :: magnification, magnification - 1 :: magnification ] return downsampled_intensity @jaxtyped(typechecker=beartype) def _create_frequency_grids( field: Complex[Array, " H W"], dx: ScalarFloat ) -> Tuple[Float[Array, " H W"], Float[Array, " H W"]]: """Create frequency grids for Fourier transforms. Parameters ---------- field : Complex[Array, " H W"] Input field to determine grid size. dx : ScalarFloat Spatial sampling interval. Returns ------- tuple of (Float[Array, " H W"], Float[Array, " H W"]) - frequency_x_grid : Float[Array, " H W"] Frequency grid in x direction. - frequency_y_grid : Float[Array, " H W"] Frequency grid in y direction. Notes ----- Algorithm: - Compute height and width of field - Compute frequency grid in x direction - Compute frequency grid in y direction - Compute frequency grid in x and y directions - Return frequency grids """ height: ScalarInteger = field.shape[0] width: ScalarInteger = field.shape[1] frequency_x: Float[Array, " W"] = jnp.fft.fftfreq(width, dx) frequency_y: Float[Array, " H"] = jnp.fft.fftfreq(height, dx) frequency_x_grid: Float[Array, " H W"] frequency_y_grid: Float[Array, " H W"] frequency_x_grid, frequency_y_grid = jnp.meshgrid(frequency_x, frequency_y) return frequency_x_grid, frequency_y_grid @jaxtyped(typechecker=beartype) def _get_ctf( field_shape: Optional[Tuple[int, int]] = (256, 256), ) -> Complex[Array, " H W"]: """Return a placeholder coherent transfer function. Parameters ---------- field_shape : Tuple[int, int], optional Shape of the field (height, width). Default is (256, 256). Returns ------- Complex[Array, " H W"] Coherent transfer function. """ ctf: Complex[Array, " H W"] = jnp.ones(field_shape, dtype=complex) return ctf