Simple Microscope Simulation - Spherical Inclusions¶
This notebook demonstrates simulating microscope imaging of a sample with multiple spherical inclusions using the simple_microscope function.
Overview¶
We create a 2D sample with ~50 randomly placed spherical inclusions:
Random positions across the field of view
Random radii between 50-500 pixels (25-250 µm)
Random refractive index contrast
Projected to 2D transmission function for imaging
Imports¶
In [1]:
import janssen as jns
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import cmocean.cm as cmo
from matplotlib_scalebar.scalebar import ScaleBar
from matplotlib.patches import Circle
In [2]:
jns.__version__
Out [2]:
'2025.10.6'
In [3]:
%load_ext autoreload
%autoreload 2
Define Simulation Parameters¶
Same grid as USAF sample: 4096x4096 pixels at 0.5 µm pixel size.
In [4]:
pixel_size = 0.5e-6 # 0.5 microns
num_pixels = 4096 # Same as USAF sample
wavelength = 633e-9 # 633 nm (HeNe laser)
# Sphere parameters
num_spheres = 200
min_radius_pixels = 50 # Minimum radius in pixels
max_radius_pixels = 500 # Maximum radius in pixels
# Refractive indices - almost transparent (small contrast, no absorption)
n_background = 1.0 + 0.0j # Air/vacuum background
# Random seed for reproducibility
np.random.seed(42)
print(f"Pixel size: {pixel_size * 1e6:.1f} microns")
print(f"Grid size: {num_pixels} x {num_pixels} pixels")
print(f"Field of view: {pixel_size * num_pixels * 1e3:.2f} mm")
print(f"Wavelength: {wavelength * 1e9:.0f} nm")
print(f"Number of spheres: {num_spheres}")
print(
f"Radius range: {min_radius_pixels}-{max_radius_pixels} pixels ({min_radius_pixels * pixel_size * 1e6:.0f}-{max_radius_pixels * pixel_size * 1e6:.0f} µm)"
)
Pixel size: 0.5 microns
Grid size: 4096 x 4096 pixels
Field of view: 2.05 mm
Wavelength: 633 nm
Number of spheres: 200
Radius range: 50-500 pixels (25-250 µm)
1. Generate Random Sphere Parameters¶
Create random positions, radii, and refractive indices for all spheres.
In [5]:
# Generate uniformly distributed sphere positions using grid + jitter
# Create a grid covering the full image (spheres can be cut off at edges)
grid_size = int(
np.ceil(np.sqrt(num_spheres))
) # e.g., 8x8 grid for ~50 spheres
# No margin - spheres can extend to edges and be cut off
grid_spacing_y = num_pixels / grid_size
grid_spacing_x = num_pixels / grid_size
grid_y, grid_x = np.meshgrid(
np.linspace(
grid_spacing_y / 2, num_pixels - grid_spacing_y / 2, grid_size
),
np.linspace(
grid_spacing_x / 2, num_pixels - grid_spacing_x / 2, grid_size
),
)
grid_centers = np.stack([grid_y.ravel(), grid_x.ravel()], axis=1)
# Randomly select num_spheres positions from grid and add jitter
selected_indices = np.random.choice(
len(grid_centers), num_spheres, replace=False
)
sphere_centers = grid_centers[selected_indices]
# Add jitter (up to 25% of grid spacing for more uniform look)
jitter_amount = 0.25 * min(grid_spacing_y, grid_spacing_x)
sphere_centers_y = sphere_centers[:, 0] + np.random.uniform(
-jitter_amount, jitter_amount, num_spheres
)
sphere_centers_x = sphere_centers[:, 1] + np.random.uniform(
-jitter_amount, jitter_amount, num_spheres
)
# Radii: uniform distribution between min and max
sphere_radii_pixels = np.random.uniform(
min_radius_pixels, max_radius_pixels, num_spheres
)
# Refractive index - unique for each sphere, clustered values
# Bigger spheres are more transparent (smaller contrast), smaller spheres less transparent
# Normalize radius to [0, 1] range
radius_normalized = (sphere_radii_pixels - min_radius_pixels) / (
max_radius_pixels - min_radius_pixels
)
# Base refractive index with small random variation for each sphere
# Real part: smaller spheres have higher contrast (n ~ 1.008), bigger spheres lower (n ~ 1.003)
# Add small unique variation to each
base_n_real = 1.003 + 0.005 * (
1 - radius_normalized
) # Inversely proportional to size
sphere_n_real = base_n_real + np.random.uniform(
-0.0005, 0.0005, num_spheres
) # Small unique variation
# Imaginary part: smaller spheres have more absorption, bigger spheres less
# Smaller spheres: κ ~ 0.0003, bigger spheres: κ ~ 0.00005
base_n_imag = 0.00005 + 0.00025 * (
1 - radius_normalized
) # Inversely proportional to size
sphere_n_imag = base_n_imag + np.random.uniform(
-0.00002, 0.00002, num_spheres
) # Small unique variation
sphere_n_imag = np.maximum(sphere_n_imag, 0) # Ensure non-negative
sphere_n = sphere_n_real + 1j * sphere_n_imag
print(f"Generated {num_spheres} spheres with uniform distribution")
print(
f"Grid: {grid_size}x{grid_size} = {grid_size**2} positions, selected {num_spheres}"
)
print(f"Grid spacing: {grid_spacing_y:.0f} x {grid_spacing_x:.0f} pixels")
print(
f"Center Y range: {sphere_centers_y.min():.0f} to {sphere_centers_y.max():.0f} pixels"
)
print(
f"Center X range: {sphere_centers_x.min():.0f} to {sphere_centers_x.max():.0f} pixels"
)
print(
f"Radius range: {sphere_radii_pixels.min():.0f} to {sphere_radii_pixels.max():.0f} pixels"
)
print(
f"Refractive index (real) range: {sphere_n_real.min():.5f} to {sphere_n_real.max():.5f}"
)
print(
f"Absorption (imag) range: {sphere_n_imag.min():.6f} to {sphere_n_imag.max():.6f}"
)
print(
f"Contrast (n-1) range: {(sphere_n_real-1).min():.5f} to {(sphere_n_real-1).max():.5f}"
)
Generated 200 spheres with uniform distribution
Grid: 15x15 = 225 positions, selected 200
Grid spacing: 273 x 273 pixels
Center Y range: 74 to 4025 pixels
Center X range: 96 to 4025 pixels
Radius range: 51 to 499 pixels
Refractive index (real) range: 1.00266 to 1.00821
Absorption (imag) range: 0.000035 to 0.000315
Contrast (n-1) range: 0.00266 to 0.00821
2. Create 2D Sample with Spherical Projections¶
For each sphere, we compute the 2D projection (optical path length through a sphere). For a sphere of radius R centered at origin, the path length at position (x,y) is: \(L(x,y) = 2\sqrt{R^2 - x^2 - y^2}\) for \(x^2 + y^2 < R^2\)
In [6]:
# Create 2D sample with projected spheres using vmap
# Create coordinate grids
y_coords = jnp.arange(num_pixels)
x_coords = jnp.arange(num_pixels)
yy, xx = jnp.meshgrid(y_coords, x_coords, indexing="ij")
# Wave number
k = 2 * jnp.pi / wavelength
# Convert sphere parameters to JAX arrays
centers_y = jnp.array(sphere_centers_y)
centers_x = jnp.array(sphere_centers_x)
radii = jnp.array(sphere_radii_pixels)
n_spheres_arr = jnp.array(sphere_n)
def compute_sphere_transmission(cy, cx, radius, n_sphere):
"""Compute transmission contribution from a single sphere."""
# Distance from sphere center (in pixels)
dist_sq = (yy - cy) ** 2 + (xx - cx) ** 2
# Path length through sphere (in meters)
# L = 2 * sqrt(R^2 - r^2) for r < R, else 0
path_length_pixels = 2 * jnp.sqrt(jnp.maximum(radius**2 - dist_sq, 0))
path_length_meters = path_length_pixels * pixel_size
# Phase and amplitude from this sphere
delta_n = n_sphere - n_background
sphere_transmission = jnp.exp(1j * k * delta_n * path_length_meters)
return sphere_transmission
# vmap over all spheres and multiply contributions together
all_transmissions = jax.vmap(compute_sphere_transmission)(
centers_y, centers_x, radii, n_spheres_arr
)
# Product of all sphere transmissions (shape: num_pixels x num_pixels)
sample_transmission = jnp.prod(all_transmissions, axis=0)
print(f"Sample created using vmap!")
print(
f"Amplitude range: {jnp.abs(sample_transmission).min():.4f} to {jnp.abs(sample_transmission).max():.4f}"
)
print(
f"Phase range: {jnp.angle(sample_transmission).min():.4f} to {jnp.angle(sample_transmission).max():.4f} rad"
)
WARNING:2025-12-16 20:46:10,697:jax._src.xla_bridge:864: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Sample created using vmap!
Amplitude range: 0.1303 to 1.0000
Phase range: -3.1416 to 3.1416 rad
In [7]:
# Create sample function
sphere_sample = jns.utils.make_sample_function(
sample=sample_transmission,
dx=pixel_size,
)
print(f"Sample shape: {sphere_sample.sample.shape}")
print(f"Sample dx: {sphere_sample.dx * 1e6:.2f} microns")
Sample shape: (4096, 4096)
Sample dx: 0.50 microns
In [8]:
# Visualize the sample
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
amp = jnp.abs(sphere_sample.sample)
phase = jnp.angle(sphere_sample.sample)
# Amplitude
im0 = axes[0].imshow(amp, cmap=cmo.gray)
axes[0].set_title("Amplitude (Transmission)")
scalebar = ScaleBar(sphere_sample.dx, "m", length_fraction=0.25, color="black")
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0], label="Transmission")
# Phase
im1 = axes[1].imshow(phase, cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi)
axes[1].set_title("Phase")
scalebar = ScaleBar(sphere_sample.dx, "m", length_fraction=0.25, color="black")
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1], label="Phase (rad)")
plt.suptitle(f"Sample with {num_spheres} Spherical Inclusions", fontsize=14)
plt.tight_layout()
plt.show()
3. Create Illumination Wavefront¶
In [9]:
illumination_size = 256 # Same as USAF notebook
lightwave = jns.models.plane_wave(
wavelength=wavelength,
dx=pixel_size,
grid_size=(illumination_size, illumination_size),
amplitude=1.0,
)
print(f"Illumination field shape: {lightwave.field.shape}")
print(f"Illumination wavelength: {lightwave.wavelength * 1e9:.0f} nm")
print(f"Illumination dx: {lightwave.dx * 1e6:.2f} microns")
print(f"Illumination FOV: {illumination_size * pixel_size * 1e6:.0f} microns")
Illumination field shape: (256, 256)
Illumination wavelength: 633 nm
Illumination dx: 0.50 microns
Illumination FOV: 128 microns
4. Set Microscope Parameters¶
In [10]:
# Microscope parameters
zoom_factor = 10.0 # 10x magnification
aperture_diameter = 1e-3 # 1 mm aperture
travel_distance = 0.1 # 150 mm to camera
detector_pixel_size = jnp.array(16e-6) # 16 micron camera pixels
print(f"Zoom factor: {zoom_factor}x")
print(f"Aperture diameter: {aperture_diameter * 1e3:.1f} mm")
print(f"Travel distance: {travel_distance * 1e3:.0f} mm")
print(f"Detector pixel size: {detector_pixel_size * 1e6:.1f} µm")
Zoom factor: 10.0x
Aperture diameter: 1.0 mm
Travel distance: 100 mm
Detector pixel size: 16.0 µm
5. Step-by-Step Diffractogram Formation¶
Let’s visualize each step in the formation of a diffractogram:
Linear Interaction - Light interacts with the sample
Optical Zoom - Magnification by the objective lens
Circular Aperture - Limits the numerical aperture
Fraunhofer Propagation - Far-field propagation to the camera
In [11]:
# Cut sample at center for step-by-step visualization
center_pixel = num_pixels // 2
half_size = illumination_size // 2
sample_cut = sphere_sample.sample[
center_pixel - half_size : center_pixel + half_size,
center_pixel - half_size : center_pixel + half_size,
]
sample_region = jns.utils.make_sample_function(
sample=sample_cut,
dx=pixel_size,
)
print(f"Sample region shape: {sample_region.sample.shape}")
Sample region shape: (256, 256)
In [12]:
# Step 1: Linear Interaction - Light through sample
after_sample = jns.scopes.linear_interaction(
sample=sample_region,
light=lightwave,
)
print(f"After sample field shape: {after_sample.field.shape}")
print(f"After sample dx: {after_sample.dx * 1e6:.2f} microns")
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
im0 = axes[0].imshow(jnp.abs(sample_region.sample), cmap=cmo.gray)
axes[0].set_title("Sample Region")
scalebar = ScaleBar(sample_region.dx, "m", length_fraction=0.25, color="black")
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0])
im1 = axes[1].imshow(jnp.abs(after_sample.field) ** 2, cmap=cmo.gray)
axes[1].set_title("Field Intensity After Sample")
scalebar = ScaleBar(after_sample.dx, "m", length_fraction=0.25, color="black")
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1])
im2 = axes[2].imshow(
jnp.angle(after_sample.field), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi
)
axes[2].set_title("Field Phase After Sample")
scalebar = ScaleBar(after_sample.dx, "m", length_fraction=0.25, color="black")
axes[2].add_artist(scalebar)
axes[2].axis("off")
plt.colorbar(im2, ax=axes[2])
plt.suptitle("Step 1: Linear Interaction", fontsize=14)
plt.tight_layout()
plt.show()
After sample field shape: (256, 256)
After sample dx: 0.50 microns
In [13]:
# Step 2: Optical Zoom - Magnification
zoomed_wave = jns.prop.optical_zoom(after_sample, zoom_factor)
print(f"Before zoom dx: {after_sample.dx * 1e6:.2f} microns")
print(f"After zoom dx: {zoomed_wave.dx * 1e6:.2f} microns")
print(f"Magnification achieved: {zoomed_wave.dx / after_sample.dx:.1f}x")
# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
im0 = axes[0].imshow(jnp.abs(after_sample.field) ** 2, cmap=cmo.gray)
axes[0].set_title(f"Before Zoom (dx={after_sample.dx*1e6:.2f} µm)")
scalebar = ScaleBar(after_sample.dx, "m", length_fraction=0.25, color="black")
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0])
im1 = axes[1].imshow(jnp.abs(zoomed_wave.field) ** 2, cmap=cmo.gray)
axes[1].set_title(f"After Zoom (dx={zoomed_wave.dx*1e6:.2f} µm)")
scalebar = ScaleBar(zoomed_wave.dx, "m", length_fraction=0.25, color="black")
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1])
plt.suptitle("Step 2: Optical Zoom (Magnification)", fontsize=14)
plt.tight_layout()
plt.show()
Before zoom dx: 0.50 microns
After zoom dx: 5.00 microns
Magnification achieved: 10.0x
In [14]:
# Step 3: Circular Aperture - NA Limit
after_aperture = jns.optics.circular_aperture(
zoomed_wave,
diameter=aperture_diameter,
)
print(f"Aperture diameter: {aperture_diameter * 1e3:.1f} mm")
# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
im0 = axes[0].imshow(jnp.abs(zoomed_wave.field) ** 2, cmap=cmo.gray)
axes[0].set_title("Before Aperture")
scalebar = ScaleBar(zoomed_wave.dx, "m", length_fraction=0.25, color="black")
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0])
im1 = axes[1].imshow(jnp.abs(after_aperture.field) ** 2, cmap=cmo.gray)
axes[1].set_title("After Circular Aperture")
scalebar = ScaleBar(
after_aperture.dx, "m", length_fraction=0.25, color="black"
)
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1])
plt.suptitle("Step 3: Circular Aperture", fontsize=14)
plt.tight_layout()
plt.show()
Aperture diameter: 1.0 mm
In [15]:
# Step 4: Fraunhofer Propagation - To Camera Plane
at_camera = jns.prop.fraunhofer_prop_scaled(
after_aperture, travel_distance, output_dx=detector_pixel_size
)
print(f"Propagation distance: {travel_distance * 1e3:.0f} mm")
print(f"Camera plane dx: {at_camera.dx * 1e6:.2f} microns")
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
im0 = axes[0].imshow(
jns.optics.field_intensity(at_camera.field), cmap=cmo.haline
)
axes[0].set_title("Intensity at Camera (Linear)")
scalebar = ScaleBar(at_camera.dx, "m", length_fraction=0.25, color="black")
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0])
im1 = axes[1].imshow(
jnp.log10(1 + jns.optics.field_intensity(at_camera.field)), cmap=cmo.haline
)
axes[1].set_title("Intensity at Camera (Log)")
scalebar = ScaleBar(at_camera.dx, "m", length_fraction=0.25, color="black")
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1])
im2 = axes[2].imshow(
jnp.angle(at_camera.field), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi
)
axes[2].set_title("Phase at Camera")
scalebar = ScaleBar(at_camera.dx, "m", length_fraction=0.25, color="black")
axes[2].add_artist(scalebar)
axes[2].axis("off")
plt.colorbar(im2, ax=axes[2])
plt.suptitle("Step 4: Fraunhofer Propagation to Camera", fontsize=14)
plt.tight_layout()
plt.show()
Propagation distance: 100 mm
Camera plane dx: 16.00 microns
In [16]:
# Step 4: Fraunhofer Propagation - To Camera Plane
at_camera = jns.prop.fraunhofer_prop_scaled(
after_aperture, travel_distance, output_dx=detector_pixel_size
)
at_camera_inv = jnp.fft.ifftshift(jnp.fft.ifft2(at_camera.field))
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
im0 = axes[0].imshow(
jns.optics.field_intensity(at_camera_inv), cmap=cmo.haline
)
axes[0].set_title("Intensity at Camera (Linear)")
scalebar = ScaleBar(at_camera.dx, "m", length_fraction=0.25, color="black")
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0])
im1 = axes[1].imshow(
jnp.log10(1 + jns.optics.field_intensity(at_camera_inv)), cmap=cmo.haline
)
axes[1].set_title("Intensity at Camera (Log)")
scalebar = ScaleBar(at_camera.dx, "m", length_fraction=0.25, color="black")
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1])
im2 = axes[2].imshow(
jnp.angle(at_camera_inv), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi
)
axes[2].set_title("Phase at Camera")
scalebar = ScaleBar(at_camera.dx, "m", length_fraction=0.25, color="black")
axes[2].add_artist(scalebar)
axes[2].axis("off")
plt.colorbar(im2, ax=axes[2])
plt.tight_layout()
plt.show()
6. Compare with simple_diffractogram¶
Verify that the step-by-step approach matches the simple_diffractogram function.
In [17]:
# Generate single diffractogram using the combined function
diffractogram = jns.scopes.simple_diffractogram(
sample_cut=sample_region,
lightwave=lightwave,
zoom_factor=zoom_factor,
aperture_diameter=aperture_diameter,
travel_distance=travel_distance,
camera_pixel_size=detector_pixel_size,
)
print(f"Diffractogram shape: {diffractogram.image.shape}")
print(f"Diffractogram dx: {diffractogram.dx * 1e6:.2f} µm")
Diffractogram shape: (256, 256)
Diffractogram dx: 16.00 µm
In [18]:
# Compare manual pipeline with simple_diffractogram
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# Manual pipeline result
im0 = axes[0].imshow(
jns.optics.field_intensity(at_camera.field), cmap=cmo.haline
)
axes[0].set_title("Manual Pipeline (Step by Step)")
scalebar = ScaleBar(at_camera.dx, "m", length_fraction=0.25, color="black")
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0])
# Combined function result
im1 = axes[1].imshow(diffractogram.image, cmap=cmo.haline)
axes[1].set_title("simple_diffractogram Result")
scalebar = ScaleBar(diffractogram.dx, "m", length_fraction=0.25, color="black")
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1])
plt.tight_layout()
plt.show()
# Verify they match
print(
f"Max difference: {jnp.max(jnp.abs(jns.optics.field_intensity(at_camera.field) - diffractogram.image)):.2e}"
)
Max difference: 0.00e+00
7. Full Microscope Simulation - Scanning¶
Create scan positions and run the full microscope simulation.
In [19]:
# Create scan positions centered on a region with spheres
scan_step = 8e-6 # 15 micron step size (same as USAF)
scan_pixel = scan_step / sphere_sample.dx
# Center of the sample
scope_center = jnp.array(
[num_pixels // 2, num_pixels // 2]
) # (x, y) in pixels
num_scan_x = 25
num_scan_y = 25
xx, yy = jnp.meshgrid(
jnp.arange(num_scan_x) * scan_pixel - (num_scan_x - 1) * scan_pixel / 2,
jnp.arange(num_scan_y) * scan_pixel - (num_scan_y - 1) * scan_pixel / 2,
)
x_positions = xx + scope_center[0]
y_positions = yy + scope_center[1]
positions = jnp.stack([x_positions.ravel(), y_positions.ravel()], axis=1)
print(f"Scan step: {scan_step * 1e6:.0f} µm ({scan_pixel:.1f} pixels)")
print(f"Number of scan positions: {len(positions)}")
print(f"Scan grid: {num_scan_x} x {num_scan_y}")
print(
f"Total scan area: {(num_scan_x-1) * scan_step * 1e6:.0f} x {(num_scan_y-1) * scan_step * 1e6:.0f} µm"
)
Scan step: 8 µm (16.0 pixels)
Number of scan positions: 625
Scan grid: 25 x 25
Total scan area: 192 x 192 µm
In [20]:
# Visualize scan positions on sample
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
im = ax.imshow(jnp.abs(sphere_sample.sample), cmap=cmo.gray)
ax.set_title("Sample with Scan Positions")
scalebar = ScaleBar(sphere_sample.dx, "m", length_fraction=0.25, color="black")
ax.add_artist(scalebar)
# Add scan positions as colored dots
scatter = ax.scatter(
positions[:, 0],
positions[:, 1],
c=jnp.arange(len(positions)),
cmap="coolwarm",
s=10,
alpha=0.7,
marker="o",
)
plt.colorbar(scatter, ax=ax, label="Scan position index")
ax.axis("off")
plt.tight_layout()
plt.show()
In [21]:
# Run simple_microscope with all scan positions
positions_meters = positions * sphere_sample.dx
microscope_data = jns.scopes.simple_microscope(
sample=sphere_sample,
positions=positions_meters,
lightwave=lightwave,
zoom_factor=zoom_factor,
aperture_diameter=aperture_diameter,
travel_distance=travel_distance,
camera_pixel_size=detector_pixel_size,
)
print(f"Microscope data shape: {microscope_data.image_data.shape}")
print(f"Number of diffractograms: {microscope_data.image_data.shape[0]}")
print(f"Diffractogram size: {microscope_data.image_data.shape[1:]}")
print(f"Camera pixel size: {microscope_data.dx * 1e6:.2f} µm")
Microscope data shape: (625, 256, 256)
Number of diffractograms: 625
Diffractogram size: (256, 256)
Camera pixel size: 16.00 µm
In [22]:
# Visualize a subset of diffractograms (9 evenly spaced)
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
indices = jnp.linspace(0, len(positions) - 1, 9).astype(int)
for i, ax in enumerate(axes.flat):
idx = int(indices[i])
im = ax.imshow(
jnp.log10(microscope_data.image_data[idx] + 1), cmap=cmo.haline
)
pos = positions[idx]
ax.set_title(f"Pos {idx}: ({pos[0]:.0f}, {pos[1]:.0f}) px")
scalebar = ScaleBar(
microscope_data.dx, "m", length_fraction=0.25, color="black"
)
ax.add_artist(scalebar)
ax.axis("off")
plt.suptitle(
"Selected Diffractograms from Spheres Sample (Log Scale)", fontsize=14
)
plt.tight_layout()
plt.show()
In [23]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
axes[0].imshow(jnp.log10(1 + microscope_data.image_data[64]), cmap=cmo.haline)
axes[1].imshow(jnp.log10(1 + microscope_data.image_data[65]), cmap=cmo.haline)
axes[2].imshow(jnp.log10(1 + microscope_data.image_data[66]), cmap=cmo.haline)
plt.show()
print("Diffractogram at pos 0 sum:", microscope_data.image_data[64].sum())
print("Diffractogram at pos 1 sum:", microscope_data.image_data[65].sum())
print(
"Are they identical?",
jnp.allclose(
microscope_data.image_data[64], microscope_data.image_data[65]
),
)
Diffractogram at pos 0 sum: 422.4509078399198
Diffractogram at pos 1 sum: 455.16915398686547
Are they identical? False
In [24]:
print("Experimental data shape:", microscope_data.image_data.shape)
print("Experimental positions shape:", microscope_data.positions.shape)
print("First few positions (meters):", microscope_data.positions[:5])
print(
"Position range X:",
microscope_data.positions[:, 0].min(),
"to",
microscope_data.positions[:, 0].max(),
)
print(
"Position range Y:",
microscope_data.positions[:, 1].min(),
"to",
microscope_data.positions[:, 1].max(),
)
Experimental data shape: (625, 256, 256)
Experimental positions shape: (625, 2)
First few positions (meters): [[0.000928 0.000928]
[0.000936 0.000928]
[0.000944 0.000928]
[0.000952 0.000928]
[0.00096 0.000928]]
Position range X: 0.000928 to 0.00112
Position range Y: 0.000928 to 0.00112
8. Ptychographic Reconstruction¶
Now we’ll use the ptychography algorithm to reconstruct both the sample and probe from the diffractogram data.
In [25]:
# Create ptychography parameters (optimization only)
ptycho_params = jns.utils.make_ptychography_params(
camera_pixel_size=detector_pixel_size,
num_iterations=40,
learning_rate=1e-4,
loss_type=0, # 0=mse, 1=mae, 2=poisson
optimizer_type=0, # 0=adam, 1=adagrad, 2=rmsprop, 3=sgd
)
print(f"Camera pixel size: {ptycho_params.camera_pixel_size * 1e6:.1f} µm")
print(f"Learning rate: {ptycho_params.learning_rate}")
print(f"Num iterations: {ptycho_params.num_iterations}")
print(f"Loss type: {ptycho_params.loss_type} (0=mse)")
print(f"Optimizer type: {ptycho_params.optimizer_type} (0=adam)")
Camera pixel size: 16.0 µm
Learning rate: 0.0001
Num iterations: 40
Loss type: 0 (0=mse)
Optimizer type: 0 (0=adam)
In [26]:
# Initialize reconstruction by running microscope model in reverse
initial_reconstruction = jns.invert.init_simple_microscope(
experimental_data=microscope_data,
probe_lightwave=lightwave,
zoom_factor=zoom_factor,
aperture_diameter=aperture_diameter,
travel_distance=travel_distance,
camera_pixel_size=detector_pixel_size,
)
print(f"Initialized reconstruction!")
print(f"Initial sample shape: {initial_reconstruction.sample.sample.shape}")
print(
f"Translated positions shape: {initial_reconstruction.translated_positions.shape}"
)
print(f"Initial MSE loss: {initial_reconstruction.losses[0, 1]:.6f}")
Initialized reconstruction!
Initial sample shape: (896, 896)
Translated positions shape: (625, 2)
Initial MSE loss: 0.234304
In [27]:
# Visualize initial reconstruction (sanity check before optimization)
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
init_amp = jnp.abs(initial_reconstruction.sample.sample)
init_phase = jnp.angle(initial_reconstruction.sample.sample)
# Amplitude
im0 = axes[0].imshow(init_amp, cmap=cmo.gray)
axes[0].set_title("Initial Sample - Amplitude")
scalebar = ScaleBar(
initial_reconstruction.sample.dx, "m", length_fraction=0.25, color="black"
)
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0], label="Amplitude")
# Phase
im1 = axes[1].imshow(init_phase, cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi)
axes[1].set_title("Initial Sample - Phase")
scalebar = ScaleBar(
initial_reconstruction.sample.dx, "m", length_fraction=0.25, color="black"
)
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1], label="Phase (rad)")
plt.suptitle(
f"Initial Reconstruction (MSE: {initial_reconstruction.losses[0, 1]:.4f})",
fontsize=14,
)
plt.tight_layout()
plt.show()
In [28]:
center = jnp.astype(0.5 * jnp.asarray(init_amp.shape), jnp.int32)
pad = jnp.astype(
jnp.amin(
initial_reconstruction.translated_positions
/ initial_reconstruction.sample.dx
),
jnp.int32,
)
print(
f"The intensity at the center of the initial reconstruction is {init_amp[center[0]+2, center[1]+2]:.2e}"
)
print(f"The intensity at the top left corner is {init_amp[pad, pad]:.2e}")
The intensity at the center of the initial reconstruction is 6.49e-01
The intensity at the top left corner is 6.10e-01
In [29]:
# Run ptychographic reconstruction
reconstruction = jns.invert.simple_microscope_ptychography(
experimental_data=microscope_data,
reconstruction=initial_reconstruction,
params=ptycho_params,
)
print(f"Reconstruction complete!")
print(f"Reconstructed sample shape: {reconstruction.sample.sample.shape}")
print(f"Reconstructed probe shape: {reconstruction.lightwave.field.shape}")
print(f"Final loss: {reconstruction.losses[-1, 1]:.6f}")
Reconstruction complete!
Reconstructed sample shape: (896, 896)
Reconstructed probe shape: (256, 256)
Final loss: 0.234056
Visualize Reconstruction Results¶
In [30]:
reconstruction.sample.sample.shape
Out [30]:
(896, 896)
In [31]:
# Visualize reconstructed sample
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
recon_amp = jnp.abs(reconstruction.sample.sample)
recon_phase = jnp.angle(reconstruction.sample.sample)
# Amplitude
im0 = axes[0].imshow(
recon_amp, vmin=recon_amp.min(), vmax=recon_amp.max(), cmap=cmo.gray
)
axes[0].set_title("Reconstructed Sample - Amplitude")
scalebar = ScaleBar(
reconstruction.sample.dx, "m", length_fraction=0.25, color="black"
)
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0], label="Amplitude")
# Phase
im1 = axes[1].imshow(recon_phase, cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi)
axes[1].set_title("Reconstructed Sample - Phase")
scalebar = ScaleBar(
reconstruction.sample.dx, "m", length_fraction=0.25, color="black"
)
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1], label="Phase (rad)")
plt.suptitle("Ptychographic Reconstruction - Sample", fontsize=14)
plt.tight_layout()
plt.show()
In [32]:
# Visualize reconstructed probe
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
probe_amp = jnp.abs(reconstruction.lightwave.field)
probe_phase = jnp.angle(reconstruction.lightwave.field)
# Amplitude
im0 = axes[0].imshow(probe_amp, cmap=cmo.haline)
axes[0].set_title("Reconstructed Probe - Amplitude")
scalebar = ScaleBar(
reconstruction.lightwave.dx, "m", length_fraction=0.25, color="black"
)
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0], label="Amplitude")
# Phase
im1 = axes[1].imshow(probe_phase, cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi)
axes[1].set_title("Reconstructed Probe - Phase")
scalebar = ScaleBar(
reconstruction.lightwave.dx, "m", length_fraction=0.25, color="black"
)
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1], label="Phase (rad)")
plt.suptitle("Ptychographic Reconstruction - Probe/Lightwave", fontsize=14)
plt.tight_layout()
plt.show()
In [33]:
# Visualize evolution of loss
plt.plot(reconstruction.losses[:, 0], reconstruction.losses[:, 1])
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Ptychography Convergence")
Out [33]:
Text(0.5, 1.0, 'Ptychography Convergence')
In [34]:
# Print final optimized parameters
print("Final Optimized Parameters:")
print(f" Zoom factor: {reconstruction.zoom_factor:.4f} (true: {zoom_factor})")
print(
f" Aperture diameter: {reconstruction.aperture_diameter * 1e3:.4f} mm (true: {aperture_diameter * 1e3:.1f} mm)"
)
print(
f" Travel distance: {reconstruction.travel_distance * 1e3:.4f} mm (true: {travel_distance * 1e3:.0f} mm)"
)
if reconstruction.aperture_center is not None:
print(f" Aperture center: {reconstruction.aperture_center}")
Final Optimized Parameters:
Zoom factor: 10.0000 (true: 10.0)
Aperture diameter: 1.0000 mm (true: 1.0 mm)
Travel distance: 100.0000 mm (true: 100 mm)
Aperture center: [0. 0.]
In [35]:
print(
"Sample amplitude range:",
jnp.abs(reconstruction.sample.sample).min(),
jnp.abs(reconstruction.sample.sample).max(),
)
print(
"Sample phase range:",
jnp.angle(reconstruction.sample.sample).min(),
jnp.angle(reconstruction.sample.sample).max(),
)
print(
"Probe amplitude range:",
jnp.abs(reconstruction.lightwave.field).min(),
jnp.abs(reconstruction.lightwave.field).max(),
)
print(
"Probe phase range:",
jnp.angle(reconstruction.lightwave.field).min(),
jnp.angle(reconstruction.lightwave.field).max(),
)
Sample amplitude range: 0.0 7.887381589954676
Sample phase range: -3.141566730876553 3.1415664897289357
Probe amplitude range: 1.0 1.0
Probe phase range: 0.0 0.0
8. ePIE Reconstruction¶
Now let’s run the extended PIE (ePIE) algorithm on the data. We run ePIE first as it’s often faster to converge for ptychography problems.
In [36]:
# ePIE parameters - using dedicated EpieParams type
# effective_dx: desired reconstruction pixel size (user choice)
# This determines the resolution of the reconstructed sample
epie_params = jns.utils.make_epie_params(
effective_dx=0.5e-6, # 1.6 µm pixels (user choice for resolution)
num_iterations=60, # ePIE iterations (sweeps over all positions)
alpha=1.0, # Object update step size (0.5-1.0)
beta=0.0, # Probe update step size (0 = freeze probe)
padding=32, # Extra padding around scan region
)
print(f"ePIE Parameters:")
print(f" Effective dx: {float(epie_params.effective_dx) * 1e6:.2f} µm")
print(f" Num iterations: {int(epie_params.num_iterations)}")
print(f" Alpha (object step): {float(epie_params.alpha)}")
print(f" Beta (probe step): {float(epie_params.beta)}")
print(f" Padding: {int(epie_params.padding)} pixels")
ePIE Parameters:
Effective dx: 0.50 µm
Num iterations: 60
Alpha (object step): 1.0
Beta (probe step): 0.0
Padding: 32 pixels
In [37]:
# Initialize ePIE data directly using init_simple_epie
# This preprocesses the data: rescales camera images, creates probe with aperture,
# and converts positions to pixels centered at (0,0)
epie_data = jns.invert.init_simple_epie(
experimental_data=microscope_data,
effective_dx=float(epie_params.effective_dx),
wavelength=wavelength,
zoom_factor=zoom_factor,
aperture_diameter=aperture_diameter,
travel_distance=travel_distance,
camera_pixel_size=detector_pixel_size,
padding=int(epie_params.padding),
)
print(f"ePIE Data initialized:")
print(f" Sample shape: {epie_data.sample.shape}")
print(f" Probe shape: {epie_data.probe.shape}")
print(f" Diffraction patterns shape: {epie_data.diffraction_patterns.shape}")
print(f" Positions shape: {epie_data.positions.shape}")
print(f" Effective dx: {float(epie_data.effective_dx) * 1e6:.2f} µm")
print(
f" Positions centered at: ({float(jnp.mean(epie_data.positions[:, 0])):.1f}, {float(jnp.mean(epie_data.positions[:, 1])):.1f}) px"
)
ePIE Data initialized:
Sample shape: (649, 649)
Probe shape: (649, 649)
Diffraction patterns shape: (625, 649, 649)
Positions shape: (625, 2)
Effective dx: 0.50 µm
Positions centered at: (0.0, 0.0) px
In [38]:
# Run ePIE core algorithm (pure pixel-space computation)
from janssen.invert.ptychography import _sm_epie_core
iterations = jnp.arange(int(epie_params.num_iterations), dtype=jnp.int64)
epie_result = _sm_epie_core(
epie_data=epie_data,
iterations=iterations,
alpha=float(epie_params.alpha),
beta=float(epie_params.beta),
)
epie_result.sample.block_until_ready()
print(f"ePIE Reconstruction complete!")
print(f"Final sample shape: {epie_result.sample.shape}")
print(f"Final probe shape: {epie_result.probe.shape}")
ePIE Reconstruction complete!
Final sample shape: (649, 649)
Final probe shape: (649, 649)
In [39]:
# Visualize ePIE reconstructed sample
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
pad = 70
epie_amp = jnp.abs(epie_result.sample[pad:-pad, pad:-pad])
epie_phase = jnp.angle(epie_result.sample[pad:-pad, pad:-pad])
# Amplitude
im0 = axes[0].imshow(
epie_amp, vmin=epie_amp.min(), vmax=epie_amp.max(), cmap=cmo.gray
)
axes[0].set_title("ePIE Reconstructed Sample - Amplitude")
scalebar = ScaleBar(
float(epie_params.effective_dx), "m", length_fraction=0.25, color="black"
)
axes[0].add_artist(scalebar)
axes[0].axis("off")
plt.colorbar(im0, ax=axes[0], label="Amplitude")
# Phase
im1 = axes[1].imshow(
epie_phase, cmap=cmo.phase, vmin=epie_phase.min(), vmax=epie_phase.max()
)
axes[1].set_title("ePIE Reconstructed Sample - Phase")
scalebar = ScaleBar(
float(epie_params.effective_dx), "m", length_fraction=0.25, color="black"
)
axes[1].add_artist(scalebar)
axes[1].axis("off")
plt.colorbar(im1, ax=axes[1], label="Phase (rad)")
plt.suptitle("ePIE Ptychographic Reconstruction - Sample", fontsize=14)
plt.tight_layout()
plt.show()
Compare Gradient-Based vs ePIE Reconstruction¶
In [40]:
# Convergence plot for gradient-based method only
# Note: ePIE doesn't track per-iteration losses in this implementation
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
# Gradient-based (Adam)
ax.plot(
reconstruction.losses[:, 0],
reconstruction.losses[:, 1],
"b-",
linewidth=2,
label=f"Adam ({int(ptycho_params.num_iterations)} iters)",
)
ax.set_xlabel("Iteration", fontsize=12)
ax.set_ylabel("MSE Loss", fontsize=12)
ax.set_title("Reconstruction Convergence: Gradient-Based (Adam)", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_yscale("log")
plt.tight_layout()
plt.show()
print(
f"\nFinal Gradient-Based (Adam) Loss: {reconstruction.losses[-1, 1]:.6f}"
)
Final Gradient-Based (Adam) Loss: 0.234056
In [41]:
# Side-by-side comparison of reconstructed samples
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
crop = slice(192, -192)
# Gradient-based amplitude
im00 = axes[0, 0].imshow(
jnp.abs(reconstruction.sample.sample)[crop, crop], cmap=cmo.gray
)
axes[0, 0].set_title("Gradient (Adam) - Amplitude")
scalebar = ScaleBar(
reconstruction.sample.dx, "m", length_fraction=0.25, color="black"
)
axes[0, 0].add_artist(scalebar)
axes[0, 0].axis("off")
plt.colorbar(im00, ax=axes[0, 0])
# ePIE amplitude - note: epie_result.sample is the raw array, use epie_params.effective_dx for scale
im01 = axes[0, 1].imshow(jnp.abs(epie_result.sample), cmap=cmo.gray)
axes[0, 1].set_title("ePIE - Amplitude")
scalebar = ScaleBar(
float(epie_params.effective_dx), "m", length_fraction=0.25, color="black"
)
axes[0, 1].add_artist(scalebar)
axes[0, 1].axis("off")
plt.colorbar(im01, ax=axes[0, 1])
# Gradient-based phase
im10 = axes[1, 0].imshow(
jnp.angle(reconstruction.sample.sample)[crop, crop],
cmap=cmo.phase,
vmin=-jnp.pi,
vmax=jnp.pi,
)
axes[1, 0].set_title("Gradient (Adam) - Phase")
scalebar = ScaleBar(
reconstruction.sample.dx, "m", length_fraction=0.25, color="black"
)
axes[1, 0].add_artist(scalebar)
axes[1, 0].axis("off")
plt.colorbar(im10, ax=axes[1, 0])
# ePIE phase
im11 = axes[1, 1].imshow(
jnp.angle(epie_result.sample), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi
)
axes[1, 1].set_title("ePIE - Phase")
scalebar = ScaleBar(
float(epie_params.effective_dx), "m", length_fraction=0.25, color="black"
)
axes[1, 1].add_artist(scalebar)
axes[1, 1].axis("off")
plt.colorbar(im11, ax=axes[1, 1])
plt.suptitle(
"Sample Reconstruction Comparison: Gradient-Based vs ePIE", fontsize=14
)
plt.tight_layout()
plt.show()
print(f"\nePIE sample shape: {epie_result.sample.shape}")
print(f"Adam sample shape: {reconstruction.sample.sample.shape}")
ePIE sample shape: (649, 649)
Adam sample shape: (896, 896)
In [42]:
# Compare reconstructed probes
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
# Gradient-based probe amplitude
im00 = axes[0, 0].imshow(
jnp.abs(reconstruction.lightwave.field), cmap=cmo.haline
)
axes[0, 0].set_title("Gradient (Adam) - Probe Amplitude")
scalebar = ScaleBar(
reconstruction.lightwave.dx, "m", length_fraction=0.25, color="black"
)
axes[0, 0].add_artist(scalebar)
axes[0, 0].axis("off")
plt.colorbar(im00, ax=axes[0, 0])
# ePIE probe amplitude - note: epie_result.probe is the raw array
im01 = axes[0, 1].imshow(jnp.abs(epie_result.probe), cmap=cmo.haline)
axes[0, 1].set_title("ePIE - Probe Amplitude")
scalebar = ScaleBar(
float(epie_params.effective_dx), "m", length_fraction=0.25, color="black"
)
axes[0, 1].add_artist(scalebar)
axes[0, 1].axis("off")
plt.colorbar(im01, ax=axes[0, 1])
# Gradient-based probe phase
im10 = axes[1, 0].imshow(
jnp.angle(reconstruction.lightwave.field),
cmap=cmo.phase,
vmin=-jnp.pi,
vmax=jnp.pi,
)
axes[1, 0].set_title("Gradient (Adam) - Probe Phase")
scalebar = ScaleBar(
reconstruction.lightwave.dx, "m", length_fraction=0.25, color="black"
)
axes[1, 0].add_artist(scalebar)
axes[1, 0].axis("off")
plt.colorbar(im10, ax=axes[1, 0])
# ePIE probe phase
im11 = axes[1, 1].imshow(
jnp.angle(epie_result.probe), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi
)
axes[1, 1].set_title("ePIE - Probe Phase")
scalebar = ScaleBar(
float(epie_params.effective_dx), "m", length_fraction=0.25, color="black"
)
axes[1, 1].add_artist(scalebar)
axes[1, 1].axis("off")
plt.colorbar(im11, ax=axes[1, 1])
plt.suptitle(
"Probe Reconstruction Comparison: Gradient-Based vs ePIE", fontsize=14
)
plt.tight_layout()
plt.show()
print(f"\nePIE probe shape: {epie_result.probe.shape}")
print(f"Adam probe shape: {reconstruction.lightwave.field.shape}")
print(
f"\nNote: ePIE probe update controlled by beta={float(epie_params.beta)} (0=frozen)."
)
ePIE probe shape: (649, 649)
Adam probe shape: (256, 256)
Note: ePIE probe update controlled by beta=0.0 (0=frozen).