Source code for janssen.plots.wavefront

"""Wavefront plotting functions.

Extended Summary
----------------
Functions for visualizing optical wavefronts including amplitude, phase,
and complex field representations using matplotlib.
This module is NOT JAX-accelerated.

Routine Listings
----------------
plot_amplitude : function
    Plot the amplitude of an optical wavefront
plot_complex_wavefront : function
    Plot a complex optical wavefront using HSV color mapping
plot_intensity : function
    Plot the intensity of an optical wavefront
plot_phase : function
    Plot the phase of an optical wavefront using HSV color mapping
_plot_field : function, internal
    Internal function for plotting the field of an optical wavefront

Notes
-----
All plotting functions use matplotlib and matplotlib-scalebar for
publication-quality figures with proper scale annotations.
"""

from collections.abc import Callable

import matplotlib.pyplot as plt
import numpy as np
from beartype import beartype
from beartype.typing import List, Literal, Optional, Tuple, Union
from jaxtyping import Complex, Float
from matplotlib.axes import Axes
from matplotlib.colors import hsv_to_rgb
from matplotlib.figure import Figure
from matplotlib.image import AxesImage
from matplotlib_scalebar.scalebar import ScaleBar
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.axes_divider import AxesDivider
from numpy import ndarray as NDArray

from janssen.types import OpticalWavefront


def _plot_field(
    wavefront: OpticalWavefront,
    plot_fn: Callable[[Axes, Complex[NDArray, " hh ww"]], AxesImage],
    figsize: Tuple[float, float],
    scalebar_length: float | None,
    scalebar_units: str,
    title: Optional[str],
    colorbar_location: Optional[str] = None,
) -> Union[Tuple[Figure, Axes], Tuple[Figure, Tuple[Axes, Axes]]]:
    """Handle plotting logic for both scalar and polarized fields."""
    field: Complex[NDArray, " hh ww"] = np.asarray(wavefront.field)
    dx: float = float(wavefront.dx)
    is_polarized: bool = field.ndim == 3  # noqa: PLR2004

    if is_polarized:
        polarized_figsize: Tuple[float, float] = (figsize[0] * 2, figsize[1])
        fig: Figure
        ax1: Axes
        ax2: Axes
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=polarized_figsize)
        axes: Tuple[Axes, Axes] = (ax1, ax2)
        fields: List[Complex[NDArray, " hh ww"]] = [
            field[:, :, 0],
            field[:, :, 1],
        ]
        labels: List[str] = ["Ex", "Ey"]

        ax: Axes
        f: Complex[NDArray, " hh ww"]
        label: str
        for ax, f, label in zip(axes, fields, labels, strict=True):
            im = plot_fn(ax, f)
            ax.axis("off")
            ax.set_title(label)

            if colorbar_location:
                divider: AxesDivider = make_axes_locatable(ax)
                is_vertical: bool = colorbar_location in ("left", "right")
                orientation: str = "vertical" if is_vertical else "horizontal"
                cax: Axes = divider.append_axes(
                    colorbar_location, size="5%", pad=0.05
                )
                fig.colorbar(im, cax=cax, orientation=orientation)

            scalebar: ScaleBar = ScaleBar(
                dx,
                units=scalebar_units,
                length_fraction=0.25,
                location="lower right",
                color="white",
                box_alpha=0.5,
                fixed_value=scalebar_length,
            )
            ax.add_artist(scalebar)

        if title is not None:
            fig.suptitle(title)

        fig.tight_layout()
        return fig, axes

    fig: Figure
    ax: Axes
    fig, ax = plt.subplots(figsize=figsize)

    im = plot_fn(ax, field)
    ax.axis("off")

    if colorbar_location:
        divider: AxesDivider = make_axes_locatable(ax)
        is_vertical: bool = colorbar_location in ("left", "right")
        orientation: str = "vertical" if is_vertical else "horizontal"
        cax: Axes = divider.append_axes(colorbar_location, size="5%", pad=0.05)
        fig.colorbar(im, cax=cax, orientation=orientation)

    scalebar: ScaleBar = ScaleBar(
        dx,
        units=scalebar_units,
        length_fraction=0.25,
        location="lower right",
        color="white",
        box_alpha=0.5,
        fixed_value=scalebar_length,
    )
    ax.add_artist(scalebar)

    if title is not None:
        ax.set_title(title)

    fig.tight_layout()

    return fig, ax


[docs] @beartype def plot_complex_wavefront( wavefront: OpticalWavefront, figsize: Tuple[float, float] = (6, 5), scalebar_length: float | None = None, scalebar_units: str = "m", title: Optional[str] = None, ) -> Union[Tuple[Figure, Axes], Tuple[Figure, Tuple[Axes, Axes]]]: """Plot a complex optical wavefront using HSV color mapping. Parameters ---------- wavefront : OpticalWavefront The optical wavefront to plot. Contains field, wavelength, dx, z_position, and polarization attributes. figsize : Tuple[float, float], optional Figure size in inches (width, height). Default is (6, 5). scalebar_length : Optional[float], optional Length of the scalebar in the units specified by scalebar_units. If None, matplotlib-scalebar will choose automatically. scalebar_units : str, optional Units for the scalebar. Default is "m" (meters). title : Optional[str], optional Title for the figure. If None, no title is added. Returns ------- fig : Figure The matplotlib Figure object. ax : Axes or Tuple[Axes, Axes] The matplotlib Axes object. For polarized wavefronts, returns a tuple of two Axes (Ex, Ey). Notes ----- The function creates a single plot using HSV color mapping where: - Hue: Represents the phase arg(U(x,y)), mapped from [-π, π] to [0, 1] - Saturation: Set to 1 (fully saturated colors) - Value: Represents the amplitude |U(x,y)|, normalized to [0, 1] This representation allows simultaneous visualization of both amplitude and phase in a single image. Bright colors indicate high amplitude, while dark regions indicate low amplitude. The color itself encodes the phase. A scalebar is added based on the wavefront's dx (pixel size) parameter. For polarized wavefronts (3D field with shape [H, W, 2]), two side-by-side plots are created showing Ex and Ey components. """ def _plot_complex(ax: Axes, f: Complex[NDArray, " hh ww"]) -> AxesImage: amplitude: Float[NDArray, " hh ww"] = np.abs(f) phase: Float[NDArray, " hh ww"] = np.angle(f) amplitude_normalized: Float[NDArray, " hh ww"] = amplitude / ( np.max(amplitude) + 1e-10 ) hue: Float[NDArray, " hh ww"] = (phase + np.pi) / (2 * np.pi) saturation: Float[NDArray, " hh ww"] = np.ones_like( amplitude_normalized ) value: Float[NDArray, " hh ww"] = amplitude_normalized hsv_image: Float[NDArray, " hh ww 3"] = np.stack( [hue, saturation, value], axis=-1 ) rgb_image: Float[NDArray, " hh ww 3"] = hsv_to_rgb(hsv_image) return ax.imshow(rgb_image, origin="lower") return _plot_field( wavefront, _plot_complex, figsize, scalebar_length, scalebar_units, title, colorbar_location=None, )
[docs] @beartype def plot_amplitude( wavefront: OpticalWavefront, figsize: Tuple[float, float] = (6, 5), cmap: str = "gray", colorbar_location: Literal["top", "bottom", "left", "right"] = "right", colorbar_min: float | None = None, scalebar_length: float | None = None, scalebar_units: str = "m", title: Optional[str] = None, ) -> Union[Tuple[Figure, Axes], Tuple[Figure, Tuple[Axes, Axes]]]: """Plot the amplitude of an optical wavefront. Parameters ---------- wavefront : OpticalWavefront The optical wavefront to plot. Contains field, wavelength, dx, z_position, and polarization attributes. figsize : Tuple[float, float], optional Figure size in inches (width, height). Default is (6, 5). cmap : str, optional Colormap for the amplitude plot. Default is "gray". colorbar_location : Literal["top", "bottom", "left", "right"], optional Location of the colorbar. Default is "right". colorbar_min : float | None, optional Minimum value for the colorbar. If None, uses the minimum value of the amplitude data. Default is None. scalebar_length : Optional[float], optional Length of the scalebar in the units specified by scalebar_units. If None, matplotlib-scalebar will choose automatically. scalebar_units : str, optional Units for the scalebar. Default is "m" (meters). title : Optional[str], optional Title for the figure. If None, no title is added. Returns ------- fig : Figure The matplotlib Figure object. ax : Axes or Tuple[Axes, Axes] The matplotlib Axes object. For polarized wavefronts, returns a tuple of two Axes (Ex, Ey). Notes ----- The function plots the amplitude |U(x,y)| of the complex field. A colorbar is included to show the amplitude scale, sized to match the image dimensions. A scalebar is added based on the wavefront's dx (pixel size) parameter. For polarized wavefronts (3D field with shape [H, W, 2]), two side-by-side plots are created showing Ex and Ey components. """ def _plot_amp(ax: Axes, f: Complex[NDArray, " hh ww"]) -> AxesImage: amplitude: Float[NDArray, " hh ww"] = np.abs(f) vmin: float if colorbar_min is not None: vmin = colorbar_min else: vmin = float(np.min(amplitude)) return ax.imshow(amplitude, cmap=cmap, origin="lower", vmin=vmin) return _plot_field( wavefront, _plot_amp, figsize, scalebar_length, scalebar_units, title, colorbar_location=colorbar_location, )
[docs] @beartype def plot_intensity( wavefront: OpticalWavefront, figsize: Tuple[float, float] = (6, 5), cmap: str = "gray", colorbar_location: Literal["top", "bottom", "left", "right"] = "right", colorbar_min: float | None = None, scalebar_length: float | None = None, scalebar_units: str = "m", title: Optional[str] = None, ) -> Union[Tuple[Figure, Axes], Tuple[Figure, Tuple[Axes, Axes]]]: """Plot the intensity of an optical wavefront. Parameters ---------- wavefront : OpticalWavefront The optical wavefront to plot. Contains field, wavelength, dx, z_position, and polarization attributes. figsize : Tuple[float, float], optional Figure size in inches (width, height). Default is (6, 5). cmap : str, optional Colormap for the intensity plot. Default is "gray". colorbar_location : Literal["top", "bottom", "left", "right"], optional Location of the colorbar. Default is "right". colorbar_min : float | None, optional Minimum value for the colorbar. If None, uses the minimum value of the intensity data. Default is None. scalebar_length : Optional[float], optional Length of the scalebar in the units specified by scalebar_units. If None, matplotlib-scalebar will choose automatically. scalebar_units : str, optional Units for the scalebar. Default is "m" (meters). title : Optional[str], optional Title for the figure. If None, no title is added. Returns ------- fig : Figure The matplotlib Figure object. ax : Axes or Tuple[Axes, Axes] The matplotlib Axes object. For polarized wavefronts, returns a tuple of two Axes (Ex, Ey). Notes ----- The function plots the intensity |U(x,y)|² of the complex field. A colorbar is included to show the intensity scale, sized to match the image dimensions. A scalebar is added based on the wavefront's dx (pixel size) parameter. For polarized wavefronts (3D field with shape [H, W, 2]), two side-by-side plots are created showing Ex and Ey components. """ def _plot_int(ax: Axes, f: Complex[NDArray, " hh ww"]) -> AxesImage: intensity: Float[NDArray, " hh ww"] = np.abs(f) ** 2 vmin: float if colorbar_min is not None: vmin = colorbar_min else: vmin = float(np.min(intensity)) return ax.imshow(intensity, cmap=cmap, origin="lower", vmin=vmin) return _plot_field( wavefront, _plot_int, figsize, scalebar_length, scalebar_units, title, colorbar_location=colorbar_location, )
[docs] @beartype def plot_phase( wavefront: OpticalWavefront, figsize: Tuple[float, float] = (6, 5), scalebar_length: float | None = None, scalebar_units: str = "m", title: Optional[str] = None, ) -> Union[Tuple[Figure, Axes], Tuple[Figure, Tuple[Axes, Axes]]]: """Plot the phase of an optical wavefront using HSV color mapping. Parameters ---------- wavefront : OpticalWavefront The optical wavefront to plot. Contains field, wavelength, dx, z_position, and polarization attributes. figsize : Tuple[float, float], optional Figure size in inches (width, height). Default is (6, 5). scalebar_length : Optional[float], optional Length of the scalebar in the units specified by scalebar_units. If None, matplotlib-scalebar will choose automatically. scalebar_units : str, optional Units for the scalebar. Default is "m" (meters). title : Optional[str], optional Title for the figure. If None, no title is added. Returns ------- fig : Figure The matplotlib Figure object. ax : Axes or Tuple[Axes, Axes] The matplotlib Axes object. For polarized wavefronts, returns a tuple of two Axes (Ex, Ey). Notes ----- The function plots the phase arg(U(x,y)) using HSV color mapping where: - Hue: Represents the phase, mapped from [-π, π] to [0, 1] - Saturation: Set to 1 (fully saturated colors) - Value: Set to 1 (full brightness) This creates a uniform brightness image where only the color encodes the phase information. The cyclic nature of HSV hue naturally represents the cyclic nature of phase. A scalebar is added based on the wavefront's dx (pixel size) parameter. For polarized wavefronts (3D field with shape [H, W, 2]), two side-by-side plots are created showing Ex and Ey components. """ def _plot_ph(ax: Axes, f: Complex[NDArray, " hh ww"]) -> AxesImage: phase: Float[NDArray, " hh ww"] = np.angle(f) hue: Float[NDArray, " hh ww"] = (phase + np.pi) / (2 * np.pi) saturation: Float[NDArray, " hh ww"] = np.ones_like(hue) value: Float[NDArray, " hh ww"] = np.ones_like(hue) hsv_image: Float[NDArray, " hh ww 3"] = np.stack( [hue, saturation, value], axis=-1 ) rgb_image: Float[NDArray, " hh ww 3"] = hsv_to_rgb(hsv_image) return ax.imshow(rgb_image, origin="lower") return _plot_field( wavefront, _plot_ph, figsize, scalebar_length, scalebar_units, title, colorbar_location=None, )