janssen.scopes

Microscope implementations and forward models.

Extended Summary

Complete forward models for optical microscopy including diffraction patterns, light-sample interactions, and multi-position imaging.

Routine Listings

diffractogram_noscale()

Calculates the diffractogram without scaling camera pixel size.

linear_interaction()

Propagates optical wavefront through sample using linear interaction.

simple_diffractogram()

Calculates the diffractogram using a simple model.

simple_microscope()

Calculates 3D diffractograms at all pixel positions in parallel.

Notes

These functions provide complete forward models for optical microscopy and are designed for use in inverse problems and ptychography reconstruction. All functions are JAX-compatible and support automatic differentiation.

janssen.scopes.diffractogram_noscale(sample_cut: SampleFunction, lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) OpticalWavefront[source]

Calculate the diffractogram of a sample using a simple model.

The lightwave interacts with the sample linearly, and is then zoomed optically. Following this it interacts with a circular aperture before propagating to the camera plane. The camera image is then scaled to the pixel size of the camera. The diffractogram is created from the camera image.

Parameters:
  • sample_cut (SampleFunction) – The sample function representing the optical properties of the sample

  • lightwave (OpticalWavefront) – The incoming optical wavefront

  • zoom_factor (float) – The zoom factor for the optical system

  • aperture_diameter (float) – The diameter of the aperture in meters

  • travel_distance (float) – The distance traveled by the light in meters

  • camera_pixel_size (float) – The pixel size of the camera in meters

  • aperture_center (Optional[Float[Array, " 2"]], optional) – The center of the aperture in meters

Returns:

at_camera – The calculated optical wavefront at the camera plane.

Return type:

OpticalWavefront

Notes

Algorithm:

  • Propagate the lightwave through the sample using linear interaction

  • Apply optical zoom to the wavefront

  • Apply a circular aperture to the zoomed wavefront

  • Propagate the wavefront to the camera plane using Fraunhofer propagation

janssen.scopes.linear_interaction(sample: SampleFunction, light: OpticalWavefront) OpticalWavefront[source]

Propagate an optical wavefront through a sample.

The sample is modeled as a complex function that modifies the incoming wavefront. Using linear interaction.

Parameters:
  • sample (SampleFunction) – The sample function representing the optical properties of the sample

  • light (OpticalWavefront) – The incoming optical wavefront

Returns:

The propagated optical wavefront after passing through the sample

Return type:

OpticalWavefront

janssen.scopes.simple_diffractogram(sample_cut: SampleFunction, lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) Diffractogram[source]

Calculate the diffractogram of a sample using a simple model.

The lightwave interacts with the sample linearly, and is then zoomed optically. Following this it interacts with a circular aperture before propagating to the camera plane using scaled Fraunhofer propagation to match the detector pixel size.

Parameters:
  • sample_cut (SampleFunction) – The sample function representing the optical properties of the sample

  • lightwave (OpticalWavefront) – The incoming optical wavefront

  • zoom_factor (float) – The zoom factor for the optical system

  • aperture_diameter (float) – The diameter of the aperture in meters

  • travel_distance (float) – The distance traveled by the light in meters

  • camera_pixel_size (float) – The pixel size of the detector/camera in meters

  • aperture_center (Optional[Float[Array, " 2"]], optional) – The center of the aperture in meters

Returns:

The calculated diffractogram of the sample. Output shape matches the input lightwave field shape.

Return type:

Diffractogram

Notes

Algorithm:

  • Propagate the lightwave through the sample using linear interaction

  • Apply optical zoom to the wavefront

  • Apply a circular aperture to the zoomed wavefront

  • Propagate to the camera plane using scaled Fraunhofer propagation which outputs at the specified camera pixel size

  • Calculate the field intensity of the camera image

  • Create a diffractogram from the camera image

janssen.scopes.simple_microscope(sample: SampleFunction, positions: Num[Array, 'n 2'], lightwave: OpticalWavefront, zoom_factor: float | Float[Array, ''], aperture_diameter: float | Float[Array, ''], travel_distance: float | Float[Array, ''], camera_pixel_size: float | Float[Array, ''], aperture_center: Float[Array, '2'] | None = None) MicroscopeData[source]

Calculate the 3D diffractograms of the entire imaging.

This cuts the sample, and then generates a diffractogram with the desired camera pixel size - all done in parallel across all available GPUs. Done at every pixel positions.

The function automatically detects available GPUs and distributes computation across devices using JAX’s modern sharding API. For ptychography experiments with hundreds of scan positions, this provides significant speedup on multi-GPU systems.

Implementation Logic

The function orchestrates parallel diffractogram computation through a four-stage pipeline:

  1. Position Preparation: - Converts physical positions to pixel coordinates via positions/dx - Detects available devices (GPUs/TPUs) via jax.devices() - Pads position array to nearest multiple of device count for

    even distribution

    • Example: 400 positions on 7 GPUs → pad to 406 (7 × 58)

  2. Data Sharding (Automatic Multi-GPU Distribution): - Creates Mesh with all devices along “data” axis - Defines NamedSharding(P(“data”, None)) for position dimension - Uses jax.device_put to explicitly distribute positions across

    devices

    • Each device receives n_positions // n_devices positions

    • Critical for memory: sharding prevents OOM by distributing both input positions and output diffractograms

  3. Parallel Computation: - jax.vmap over positions dimension with in_axes=(None, 0)

    broadcasts sample to all positions

    • For each position: dynamic_slice extracts sample cutout, simple_diffractogram computes forward model

    • JAX’s SPMD compiler ensures each device processes only its shard

    • Outputs remain sharded: Float[Array, “padded hh ww”] distributed across devices

  4. Result Assembly: - Trims padded diffractograms back to original count via [:n] - Creates MicroscopeData with trimmed images and original positions - JAX automatically manages cross-device communication

type sample:

SampleFunction

param sample:

The sample function representing the optical properties of the sample

type sample:

SampleFunction

type positions:

Num[Array, 'n 2']

param positions:

The positions in the sample plane where the diffractograms are calculated. Physical coordinates in meters.

type positions:

Num[Array, " n 2"]

type lightwave:

OpticalWavefront

param lightwave:

The incoming optical wavefront

type lightwave:

OpticalWavefront

type zoom_factor:

Union[float, Float[Array, '']]

param zoom_factor:

The zoom factor for the optical system

type zoom_factor:

float

type aperture_diameter:

Union[float, Float[Array, '']]

param aperture_diameter:

The diameter of the aperture in meters

type aperture_diameter:

float

type travel_distance:

Union[float, Float[Array, '']]

param travel_distance:

The distance traveled by the light in meters

type travel_distance:

float

type camera_pixel_size:

Union[float, Float[Array, '']]

param camera_pixel_size:

The pixel size of the detector/camera in meters

type camera_pixel_size:

float

type aperture_center:

Optional[Float[Array, '2']], default: None

param aperture_center:

The center of the aperture in meters

type aperture_center:

Optional[Float[Array, " 2"]], optional

returns:

The calculated diffractograms of the sample at the specified positions. Output diffractogram shape matches the lightwave field shape.

rtype:

MicroscopeData

Notes

Multi-GPU Performance:

  • Sharding overhead is negligible (<100ms) compared to computation time (minutes for hundreds of positions)

  • Speedup is near-linear with device count for large n (>100 positions)

  • Memory per device: ~(n_total / n_devices) × diffractogram_size

  • Padding waste: at most (n_devices - 1) extra computations, typically <2% overhead

Memory Characteristics:

For n=400 positions, 256×256 diffractograms, float32: - Single GPU: 400 × 256 × 256 × 4 bytes ≈ 100 MB (tractable) - Output sharding prevents accumulation on single device - Forward model peak memory: ~2-3× output size during vmap execution

When Sharding Helps:

  • Multi-GPU systems: Always beneficial for n > 50

  • Single GPU: No overhead, padding waste is minimal

  • CPU: Sharding across multiple CPUs provides moderate speedup

Design Decisions:

  • Padding strategy ensures compatibility with any device count (avoids “not divisible by” errors)

  • NamedSharding chosen over legacy PositionalSharding for JAX 0.7.1+ compatibility

  • P(“data”, None) shards only position dimension, broadcasts sample (sample is read-only and fits in device memory)

  • Explicit jax.device_put ensures sharding happens before vmap, not during (avoids runtime sharding overhead)