Package Organization

Overview

Janssen is a focused library for optical microscopy and ptychography, split from the original Ptyrodactyl project. The package is organized into four main modules: utils for common utilities, simul for forward models, lenses for lens implementations, and invert for reconstruction algorithms. It follows a clean, hierarchical structure optimized for optical microscopy applications.

Module Structure

janssen.utils

Common utilities and shared data structures used throughout the package.

Key Components:

  • Data Types & Structures

    • Type definitions and common data structures

    • Decorators for JAX transformations

    • Shared utility functions

janssen.simul

The forward simulation module for optical microscopy, providing differentiable implementations of optical elements and propagation.

Key Components:

  • Optical Elements

    • apertures.py: Circular, rectangular, and custom aperture functions

    • elements.py: Optical element transformations (beam splitters, waveplates)

    • lenses.py: Lens implementations and phase transformations

    • lens_optics.py: Physical lens calculations (thickness, phase profiles)

  • Propagation & Simulation

    • microscope.py: Microscopy simulation pipelines

    • helper.py: Helper functions for optical propagation

    • Fresnel and angular spectrum propagation methods

    • Wavefront manipulation utilities

janssen.lenses

Dedicated module for lens implementations and optical calculations.

Key Components:

  • Lens phase transformations

  • Physical lens calculations

  • Aberration modeling

  • Optical element interfaces

janssen.invert

The reconstruction module containing phase retrieval algorithms and optimization routines.

Key Components:

  • Phase Retrieval Algorithms

    • ptychography.py: Ptychographic reconstruction algorithms

    • engine.py: Core reconstruction engine

    • Single-slice and multi-slice ptychography

    • Position-corrected algorithms

    • Multi-modal probe reconstruction

  • Optimization

    • optimizers.py: Complex-valued optimizers with Wirtinger derivatives

    • loss_functions.py: Loss functions for phase retrieval

    • ADAM, AdaGrad, RMSProp, and SGD implementations

    • Learning rate scheduling

Design Principles

1. JAX-First Architecture

All functions are designed to be:

  • Differentiable: Full support for jax.grad

  • JIT-compilable: Optimized with jax.jit

  • Vectorizable: Compatible with jax.vmap

  • Device-agnostic: Run on CPU, GPU, or TPU

2. Type Safety

  • Comprehensive type hints using jaxtyping

  • Runtime type checking with beartype

  • Clear array shape specifications

3. Functional Programming

  • Pure functions without side effects

  • Immutable data structures

  • Composable operations

4. Optical Focus

Optimized specifically for optical microscopy:

  • Wavelength-dependent calculations

  • Complex wavefront representations

  • Physical optics simulations

File Organization

The package structure is organized for clarity and maintainability:

src/janssen/
├── __init__.py           # Top-level exports
├── utils/
│   ├── __init__.py       # Utils module exports
│   ├── types.py          # Shared type definitions
│   └── decorators.py     # JAX decorators and utilities
├── simul/
│   ├── __init__.py       # Simulation module exports
│   ├── apertures.py      # Aperture functions
│   ├── elements.py       # Optical elements
│   ├── helper.py         # Utility functions
│   └── microscope.py     # Microscopy simulations
├── lenses/
│   ├── __init__.py       # Lenses module exports
│   ├── lens_optics.py    # Lens calculations
│   └── lenses.py         # Lens implementations
└── invert/
    ├── __init__.py       # Invert module exports
    ├── engine.py         # Reconstruction engine
    ├── ptychography.py   # Ptychographic algorithms
    ├── optimizers.py     # Optimization routines
    └── loss_functions.py # Loss function definitions

Import Patterns

Public API Usage

Users should import from the three main modules:

# Import from main modules
from janssen.simul import microscope, apertures, elements
from janssen.lenses import lenses, lens_optics
from janssen.invert import ptychography, optimizers
from janssen.utils import types, decorators

# Import entire modules
import janssen.simul as sim
import janssen.lenses as lens
import janssen.invert as inv
import janssen.utils as utils

Internal Implementation

The __init__.py files handle internal imports and expose a clean API:

# simul/__init__.py example
from .apertures import (
    circular_aperture, rectangular_aperture
)
from .elements import (
    apply_aperture, apply_beam_splitter
)
from .microscope import (
    simple_microscope, scanning_microscope
)
# ... etc

Best Practices

1. Use JAX Transformations

Leverage JAX’s powerful transformations:

# JIT compilation for performance
@jax.jit
def simulate(wavefront, sample):
    return microscope.forward_model(wavefront, sample)

2. Automatic Differentiation

# Automatic differentiation for optimization
grad_fn = jax.grad(loss_function)

# Vectorization
batched_simulate = jax.vmap(simulate, in_axes=(0, None))
batched_stem = jax.vmap(stem_4d, in_axes=(None, None, 0))

3. Type Annotations

Use type hints for clarity:

from jaxtyping import Float, Complex

def propagate_wavefront(
    field: Complex[Array, "H W"],
    distance: float,
    wavelength: float
) -> Complex[Array, "H W"]:
    return fresnel_propagate(field, distance, wavelength)

4. Composable Operations

Build complex operations from simple functions:

# Compose multiple operations
def full_reconstruction(raw_data, initial_guess):
    # Apply forward model
    simulated = simul.microscope(initial_guess, probe)
    
    # Reconstruct using ptychography
    result = invert.ptychography(raw_data, initial_guess)
    
    return result

Performance Considerations

Memory Management

  • Use jax.checkpoint for memory-intensive operations

  • Leverage jax.lax.scan for sequential operations

  • Prefer in-place updates with .at[].set() for large arrays

Parallelization

  • Use jax.pmap for data parallelism

  • Implement sharding strategies for large datasets

  • Utilize device mesh for distributed computing

Optimization

  • JIT-compile hot paths

  • Batch operations with vmap

  • Use appropriate precision (float32 vs float64)

Extension Points

The package is designed to be extensible:

  1. Custom Loss Functions: Implement new loss functions following the pattern in invert.loss_functions

  2. New Optimizers: Add optimizers with Wirtinger derivative support

  3. Additional Reconstructions: Build on base reconstruction algorithms in invert.ptychography

  4. Custom Optical Elements: Add new elements in simul.elements

  5. Custom Workflows: Combine existing functions for specific use cases

Dependencies

Core Dependencies

  • JAX: Automatic differentiation and JIT compilation

  • NumPy: Array operations (via JAX)

  • jaxtyping: Type annotations for JAX arrays

  • beartype: Runtime type checking

Optional Dependencies

  • matplotlib: Visualization (for examples)

  • scipy: Additional scientific computing tools

  • h5py: HDF5 file I/O

Future Directions

The package architecture supports future extensions:

  • Advanced ptychographic reconstruction algorithms

  • GPU-optimized optical propagation kernels

  • Real-time microscopy processing pipelines

  • Integration with experimental microscopy data formats

  • Machine learning-enhanced phase retrieval

  • Adaptive optics simulations

  • Coherent diffractive imaging techniques