janssen.scopes¶
Microscope implementations and forward models.
Extended Summary¶
Complete forward models for optical microscopy including diffraction patterns, light-sample interactions, and multi-position imaging.
Routine Listings¶
diffractogram_noscale()Calculates the diffractogram without scaling camera pixel size.
linear_interaction()Propagates optical wavefront through sample using linear interaction.
simple_diffractogram()Calculates the diffractogram using a simple model.
simple_microscope()Calculates 3D diffractograms at all pixel positions in parallel.
Notes
These functions provide complete forward models for optical microscopy and are designed for use in inverse problems and ptychography reconstruction. All functions are JAX-compatible and support automatic differentiation.
- janssen.scopes.diffractogram_noscale(sample_cut: SampleFunction, lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) OpticalWavefront[source]¶
Calculate the diffractogram of a sample using a simple model.
The lightwave interacts with the sample linearly, and is then zoomed optically. Following this it interacts with a circular aperture before propagating to the camera plane. The camera image is then scaled to the pixel size of the camera. The diffractogram is created from the camera image.
- Parameters:
sample_cut (
SampleFunction) – The sample function representing the optical properties of the samplelightwave (
OpticalWavefront) – The incoming optical wavefrontzoom_factor (
float) – The zoom factor for the optical systemaperture_diameter (
float) – The diameter of the aperture in meterstravel_distance (
float) – The distance traveled by the light in meterscamera_pixel_size (
float) – The pixel size of the camera in metersaperture_center (
Optional[Float[Array," 2"]], optional) – The center of the aperture in meters
- Returns:
at_camera – The calculated optical wavefront at the camera plane.
- Return type:
OpticalWavefront
Notes
Algorithm:
Propagate the lightwave through the sample using linear interaction
Apply optical zoom to the wavefront
Apply a circular aperture to the zoomed wavefront
Propagate the wavefront to the camera plane using Fraunhofer propagation
- janssen.scopes.linear_interaction(sample: SampleFunction, light: OpticalWavefront) OpticalWavefront[source]¶
Propagate an optical wavefront through a sample.
The sample is modeled as a complex function that modifies the incoming wavefront. Using linear interaction.
- Parameters:
sample (
SampleFunction) – The sample function representing the optical properties of the samplelight (
OpticalWavefront) – The incoming optical wavefront
- Returns:
The propagated optical wavefront after passing through the sample
- Return type:
OpticalWavefront
- janssen.scopes.simple_diffractogram(sample_cut: SampleFunction, lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) Diffractogram[source]¶
Calculate the diffractogram of a sample using a simple model.
The lightwave interacts with the sample linearly, and is then zoomed optically. Following this it interacts with a circular aperture before propagating to the camera plane using scaled Fraunhofer propagation to match the detector pixel size.
- Parameters:
sample_cut (
SampleFunction) – The sample function representing the optical properties of the samplelightwave (
OpticalWavefront) – The incoming optical wavefrontzoom_factor (
float) – The zoom factor for the optical systemaperture_diameter (
float) – The diameter of the aperture in meterstravel_distance (
float) – The distance traveled by the light in meterscamera_pixel_size (
float) – The pixel size of the detector/camera in metersaperture_center (
Optional[Float[Array," 2"]], optional) – The center of the aperture in meters
- Returns:
The calculated diffractogram of the sample. Output shape matches the input lightwave field shape.
- Return type:
Diffractogram
Notes
Algorithm:
Propagate the lightwave through the sample using linear interaction
Apply optical zoom to the wavefront
Apply a circular aperture to the zoomed wavefront
Propagate to the camera plane using scaled Fraunhofer propagation which outputs at the specified camera pixel size
Calculate the field intensity of the camera image
Create a diffractogram from the camera image
- janssen.scopes.simple_microscope(sample: SampleFunction, positions: Num[Array, 'n 2'], lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) MicroscopeData[source]¶
Calculate the 3D diffractograms of the entire imaging.
This cuts the sample, and then generates a diffractogram with the desired camera pixel size - all done in parallel across all available GPUs. Done at every pixel positions.
The function automatically detects available GPUs and distributes computation across devices using JAX’s modern sharding API. For ptychography experiments with hundreds of scan positions, this provides significant speedup on multi-GPU systems.
Implementation Logic¶
The function orchestrates parallel diffractogram computation through a four-stage pipeline:
Position Preparation: - Converts physical positions to pixel coordinates via positions/dx - Detects available devices (GPUs/TPUs) via jax.devices() - Pads position array to nearest multiple of device count for
even distribution
Example: 400 positions on 7 GPUs → pad to 406 (7 × 58)
Data Sharding (Automatic Multi-GPU Distribution): - Creates Mesh with all devices along “data” axis - Defines NamedSharding(P(“data”, None)) for position dimension - Uses jax.device_put to explicitly distribute positions across
devices
Each device receives n_positions // n_devices positions
Critical for memory: sharding prevents OOM by distributing both input positions and output diffractograms
Parallel Computation: - jax.vmap over positions dimension with in_axes=(None, 0)
broadcasts sample to all positions
For each position: dynamic_slice extracts sample cutout, simple_diffractogram computes forward model
JAX’s SPMD compiler ensures each device processes only its shard
Outputs remain sharded: Float[Array, “padded hh ww”] distributed across devices
Result Assembly: - Trims padded diffractograms back to original count via [:n] - Creates MicroscopeData with trimmed images and original positions - JAX automatically manages cross-device communication
- type sample:
SampleFunction- param sample:
The sample function representing the optical properties of the sample
- type sample:
SampleFunction- type positions:
Num[Array, 'n 2']- param positions:
The positions in the sample plane where the diffractograms are calculated. Physical coordinates in meters.
- type positions:
Num[Array," n 2"]- type lightwave:
OpticalWavefront- param lightwave:
The incoming optical wavefront
- type lightwave:
OpticalWavefront- type zoom_factor:
Union[float,Float[Array, '']]- param zoom_factor:
The zoom factor for the optical system
- type zoom_factor:
- type aperture_diameter:
Union[float,Float[Array, '']]- param aperture_diameter:
The diameter of the aperture in meters
- type aperture_diameter:
- type travel_distance:
Union[float,Float[Array, '']]- param travel_distance:
The distance traveled by the light in meters
- type travel_distance:
- type camera_pixel_size:
Union[float,Float[Array, '']]- param camera_pixel_size:
The pixel size of the detector/camera in meters
- type camera_pixel_size:
- type aperture_center:
Optional[Float[Array, '2']], default:None- param aperture_center:
The center of the aperture in meters
- type aperture_center:
Optional[Float[Array," 2"]], optional- returns:
The calculated diffractograms of the sample at the specified positions. Output diffractogram shape matches the lightwave field shape.
- rtype:
MicroscopeData
Notes
Multi-GPU Performance:
Sharding overhead is negligible (<100ms) compared to computation time (minutes for hundreds of positions)
Speedup is near-linear with device count for large n (>100 positions)
Memory per device: ~(n_total / n_devices) × diffractogram_size
Padding waste: at most (n_devices - 1) extra computations, typically <2% overhead
Memory Characteristics:
For n=400 positions, 256×256 diffractograms, float32: - Single GPU: 400 × 256 × 256 × 4 bytes ≈ 100 MB (tractable) - Output sharding prevents accumulation on single device - Forward model peak memory: ~2-3× output size during vmap execution
When Sharding Helps:
Multi-GPU systems: Always beneficial for n > 50
Single GPU: No overhead, padding waste is minimal
CPU: Sharding across multiple CPUs provides moderate speedup
Design Decisions:
Padding strategy ensures compatibility with any device count (avoids “not divisible by” errors)
NamedSharding chosen over legacy PositionalSharding for JAX 0.7.1+ compatibility
P(“data”, None) shards only position dimension, broadcasts sample (sample is read-only and fits in device memory)
Explicit jax.device_put ensures sharding happens before vmap, not during (avoids runtime sharding overhead)