Source code for janssen.invert.optimizers

"""
Complex-valued optimizers with Wirtinger derivatives for ptychography.

Extended Summary
----------------
This module implements complex-valued optimization algorithms including Adam,
Adagrad, and RMSprop using Wirtinger calculus. It also provides learning rate
schedulers for training optimization. All functions are JAX-compatible and
support automatic differentiation.

Routine Listings
----------------
LRSchedulerState : class
    State maintained by learning rate schedulers
Optimizer : class
    Optimizer configuration with init and update functions
create_cosine_scheduler : function, scheduler
    Creates a cosine learning rate scheduler with smooth decay
create_step_scheduler : function, scheduler
    Creates a step decay scheduler with periodic learning rate drops
create_warmup_cosine_scheduler : function, scheduler
    Creates a scheduler with linear warmup followed by cosine decay
init_scheduler_state : function, scheduler
    Initialize scheduler state with given learning rate
wirtinger_grad : function
    Compute the Wirtinger gradient of a complex-valued function
complex_adam : function, optimizer
    Complex-valued Adam optimizer based on Wirtinger derivatives
complex_adagrad : function, optimizer
    Complex-valued Adagrad optimizer based on Wirtinger derivatives
complex_rmsprop : function, optimizer
    Complex-valued RMSprop optimizer based on Wirtinger derivatives
init_adam : function, initializer
    Initialize Adam optimizer state
init_adagrad : function, initializer
    Initialize Adagrad optimizer state
init_rmsprop : function, initializer
    Initialize RMSprop optimizer state
adam_update : function, updater
    Update parameters using Adam optimizer with Wirtinger derivatives
adagrad_update : function, updater
    Update parameters using Adagrad optimizer with Wirtinger derivatives
rmsprop_update : function, updater
    Update parameters using RMSprop optimizer with Wirtinger derivatives

Notes
-----
All optimizers use Wirtinger calculus for proper handling of complex-valued
parameters. The Wirtinger derivative is defined as ∂f/∂z = ½(∂f/∂x - i∂f/∂y).
All functions are designed to work with JAX transformations including jit,
grad, and vmap.
"""

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import (
    Any,
    Callable,
    NamedTuple,
    Optional,
    Sequence,
    Tuple,
    Union,
)
from jaxtyping import Array, Complex, Float, jaxtyped

from janssen.utils import OptimizerState, make_optimizer_state


class LRSchedulerState(NamedTuple):
    """State maintained by learning rate schedulers.

    Attributes
    ----------
    step : int
        Current optimization step
    learning_rate : float
        Current learning rate
    initial_lr : float
        Initial learning rate value
    """

    step: int
    learning_rate: float
    initial_lr: float


SchedulerFn = Callable[[LRSchedulerState], Tuple[float, LRSchedulerState]]


def create_cosine_scheduler(
    total_steps: int,
    final_lr_factor: Optional[float] = 0.01,
) -> SchedulerFn:
    """Create a cosine learning rate scheduler.

    This scheduler implements a cosine annealing schedule that smoothly
    decreases the learning rate from the initial value to a final value
    over the specified number of steps.

    Parameters
    ----------
    total_steps : int
        Total number of optimization steps
    final_lr_factor : float, optional
        Final learning rate as a fraction of initial learning rate.
        Default is 0.01.

    Returns
    -------
    scheduler_fn : SchedulerFn
        A function that takes the current scheduler state and returns
        the new learning rate and updated state.

    Notes
    -----
    Algorithm:
    - Calculate progress as min(step / total_steps, 1.0)
    - Compute cosine decay factor using 0.5 * (1 + cos(π * progress))
    - Calculate new learning rate using linear interpolation
    - Update scheduler state with new step and learning rate
    - Return new learning rate and updated state
    """

    @jax.jit
    def scheduler_fn(
        state: LRSchedulerState,
    ) -> Tuple[float, LRSchedulerState]:
        progress = jnp.minimum(state.step / total_steps, 1.0)
        cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * progress))
        lr = state.initial_lr * (
            final_lr_factor + (1 - final_lr_factor) * cosine_decay
        )
        new_state = LRSchedulerState(
            step=state.step + 1, learning_rate=lr, initial_lr=state.initial_lr
        )
        return lr, new_state

    return scheduler_fn


def create_step_scheduler(step_size: int, gamma: float = 0.1) -> SchedulerFn:
    """Create a step decay scheduler.

    This creates a step decay scheduler that reduces learning rate by gamma
    every step_size steps. This scheduler implements a step-wise learning
    rate decay where the
    learning rate is multiplied by gamma every step_size steps.

    Parameters
    ----------
    step_size : int
        Number of steps between learning rate drops
    gamma : float
        Multiplicative factor for learning rate decay.
        Default is 0.1.

    Returns
    -------
    scheduler_fn : SchedulerFn
        A function that takes the current scheduler state and returns
        the new learning rate and updated state.

    Notes
    -----
    Algorithm:
    - Calculate number of learning rate drops as step // step_size
    - Compute new learning rate as initial_lr * (gamma ^ num_drops)
    - Update scheduler state with new step and learning rate
    - Return new learning rate and updated state
    """

    @jax.jit
    def scheduler_fn(
        state: LRSchedulerState,
    ) -> Tuple[float, LRSchedulerState]:
        num_drops = state.step // step_size
        lr = state.initial_lr * (gamma**num_drops)
        new_state = LRSchedulerState(
            step=state.step + 1, learning_rate=lr, initial_lr=state.initial_lr
        )
        return lr, new_state

    return scheduler_fn


def create_warmup_cosine_scheduler(
    total_steps: int,
    warmup_steps: int,
    final_lr_factor: float = 0.01,
) -> SchedulerFn:
    """Create a scheduler with linear warmup followed by cosine decay.

    This scheduler combines a linear warmup phase with a cosine annealing
    decay. During warmup, the learning rate increases linearly from 0 to
    the initial value. After warmup, it follows a cosine decay schedule.

    Parameters
    ----------
    total_steps : int
        Total number of optimization steps
    warmup_steps : int
        Number of warmup steps
    final_lr_factor : float
        Final learning rate as a fraction of initial learning rate.
        Default is 0.01.

    Returns
    -------
    scheduler_fn : SchedulerFn
        A function that takes the current scheduler state and returns
        the new learning rate and updated state.

    Notes
    -----
    Algorithm:
    - During warmup phase (step < warmup_steps):
        - Calculate linear warmup learning rate
    - During decay phase (step >= warmup_steps):
        - Calculate cosine decay learning rate
    - Choose appropriate learning rate based on current step
    - Update scheduler state with new step and learning rate
    - Return new learning rate and updated state
    """

    @jax.jit
    def scheduler_fn(
        state: LRSchedulerState,
    ) -> Tuple[float, LRSchedulerState]:
        warmup_progress = jnp.minimum(state.step / warmup_steps, 1.0)
        warmup_lr = state.initial_lr * warmup_progress
        remaining_steps = total_steps - warmup_steps
        decay_progress = (
            jnp.maximum(0.0, state.step - warmup_steps) / remaining_steps
        )
        decay_progress = jnp.minimum(decay_progress, 1.0)
        cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * decay_progress))
        decay_lr = state.initial_lr * (
            final_lr_factor + (1 - final_lr_factor) * cosine_decay
        )
        lr = jnp.where(state.step < warmup_steps, warmup_lr, decay_lr)
        new_state = LRSchedulerState(
            step=state.step + 1, learning_rate=lr, initial_lr=state.initial_lr
        )
        return lr, new_state

    return scheduler_fn


def init_scheduler_state(initial_lr: float) -> LRSchedulerState:
    """Initialize scheduler state with given learning rate.

    Parameters
    ----------
    initial_lr : float
        Initial learning rate value

    Returns
    -------
    state : LRSchedulerState
        Initialized scheduler state with step=0 and learning_rate=initial_lr
    """
    return LRSchedulerState(
        step=0, learning_rate=initial_lr, initial_lr=initial_lr
    )


class Optimizer(NamedTuple):
    """Optimizer configuration.

    Attributes
    ----------
    init : Callable
        Function to initialize optimizer state
    update : Callable
        Function to update parameters using optimizer
    """

    init: Callable
    update: Callable


@jaxtyped(typechecker=beartype)
def wirtinger_grad(
    func2diff: Callable[..., Float[Array, " ..."]],
    argnums: Optional[Union[int, Sequence[int]]] = 0,
) -> Callable[
    ..., Union[Complex[Array, " ..."], Tuple[Complex[Array, " ..."], ...]]
]:
    r"""Compute the Wirtinger gradient of a complex-valued function.

    This function returns a new function that computes the Wirtinger gradient
    of the input function f with respect to the specified argument(s).
    This is based on the formula for Wirtinger derivative:

    .. math::
        \frac{\partial f}{\partial z} = \frac{1}{2} \left( \frac{\partial f}
        {\partial x} - i \frac{\partial f}{\partial y} \right)


    Parameters
    ----------
    func2diff : Callable[..., Float[Array, " ..."]]
        A complex-valued function to differentiate.
    argnums : Union[int, Sequence[int]], optional
        Specifies which argument(s) to compute the gradient with respect to.
        Can be an int or a sequence of ints. Default is 0.

    Returns
    -------
    grad_f : Callable[..., Complex[Array, " ..."] | Tuple[Complex[Array, " ..."], ...]]
        A function that computes the Wirtinger gradient of f with respect to
        the specified argument(s).
    """

    def grad_f(
        *args: Any,
    ) -> Union[Complex[Array, " ..."], Tuple[Complex[Array, " ..."], ...]]:
        def split_complex(args: Any) -> Tuple[Any, ...]:
            return tuple(
                jnp.real(arg) if jnp.iscomplexobj(arg) else arg for arg in args
            ) + tuple(
                jnp.imag(arg) if jnp.iscomplexobj(arg) else jnp.zeros_like(arg)
                for arg in args
            )

        def combine_complex(r: Any, i: Any) -> Tuple[Any, ...]:
            return tuple(
                rr + 1j * ii if jnp.iscomplexobj(arg) else rr
                for rr, ii, arg in zip(r, i, args, strict=False)
            )

        split_args = split_complex(args)
        n = len(args)

        def f_real(*split_args: Any) -> Float[Array, " ..."]:
            return jnp.real(
                func2diff(*combine_complex(split_args[:n], split_args[n:]))
            )

        def f_imag(*split_args: Any) -> Float[Array, " ..."]:
            return jnp.imag(
                func2diff(*combine_complex(split_args[:n], split_args[n:]))
            )

        gr = jax.grad(f_real, argnums=argnums)(*split_args)
        gi = jax.grad(f_imag, argnums=argnums)(*split_args)

        if isinstance(argnums, int):
            return 0.5 * (gr - 1j * gi)
        return tuple(
            0.5 * (grr - 1j * gii) for grr, gii in zip(gr, gi, strict=False)
        )

    return grad_f


@jaxtyped(typechecker=beartype)
def complex_adam(
    params: Complex[Array, " ..."],
    grads: Complex[Array, " ..."],
    state: Tuple[Complex[Array, " ..."], Complex[Array, " ..."], int],
    learning_rate: float = 0.001,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
) -> Tuple[
    Complex[Array, " ..."],
    Tuple[Complex[Array, " ..."], Complex[Array, " ..."], int],
]:
    """Complex-valued Adam optimizer based on Wirtinger derivatives.

    This function performs one step of the Adam optimization algorithm
    for complex-valued parameters using Wirtinger calculus.

    Parameters
    ----------
    params : Complex[Array, " ..."]
        Current complex-valued parameters
    grads : Complex[Array, " ..."]
        Complex-valued gradients computed using Wirtinger derivatives
    state : Tuple[Complex[Array, " ..."], Complex[Array, " ..."], int]
        Optimizer state containing (first moment, second moment, timestep)
    learning_rate : float, optional
        Learning rate for parameter updates.
        Default is 0.001.
    beta1 : float, optional
        Exponential decay rate for first moment estimates.
        Default is 0.9.
    beta2 : float, optional
        Exponential decay rate for second moment estimates.
        Default is 0.999.
    eps : float, optional
        Small value to avoid division by zero.
        Default is 1e-8.

    Returns
    -------
    new_params : Complex[Array, " ..."]
        Updated complex-valued parameters
    new_state : Tuple[Complex[Array, " ..."], Complex[Array, " ..."], int]
        Updated optimizer state

    Notes
    -----
    Algorithm:
    - Increment timestep counter
    - Update first moment estimate: m = β₁ * m + (1 - β₁) * grads
    - Update second moment estimate: v = β₂ * v + (1 - β₂) * |grads|²
    - Compute bias-corrected moments: m̂ = m / (1 - β₁^t), v̂ = v / (1 - β₂^t)
    - Calculate parameter update: update = lr * m̂ / (√v̂ + ε)
    - Apply update: new_params = params - update
    - Return updated parameters and state
    """
    m, v, t = state
    t += 1
    m = beta1 * m + (1 - beta1) * grads
    v = beta2 * v + (1 - beta2) * jnp.abs(grads) ** 2
    m_hat = m / (1 - beta1**t)
    v_hat = v / (1 - beta2**t)
    update = learning_rate * m_hat / (jnp.sqrt(v_hat) + eps)
    new_params = params - update
    return new_params, (m, v, t)


@jaxtyped(typechecker=beartype)
def complex_adagrad(
    params: Complex[Array, " ..."],
    grads: Complex[Array, " ..."],
    state: Complex[Array, " ..."],
    learning_rate: float = 0.01,
    eps: float = 1e-8,
) -> Tuple[Complex[Array, " ..."], Complex[Array, " ..."]]:
    """Complex-valued Adagrad optimizer based on Wirtinger derivatives.

    This function performs one step of the Adagrad optimization algorithm
    for complex-valued parameters using Wirtinger calculus.

    Parameters
    ----------
    params : Complex[Array, " ..."]
        Current complex-valued parameters
    grads : Complex[Array, " ..."]
        Complex-valued gradients computed using Wirtinger derivatives
    state : Complex[Array, " ..."]
        Optimizer state containing accumulated squared gradients
    learning_rate : float, optional
        Learning rate for parameter updates.
        Default is 0.01.
    eps : float, optional
        Small value to avoid division by zero.
        Default is 1e-8.

    Returns
    -------
    new_params : Complex[Array, " ..."]
        Updated complex-valued parameters
    new_state : Complex[Array, " ..."]
        Updated optimizer state with accumulated gradients

    Notes
    -----
    Algorithm:
    - Update accumulated squared gradients: G = G + |grads|²
    - Calculate adaptive learning rate: lr_adaptive = lr / (√G + ε)
    - Apply update: new_params = params - lr_adaptive * grads
    - Return updated parameters and accumulated gradients
    """
    accumulated_grads = state
    new_accumulated_grads = accumulated_grads + jnp.abs(grads) ** 2
    adaptive_lr = learning_rate / (jnp.sqrt(new_accumulated_grads) + eps)
    new_params = params - adaptive_lr * grads
    return new_params, new_accumulated_grads


@jaxtyped(typechecker=beartype)
def complex_rmsprop(
    params: Complex[Array, " ..."],
    grads: Complex[Array, " ..."],
    state: Complex[Array, " ..."],
    learning_rate: float = 0.001,
    decay_rate: float = 0.9,
    eps: float = 1e-8,
) -> Tuple[Complex[Array, " ..."], Complex[Array, " ..."]]:
    r"""Complex-valued RMSprop optimizer based on Wirtinger derivatives.

    This function performs one step of the RMSprop optimization algorithm
    for complex-valued parameters using Wirtinger calculus.

    Parameters
    ----------
    params : Complex[Array, " ..."]
        Current complex-valued parameters
    grads : Complex[Array, " ..."]
        Complex-valued gradients computed using Wirtinger derivatives
    state : Complex[Array, " ..."]
        Optimizer state containing moving average of squared gradients
    learning_rate : float, optional
        Learning rate for parameter updates.
        Default is 0.001.
    decay_rate : float, optional
        Decay rate for moving average of squared gradients.
        Default is 0.9.
    eps : float, optional
        Small value to avoid division by zero.
        Default is 1e-8.

    Returns
    -------
    new_params : Complex[Array, " ..."]
        Updated complex-valued parameters
    new_state : Complex[Array, " ..."]
        Updated optimizer state with moving average

    Notes
    -----
    Algorithm:
    - Update moving average of squared gradients:
        .. math::
        v = \rho \cdot v + (1 - \rho) \cdot |\text{grads}|^2
    - Calculate adaptive learning rate:
        .. math::
        lr_{adaptive} = \frac{lr}{\sqrt{v} + \epsilon}
    - Apply update:
    .. math::
        \text{new\_params} = \text{params} - lr_{adaptive} \cdot \text{grads}
    - Return updated parameters and moving average
    """
    moving_avg = state
    new_moving_avg = (
        decay_rate * moving_avg + (1 - decay_rate) * jnp.abs(grads) ** 2
    )
    adaptive_lr = learning_rate / (jnp.sqrt(new_moving_avg) + eps)
    new_params = params - adaptive_lr * grads
    return new_params, new_moving_avg


[docs] @jaxtyped(typechecker=beartype) def init_adam(shape: Tuple) -> OptimizerState: """Initialize Adam optimizer state. Parameters ---------- shape : Tuple Shape of the parameters to be optimized Returns ------- state : OptimizerState Initialized Adam optimizer state with zero moments and step=0 """ return make_optimizer_state(shape)
[docs] @jaxtyped(typechecker=beartype) def init_adagrad(shape: Tuple) -> OptimizerState: """Initialize Adagrad optimizer state. Parameters ---------- shape : Tuple Shape of the parameters to be optimized Returns ------- state : OptimizerState Initialized Adagrad optimizer state with zero accumulated gradients """ return make_optimizer_state(shape)
[docs] @jaxtyped(typechecker=beartype) def init_rmsprop(shape: Tuple) -> OptimizerState: """Initialize RMSprop optimizer state. Parameters ---------- shape : Tuple Shape of the parameters to be optimized Returns ------- state : OptimizerState Initialized RMSprop optimizer state with zero moving average """ return make_optimizer_state(shape)
@jaxtyped(typechecker=beartype) def adam_update( params: Complex[Array, " ..."], grads: Complex[Array, " ..."], state: OptimizerState, learning_rate: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, ) -> Tuple[Complex[Array, " ..."], OptimizerState]: """Update parameters using Adam optimizer with Wirtinger derivatives. Parameters ---------- params : Complex[Array, " ..."] Current complex-valued parameters grads : Complex[Array, " ..."] Complex-valued gradients computed using Wirtinger derivatives state : OptimizerState Current optimizer state learning_rate : float, optional Learning rate for parameter updates. Default is 0.001. beta1 : float, optional Exponential decay rate for first moment estimates. Default is 0.9. beta2 : float, optional Exponential decay rate for second moment estimates. Default is 0.999. eps : float, optional Small value to avoid division by zero. Default is 1e-8. Returns ------- new_params : Complex[Array, " ..."] Updated complex-valued parameters new_state : OptimizerState Updated optimizer state Notes ----- Algorithm: - Extract current state components (m, v, step) - Call complex_adam to perform the update - Return updated parameters and state """ m, v, step = state new_params, (new_m, new_v, new_step) = complex_adam( params, grads, (m, v, step), learning_rate, beta1, beta2, eps ) return new_params, make_optimizer_state( shape=new_m.shape, m=new_m, v=new_v, step=new_step ) @jaxtyped(typechecker=beartype) def adagrad_update( params: Complex[Array, " ..."], grads: Complex[Array, " ..."], state: OptimizerState, learning_rate: float = 0.01, eps: float = 1e-8, ) -> Tuple[Complex[Array, " ..."], OptimizerState]: """Update parameters using Adagrad optimizer with Wirtinger derivatives. Parameters ---------- params : Complex[Array, " ..."] Current complex-valued parameters grads : Complex[Array, " ..."] Complex-valued gradients computed using Wirtinger derivatives state : OptimizerState Current optimizer state learning_rate : float, optional Learning rate for parameter updates. Default is 0.01. eps : float, optional Small value to avoid division by zero. Default is 1e-8. Returns ------- new_params : Complex[Array, " ..."] Updated complex-valued parameters new_state : OptimizerState Updated optimizer state Notes ----- Algorithm: - Extract current state components (m, v, step) - Call complex_adagrad to perform the update - Return updated parameters and state """ m, v, step = state new_params, new_v = complex_adagrad(params, grads, v, learning_rate, eps) return new_params, make_optimizer_state( shape=new_v.shape, m=m, v=new_v, step=step + 1 ) @jaxtyped(typechecker=beartype) def rmsprop_update( params: Complex[Array, " ..."], grads: Complex[Array, " ..."], state: OptimizerState, learning_rate: float = 0.001, decay_rate: float = 0.9, eps: float = 1e-8, ) -> Tuple[Complex[Array, " ..."], OptimizerState]: """Update parameters using RMSprop optimizer with Wirtinger derivatives. Parameters ---------- params : Complex[Array, " ..."] Current complex-valued parameters grads : Complex[Array, " ..."] Complex-valued gradients computed using Wirtinger derivatives state : OptimizerState Current optimizer state learning_rate : float, optional Learning rate for parameter updates. Default is 0.001. decay_rate : float, optional Decay rate for moving average of squared gradients. Default is 0.9. eps : float, optional Small value to avoid division by zero. Default is 1e-8. Returns ------- new_params : Complex[Array, " ..."] Updated complex-valued parameters new_state : OptimizerState Updated optimizer state Notes ----- Algorithm: - Extract current state components (m, v, step) - Call complex_rmsprop to perform the update - Return updated parameters and state """ m, v, step = state new_params, new_v = complex_rmsprop( params, grads, v, learning_rate, decay_rate, eps ) return new_params, make_optimizer_state( shape=new_v.shape, m=m, v=new_v, step=step + 1 )