Package Organization¶
Overview¶
Janssen is a focused library for optical microscopy and ptychography, split from the original Ptyrodactyl project. The package is organized into four main modules: utils for common utilities, simul for forward models, lenses for lens implementations, and invert for reconstruction algorithms. It follows a clean, hierarchical structure optimized for optical microscopy applications.
Module Structure¶
janssen.utils¶
Common utilities and shared data structures used throughout the package.
Key Components:¶
Data Types & Structures
Type definitions and common data structures
Decorators for JAX transformations
Shared utility functions
janssen.simul¶
The forward simulation module for optical microscopy, providing differentiable implementations of optical elements and propagation.
Key Components:¶
Optical Elements
apertures.py: Circular, rectangular, and custom aperture functionselements.py: Optical element transformations (beam splitters, waveplates)lenses.py: Lens implementations and phase transformationslens_optics.py: Physical lens calculations (thickness, phase profiles)
Propagation & Simulation
microscope.py: Microscopy simulation pipelineshelper.py: Helper functions for optical propagationFresnel and angular spectrum propagation methods
Wavefront manipulation utilities
janssen.lenses¶
Dedicated module for lens implementations and optical calculations.
Key Components:¶
Lens phase transformations
Physical lens calculations
Aberration modeling
Optical element interfaces
janssen.invert¶
The reconstruction module containing phase retrieval algorithms and optimization routines.
Key Components:¶
Phase Retrieval Algorithms
ptychography.py: Ptychographic reconstruction algorithmsengine.py: Core reconstruction engineSingle-slice and multi-slice ptychography
Position-corrected algorithms
Multi-modal probe reconstruction
Optimization
optimizers.py: Complex-valued optimizers with Wirtinger derivativesloss_functions.py: Loss functions for phase retrievalADAM, AdaGrad, RMSProp, and SGD implementations
Learning rate scheduling
Design Principles¶
1. JAX-First Architecture¶
All functions are designed to be:
Differentiable: Full support for
jax.gradJIT-compilable: Optimized with
jax.jitVectorizable: Compatible with
jax.vmapDevice-agnostic: Run on CPU, GPU, or TPU
2. Type Safety¶
Comprehensive type hints using
jaxtypingRuntime type checking with
beartypeClear array shape specifications
3. Functional Programming¶
Pure functions without side effects
Immutable data structures
Composable operations
4. Optical Focus¶
Optimized specifically for optical microscopy:
Wavelength-dependent calculations
Complex wavefront representations
Physical optics simulations
File Organization¶
The package structure is organized for clarity and maintainability:
src/janssen/
├── __init__.py # Top-level exports
├── utils/
│ ├── __init__.py # Utils module exports
│ ├── types.py # Shared type definitions
│ └── decorators.py # JAX decorators and utilities
├── simul/
│ ├── __init__.py # Simulation module exports
│ ├── apertures.py # Aperture functions
│ ├── elements.py # Optical elements
│ ├── helper.py # Utility functions
│ └── microscope.py # Microscopy simulations
├── lenses/
│ ├── __init__.py # Lenses module exports
│ ├── lens_optics.py # Lens calculations
│ └── lenses.py # Lens implementations
└── invert/
├── __init__.py # Invert module exports
├── engine.py # Reconstruction engine
├── ptychography.py # Ptychographic algorithms
├── optimizers.py # Optimization routines
└── loss_functions.py # Loss function definitions
Import Patterns¶
Public API Usage¶
Users should import from the three main modules:
# Import from main modules
from janssen.simul import microscope, apertures, elements
from janssen.lenses import lenses, lens_optics
from janssen.invert import ptychography, optimizers
from janssen.utils import types, decorators
# Import entire modules
import janssen.simul as sim
import janssen.lenses as lens
import janssen.invert as inv
import janssen.utils as utils
Internal Implementation¶
The __init__.py files handle internal imports and expose a clean API:
# simul/__init__.py example
from .apertures import (
circular_aperture, rectangular_aperture
)
from .elements import (
apply_aperture, apply_beam_splitter
)
from .microscope import (
simple_microscope, scanning_microscope
)
# ... etc
Best Practices¶
1. Use JAX Transformations¶
Leverage JAX’s powerful transformations:
# JIT compilation for performance
@jax.jit
def simulate(wavefront, sample):
return microscope.forward_model(wavefront, sample)
2. Automatic Differentiation¶
# Automatic differentiation for optimization
grad_fn = jax.grad(loss_function)
# Vectorization
batched_simulate = jax.vmap(simulate, in_axes=(0, None))
batched_stem = jax.vmap(stem_4d, in_axes=(None, None, 0))
3. Type Annotations¶
Use type hints for clarity:
from jaxtyping import Float, Complex
def propagate_wavefront(
field: Complex[Array, "H W"],
distance: float,
wavelength: float
) -> Complex[Array, "H W"]:
return fresnel_propagate(field, distance, wavelength)
4. Composable Operations¶
Build complex operations from simple functions:
# Compose multiple operations
def full_reconstruction(raw_data, initial_guess):
# Apply forward model
simulated = simul.microscope(initial_guess, probe)
# Reconstruct using ptychography
result = invert.ptychography(raw_data, initial_guess)
return result
Performance Considerations¶
Memory Management¶
Use
jax.checkpointfor memory-intensive operationsLeverage
jax.lax.scanfor sequential operationsPrefer in-place updates with
.at[].set()for large arrays
Parallelization¶
Use
jax.pmapfor data parallelismImplement sharding strategies for large datasets
Utilize device mesh for distributed computing
Optimization¶
JIT-compile hot paths
Batch operations with
vmapUse appropriate precision (float32 vs float64)
Extension Points¶
The package is designed to be extensible:
Custom Loss Functions: Implement new loss functions following the pattern in
invert.loss_functionsNew Optimizers: Add optimizers with Wirtinger derivative support
Additional Reconstructions: Build on base reconstruction algorithms in
invert.ptychographyCustom Optical Elements: Add new elements in
simul.elementsCustom Workflows: Combine existing functions for specific use cases
Dependencies¶
Core Dependencies¶
JAX: Automatic differentiation and JIT compilation
NumPy: Array operations (via JAX)
jaxtyping: Type annotations for JAX arrays
beartype: Runtime type checking
Optional Dependencies¶
matplotlib: Visualization (for examples)
scipy: Additional scientific computing tools
h5py: HDF5 file I/O
Future Directions¶
The package architecture supports future extensions:
Advanced ptychographic reconstruction algorithms
GPU-optimized optical propagation kernels
Real-time microscopy processing pipelines
Integration with experimental microscopy data formats
Machine learning-enhanced phase retrieval
Adaptive optics simulations
Coherent diffractive imaging techniques