Source code for janssen.simul.microscope

"""Codes for optical propagation through lenses and optical elements.

Extended Summary
----------------
Microscope forward models for simulating image formation in optical
microscopy. Includes functions for computing diffraction patterns and
modeling light-sample interactions.

Routine Listings
----------------
lens_propagation : function
    Propagates an optical wavefront through a lens
linear_interaction : function
    Propagates an optical wavefront through a sample using linear interaction
simple_diffractogram : function
    Calculates the diffractogram of a sample using a simple model
simple_microscope : function
    Calculates 3D diffractograms at all pixel positions in parallel

Notes
-----
These functions provide complete forward models for optical microscopy
and are designed for use in inverse problems and ptychography reconstruction.
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Optional, Tuple
from jaxtyping import Array, Complex, Float, Int, Num, jaxtyped

from janssen.lenses.lens_prop import fraunhofer_prop, optical_zoom
from janssen.utils import (
    Diffractogram,
    MicroscopeData,
    OpticalWavefront,
    SampleFunction,
    make_diffractogram,
    make_microscope_data,
    make_optical_wavefront,
    make_sample_function,
    scalar_float,
    scalar_numeric,
)

from .apertures import circular_aperture
from .helper import field_intensity, scale_pixel

jax.config.update("jax_enable_x64", True)


[docs] @jaxtyped(typechecker=beartype) def linear_interaction( sample: SampleFunction, light: OpticalWavefront, ) -> OpticalWavefront: """Propagate an optical wavefront through a sample. The sample is modeled as a complex function that modifies the incoming wavefront. Using linear interaction. Parameters ---------- sample : SampleFunction The sample function representing the optical properties of the sample light : OpticalWavefront The incoming optical wavefront Returns ------- OpticalWavefront The propagated optical wavefront after passing through the sample """ new_field: Complex[Array, " H W"] = sample.sample * light.field interacted: OpticalWavefront = make_optical_wavefront( field=new_field, wavelength=light.wavelength, dx=light.dx, z_position=light.z_position, ) return interacted
[docs] @jaxtyped(typechecker=beartype) def simple_diffractogram( sample_cut: SampleFunction, lightwave: OpticalWavefront, zoom_factor: scalar_float, aperture_diameter: scalar_float, travel_distance: scalar_float, camera_pixel_size: scalar_float, aperture_center: Optional[Float[Array, " 2"]] = None, ) -> Diffractogram: """Calculate the diffractogram of a sample using a simple model. The lightwave interacts with the sample linearly, and is then zoomed optically. Following this it interacts with a circular aperture before propagating to the camera plane. The camera image is then scaled to the pixel size of the camera. The diffractogram is created from the camera image. Parameters ---------- sample_cut : SampleFunction The sample function representing the optical properties of the sample lightwave : OpticalWavefront The incoming optical wavefront zoom_factor : scalar_float The zoom factor for the optical system aperture_diameter : scalar_float The diameter of the aperture in meters travel_distance : scalar_float The distance traveled by the light in meters camera_pixel_size : scalar_float The pixel size of the camera in meters aperture_center : Optional[Float[Array, " 2"]], optional The center of the aperture in pixels Returns ------- Diffractogram The calculated diffractogram of the sample Notes ----- Algorithm: - Propagate the lightwave through the sample using linear interaction - Apply optical zoom to the wavefront - Apply a circular aperture to the zoomed wavefront - Propagate the wavefront to the camera plane using Fraunhofer propagation - Scale the pixel size of the camera image - Calculate the field intensity of the camera image - Create a diffractogram from the camera image """ at_sample_plane: OpticalWavefront = linear_interaction( sample=sample_cut, light=lightwave, ) zoomed_wave: OpticalWavefront = optical_zoom(at_sample_plane, zoom_factor) after_aperture: OpticalWavefront = circular_aperture( zoomed_wave, aperture_diameter, aperture_center ) at_camera: OpticalWavefront = fraunhofer_prop( after_aperture, travel_distance ) at_camera_scaled: OpticalWavefront = scale_pixel( at_camera, camera_pixel_size, ) scaled_camera_image: Float[Array, " H W"] = field_intensity( at_camera_scaled.field ) diffractogram: Diffractogram = make_diffractogram( image=scaled_camera_image, wavelength=at_camera_scaled.wavelength, dx=at_camera_scaled.dx, ) return diffractogram
[docs] @jaxtyped(typechecker=beartype) def simple_microscope( sample: SampleFunction, positions: Num[Array, " n 2"], lightwave: OpticalWavefront, zoom_factor: scalar_float, aperture_diameter: scalar_float, travel_distance: scalar_float, camera_pixel_size: scalar_float, aperture_center: Optional[Float[Array, " 2"]] = None, ) -> MicroscopeData: """Calculate the 3D diffractograms of the entire imaging. This cuts the sample, and then generates a diffractogram with the desired camera pixel size - all done in parallel. Done at every pixel positions. Parameters ---------- sample : SampleFunction The sample function representing the optical properties of the sample positions : Num[Array, " n 2"] The positions in the sample plane where the diffractograms are calculated. lightwave : OpticalWavefront The incoming optical wavefront zoom_factor : scalar_float The zoom factor for the optical system aperture_diameter : scalar_float The diameter of the aperture in meters travel_distance : scalar_float The distance traveled by the light in meters camera_pixel_size : scalar_float The pixel size of the camera in meters aperture_center : Optional[Float[Array, " 2"]], optional The center of the aperture in pixels Returns ------- MicroscopeData The calculated diffractograms of the sample at the specified positions Notes ----- Algorithm: - Get the size of the lightwave field - Calculate the pixel positions in the sample plane - For each position, cut out the sample and calculate the diffractogram - Combine the diffractograms into a single MicroscopeData object - Return the MicroscopeData object """ interaction_size: Tuple[int, int] = lightwave.field.shape pixel_positions: Float[Array, " n 2"] = positions / lightwave.dx def diffractogram_at_position( sample: SampleFunction, this_position: Num[Array, " 2"] ) -> Diffractogram: x: scalar_numeric y: scalar_numeric hh: int ww: int hh, ww = interaction_size x, y = this_position start_cut_x: Int[Array, " "] = jnp.floor( x - (0.5 * interaction_size[1]) ).astype(int) start_cut_y: Int[Array, " "] = jnp.floor( y - (0.5 * interaction_size[0]) ).astype(int) cutout_sample: Complex[Array, " hh ww"] = jax.lax.dynamic_slice( sample.sample, (start_cut_y, start_cut_x), (interaction_size[0], interaction_size[1]), ) this_sample: SampleFunction = make_sample_function( sample=cutout_sample, dx=sample.dx, ) this_diffractogram: Diffractogram = simple_diffractogram( sample_cut=this_sample, lightwave=lightwave, zoom_factor=zoom_factor, aperture_diameter=aperture_diameter, travel_distance=travel_distance, camera_pixel_size=camera_pixel_size, aperture_center=aperture_center, ) return this_diffractogram.image diffraction_images: Float[Array, " n hh ww"] = jax.vmap( diffractogram_at_position, in_axes=(None, 0) )(sample, pixel_positions) combined_data: MicroscopeData = make_microscope_data( image_data=diffraction_images, positions=positions, wavelength=lightwave.wavelength, dx=lightwave.dx, ) return combined_data