"""USAF 1951 resolution test pattern generation.
Extended Summary
----------------
Generates USAF 1951 resolution test patterns using pure JAX operations.
The pattern follows the MIL-STD-150A specification with correctly scaled
and positioned groups and elements.
Routine Listings
----------------
create_bar_triplet : function
Creates 3 parallel bars (horizontal or vertical)
create_element_pattern : function
Creates a single element (horizontal + vertical bar triplets)
create_group_pattern : function
Creates a complete group with 6 elements
get_bar_width_pixels : function
Calculates bar width in pixels for given group and element
generate_usaf_pattern : function
Generates USAF 1951 resolution test pattern
Notes
-----
All functions use JAX operations and support automatic differentiation.
The USAF 1951 pattern follows the resolution formula:
Resolution = 2^(group + (element-1)/6) line pairs per mm
Each successive element increases resolution by a factor of 2^(1/6) ≈ 1.122
Each successive group increases resolution by a factor of 2.
"""
import math
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Optional, Tuple
from jaxtyping import (
Array,
Float,
jaxtyped,
)
from janssen.types import (
SampleFunction,
ScalarFloat,
make_sample_function,
)
[docs]
@jaxtyped(typechecker=beartype)
def create_bar_triplet(
width: int,
length: int,
horizontal: bool = True,
) -> Float[Array, "..."]:
"""Create 3 parallel bars (horizontal or vertical).
Parameters
----------
width : int
Width of each bar in pixels (minimum 1)
length : int
Length of each bar in pixels (minimum 1)
horizontal : bool, optional
Whether to create horizontal bars, by default True
Returns
-------
pattern : Float[Array, "..."]
The bar triplet pattern. Shape depends on orientation:
- Horizontal: (5*width, length)
- Vertical: (length, 5*width)
Notes
-----
Creates three bars following USAF specification where bar spacing
equals bar width. Total extent is 5 × bar_width (3 bars + 2 spaces).
Pattern structure (for horizontal):
- Bar 1: rows [0, width)
- Space: rows [width, 2*width)
- Bar 2: rows [2*width, 3*width)
- Space: rows [3*width, 4*width)
- Bar 3: rows [4*width, 5*width)
Both horizontal and vertical patterns are computed, then selected
based on the horizontal flag. This is JAX-safe since the flag is
a static Python bool that doesn't change during tracing.
"""
width_val: int = max(1, width)
length_val: int = max(1, length)
total_bar_extent: int = 5 * width_val
h_h: int = total_bar_extent
h_w: int = length_val
y_coords: Float[Array, " h 1"] = jnp.arange(h_h, dtype=jnp.float32)[
:, None
]
bar1_h: Float[Array, " h 1"] = (y_coords < width_val).astype(jnp.float32)
bar2_h: Float[Array, " h 1"] = (
(y_coords >= 2 * width_val) & (y_coords < 3 * width_val)
).astype(jnp.float32)
bar3_h: Float[Array, " h 1"] = (y_coords >= 4 * width_val).astype(
jnp.float32
)
pattern_h: Float[Array, " h w"] = jnp.broadcast_to(
bar1_h + bar2_h + bar3_h, (h_h, h_w)
)
v_h: int = length_val
v_w: int = total_bar_extent
x_coords: Float[Array, " 1 w"] = jnp.arange(v_w, dtype=jnp.float32)[
None, :
]
bar1_v: Float[Array, " 1 w"] = (x_coords < width_val).astype(jnp.float32)
bar2_v: Float[Array, " 1 w"] = (
(x_coords >= 2 * width_val) & (x_coords < 3 * width_val)
).astype(jnp.float32)
bar3_v: Float[Array, " 1 w"] = (x_coords >= 4 * width_val).astype(
jnp.float32
)
pattern_v: Float[Array, " h w"] = jnp.broadcast_to(
bar1_v + bar2_v + bar3_v, (v_h, v_w)
)
pattern: Float[Array, "..."] = pattern_h if horizontal else pattern_v
return pattern
[docs]
@jaxtyped(typechecker=beartype)
def create_element_pattern(
bar_width_px: int,
gap_factor: float = 0.5,
) -> Float[Array, "..."]:
"""Create a single USAF element (horizontal + vertical bar triplets).
Parameters
----------
bar_width_px : int
Width of each bar in pixels (minimum 1)
gap_factor : float, optional
Gap between triplets as fraction of bar_width, by default 0.5
Returns
-------
element : Float[Array, "..."]
The complete element pattern with both triplets
Notes
-----
Each USAF element consists of:
- 3 horizontal bars (triplet) on the left
- 3 vertical bars (triplet) on the right
Bar length is 5× the bar width per USAF specification.
The element is composed as:
[horizontal triplet] [gap] [vertical triplet]
Triplets are centered vertically within the element canvas.
"""
bar_width: int = max(1, bar_width_px)
bar_length: int = 5 * bar_width
h_triplet: Float[Array, " hh hw"] = create_bar_triplet(
bar_width, bar_length, horizontal=True
)
v_triplet: Float[Array, " vh vw"] = create_bar_triplet(
bar_width, bar_length, horizontal=False
)
gap: int = max(1, int(bar_width * gap_factor))
h_height: int = h_triplet.shape[0]
h_width: int = h_triplet.shape[1]
v_height: int = v_triplet.shape[0]
v_width: int = v_triplet.shape[1]
element_height: int = max(h_height, v_height)
element_width: int = h_width + gap + v_width
element: Float[Array, " eh ew"] = jnp.zeros(
(element_height, element_width), dtype=jnp.float32
)
h_y_offset: int = (element_height - h_height) // 2
element = element.at[h_y_offset : h_y_offset + h_height, :h_width].set(
h_triplet
)
v_y_offset: int = (element_height - v_height) // 2
v_x_offset: int = h_width + gap
element = element.at[
v_y_offset : v_y_offset + v_height, v_x_offset : v_x_offset + v_width
].set(v_triplet)
return element
[docs]
@jaxtyped(typechecker=beartype)
def get_bar_width_pixels(
group: int,
element: int,
pixels_per_mm: float,
) -> int:
"""Calculate bar width in pixels for a given group and element.
Parameters
----------
group : int
Group number (typically -2 to 7)
element : int
Element number (1 to 6)
pixels_per_mm : float
Pixel density in pixels per millimeter
Returns
-------
bar_width : int
Bar width in pixels (minimum 1)
Notes
-----
Resolution formula per MIL-STD-150A:
R = 2^(group + (element-1)/6) line pairs per mm
One line pair = one bar + one space = 2 × bar_width
Therefore: bar_width_mm = 1 / (2 × R)
"""
resolution_lp_mm: float = 2.0 ** (group + (element - 1) / 6.0)
bar_width_mm: float = 1.0 / (2.0 * resolution_lp_mm)
bar_width_px: int = int(round(bar_width_mm * pixels_per_mm))
return max(1, bar_width_px)
[docs]
@jaxtyped(typechecker=beartype)
def create_group_pattern(
group: int,
pixels_per_mm: float,
) -> Tuple[Float[Array, "..."], int]:
"""Create a complete group with 6 elements in 2×3 layout.
Parameters
----------
group : int
Group number
pixels_per_mm : float
Pixel density in pixels per millimeter
Returns
-------
group_pattern : Float[Array, "..."]
The complete group pattern
max_dimension : int
Maximum dimension of the group
Notes
-----
Elements are arranged in 2 columns × 3 rows:
- Column 1: Elements 1, 2, 3 (top to bottom)
- Column 2: Elements 4, 5, 6 (top to bottom)
Elements within a group progressively decrease in size
following the 2^((element-1)/6) scaling.
The loop over 6 elements is unrolled at trace time since
the element count is fixed.
"""
elements: list[Float[Array, "..."]] = []
element_heights: list[int] = []
element_widths: list[int] = []
for elem in range(1, 7):
bar_width: int = get_bar_width_pixels(group, elem, pixels_per_mm)
element: Float[Array, "..."] = create_element_pattern(bar_width)
elements.append(element)
element_heights.append(int(element.shape[0]))
element_widths.append(int(element.shape[1]))
max_elem_width: int = max(element_widths)
elem_spacing: int = max(2, int(max_elem_width * 0.2))
col1_heights: list[int] = element_heights[0:3]
col2_heights: list[int] = element_heights[3:6]
col1_widths: list[int] = element_widths[0:3]
col2_widths: list[int] = element_widths[3:6]
col1_height: int = sum(col1_heights) + elem_spacing * 2
col2_height: int = sum(col2_heights) + elem_spacing * 2
total_height: int = max(col1_height, col2_height)
col1_width: int = max(col1_widths)
col2_width: int = max(col2_widths)
col_gap: int = max(2, int(col1_width * 0.4))
total_width: int = col1_width + col_gap + col2_width
group_pattern: Float[Array, " h w"] = jnp.zeros(
(total_height, total_width), dtype=jnp.float32
)
y_pos: int = 0
for i in range(3):
elem = elements[i]
eh: int = element_heights[i]
ew: int = element_widths[i]
x_offset: int = (col1_width - ew) // 2
group_pattern = group_pattern.at[
y_pos : y_pos + eh, x_offset : x_offset + ew
].set(elem)
y_pos += eh + elem_spacing
y_pos = 0
x_base: int = col1_width + col_gap
for i in range(3, 6):
elem = elements[i]
eh = element_heights[i]
ew = element_widths[i]
x_offset = x_base + (col2_width - ew) // 2
group_pattern = group_pattern.at[
y_pos : y_pos + eh, x_offset : x_offset + ew
].set(elem)
y_pos += eh + elem_spacing
max_dimension: int = max(total_height, total_width)
return group_pattern, max_dimension
[docs]
@jaxtyped(typechecker=beartype)
def calculate_usaf_group_range(
image_size: int,
pixel_size: float,
min_bar_pixels: int = 2,
grid_fill_fraction: float = 0.95,
) -> dict:
"""Calculate the viable USAF group range for given parameters.
Parameters
----------
image_size : int
Image size in pixels (square)
pixel_size : float
Pixel size in meters
min_bar_pixels : int, optional
Minimum bar width in pixels for visibility, by default 2
grid_fill_fraction : float, optional
Scale factor for fitting largest group, by default 0.95
Returns
-------
result : dict
Dictionary containing:
- max_group: finest group where bars are still >= min_bar_pixels
- min_group: coarsest group that fits in image
- recommended_range: suggested range() for generate_usaf_pattern
- num_groups: how many groups in recommended range
- pixels_per_mm: calculated pixel density
- group_info: dict with bar width info for each group
Notes
-----
This function maximizes the number of groups that can fit in the image
using variable-density row packing (smaller groups pack more per row).
Examples
--------
>>> result = calculate_usaf_group_range(
... image_size=8192,
... pixel_size=0.5e-6,
... )
>>> min_g, max_g = result['min_group'], result['max_group']
>>> print(f"Use: groups=range({min_g}, {max_g + 1})")
"""
pixels_per_mm: float = 1e-3 / pixel_size
max_group: int = int(
math.floor(math.log2(pixels_per_mm / (2 * min_bar_pixels)) - 5 / 6)
)
def get_group_size(g: int, ppm: float) -> Tuple[int, int]:
"""Estimate group pattern size in pixels.
Returns (height, width) of the group pattern.
"""
bar_width: float = ppm / (2 * (2**g))
bar_width = max(1, bar_width)
elem_height: float = 5 * bar_width
elem_width: float = 10.5 * bar_width
elem_spacing: float = 0.2 * elem_width
col_height: float = 3 * elem_height + 2 * elem_spacing
col_gap: float = 0.4 * elem_width
total_width: float = 2 * elem_width + col_gap
return max(1, int(col_height)), max(1, int(total_width))
def simulate_packing(groups: list[int], ppm: float, img_size: int) -> int:
"""Simulate row-packing and return how many groups fit."""
margin: int = img_size // 40
usable: int = img_size - 2 * margin
spacing_h: int = max(5, img_size // 200)
spacing_v: int = max(20, img_size // 80)
largest_g: int = min(groups)
largest_h, largest_w = get_group_size(largest_g, ppm)
effective_ppm: float = ppm
if largest_h > usable or largest_w > usable:
scale: float = (
min(usable / largest_h, usable / largest_w)
* grid_fill_fraction
)
effective_ppm = ppm * scale
current_x: int = margin
current_y: int = margin
row_max_h: int = 0
count: int = 0
for g in groups:
gh, gw = get_group_size(g, effective_ppm)
if current_x + gw > img_size - margin:
current_x = margin
current_y += row_max_h + spacing_v
row_max_h = 0
if current_y + gh > img_size - margin:
break
current_x += gw + spacing_h
row_max_h = max(row_max_h, gh)
count += 1
return count
best_min_group: int = max_group
best_count: int = 1
for candidate_min in range(-10, max_group + 1):
groups_to_try: list[int] = list(range(candidate_min, max_group + 1))
count: int = simulate_packing(groups_to_try, pixels_per_mm, image_size)
if count >= len(groups_to_try):
if len(groups_to_try) > best_count:
best_count = len(groups_to_try)
best_min_group = candidate_min
min_group: int = best_min_group
group_info: dict = {}
for g in range(min_group, max_group + 1):
bar_width_e1: float = pixels_per_mm / (2 * (2**g))
bar_width_e6: float = pixels_per_mm / (2 * (2 ** (g + 5 / 6)))
gh, gw = get_group_size(g, pixels_per_mm)
group_info[g] = {
"bar_width_element1": round(bar_width_e1, 1),
"bar_width_element6": round(bar_width_e6, 1),
"group_size_approx": f"{gh}x{gw}",
}
recommended_range: range = range(min_group, max_group + 1)
return {
"max_group": max_group,
"min_group": min_group,
"recommended_range": recommended_range,
"num_groups": len(recommended_range),
"pixels_per_mm": pixels_per_mm,
"group_info": group_info,
}
[docs]
@jaxtyped(typechecker=beartype)
def generate_usaf_pattern( # noqa: PLR0912, PLR0915
image_size: int = 1024,
groups: Optional[range] = None,
pixel_size: ScalarFloat = 1.0e-6,
background: float = 0.0,
foreground: float = 1.0,
max_phase: float = 0.0,
auto: bool = False,
min_bar_pixels: int = 2,
) -> SampleFunction:
"""Generate USAF 1951 resolution test pattern.
Parameters
----------
image_size : int, optional
Size of the output image (square), by default 1024
groups : range, optional
Range of groups to include, by default range(-2, 8).
Ignored if auto=True.
pixel_size : ScalarFloat, optional
Physical size of each pixel in meters, by default 1.0e-6 (1 µm)
background : float, optional
Background value, by default 0.0 (black)
foreground : float, optional
Foreground (bar) value, by default 1.0 (white)
max_phase : float, optional
Maximum phase shift in radians applied to the bars, by default 0.0.
The phase pattern follows the same structure as the amplitude,
scaling from 0 (at background) to max_phase (at foreground).
auto : bool, optional
If True, automatically calculate the optimal group range to
fill the image based on image_size and pixel_size. Overrides
the groups parameter. By default False.
min_bar_pixels : int, optional
Minimum bar width in pixels for visibility when auto=True,
by default 2. Ignored if auto=False.
Returns
-------
pattern : SampleFunction
SampleFunction PyTree containing the USAF test pattern as a
complex array with both amplitude and phase information.
Notes
-----
The USAF 1951 test pattern consists of groups arranged in a grid.
Each group contains 6 elements of progressively higher resolution.
Resolution formula per MIL-STD-150A:
Resolution = 2^(group + (element-1)/6) line pairs per mm
Standard groups range from -2 (coarsest) to 7 (finest).
Each element consists of:
- 3 horizontal bars (bar triplet)
- 3 vertical bars (bar triplet)
with bar length = 5 × bar width.
The output is a complex field: amplitude * exp(i * phase), where
the phase follows the same spatial pattern as the amplitude.
The loop over groups is unrolled at Python trace time since
groups_list is known before tracing. Python-level conditionals
for bounds checking, scaling, and phase normalization are evaluated
at trace time since all controlling values are Python scalars.
A global scale factor is computed from the largest (coarsest) group
to ensure all groups fit within their grid cells while preserving
the correct relative size ratios between groups.
Examples
--------
>>> from janssen.models import generate_usaf_pattern
>>> pattern = generate_usaf_pattern(image_size=1024, pixel_size=1e-6)
>>> pattern.sample.shape
(1024, 1024)
>>> # Auto mode: fill the image optimally
>>> pattern = generate_usaf_pattern(
... image_size=8192, pixel_size=0.5e-6, auto=True)
>>> # Camera with 6.5 µm pixels
>>> pattern = generate_usaf_pattern(
... pixel_size=6.5e-6)
>>> # White background with black bars (typical)
>>> pattern = generate_usaf_pattern(background=1.0, foreground=0.0)
>>> # Phase object with π phase shift on bars
>>> pattern = generate_usaf_pattern(max_phase=jnp.pi)
>>> # Specific group range
>>> pattern = generate_usaf_pattern(groups=range(0, 5))
"""
if auto:
range_info = calculate_usaf_group_range(
image_size=image_size,
pixel_size=float(pixel_size),
min_bar_pixels=min_bar_pixels,
)
groups_list: list[int] = list(range_info["recommended_range"])
else:
groups_list = (
list(groups) if groups is not None else list(range(-2, 8))
)
dx_calculated: ScalarFloat = float(pixel_size)
pixels_per_mm: float = 1.0e-3 / float(pixel_size)
canvas: Float[Array, " h w"] = jnp.full(
(image_size, image_size), background, dtype=jnp.float32
)
margin: int = image_size // 40
usable_size: int = image_size - 2 * margin
spacing_h: int = max(5, image_size // 200)
def estimate_group_size(group: int, ppm: float) -> Tuple[int, int]:
"""Estimate group size without generating full pattern."""
bar_width: int = get_bar_width_pixels(group, 1, ppm)
bar_length: int = 5 * bar_width
elem_height: int = 5 * bar_width
elem_width: int = (
bar_length + max(1, int(bar_width * 0.5)) + 5 * bar_width
)
elem_spacing: int = max(2, int(elem_width * 0.2))
col_height: int = 3 * elem_height + 2 * elem_spacing
col_gap: int = max(2, int(elem_width * 0.4))
total_width: int = 2 * elem_width + col_gap
return col_height, total_width
largest_group: int = min(groups_list)
largest_h, largest_w = estimate_group_size(largest_group, pixels_per_mm)
effective_ppm: float = pixels_per_mm
if largest_h > usable_size or largest_w > usable_size:
scale: float = (
min(usable_size / largest_h, usable_size / largest_w) * 0.95
)
effective_ppm = pixels_per_mm * scale
rows: list[list[Tuple[int, int, int]]] = []
current_row: list[Tuple[int, int, int]] = []
current_x: int = margin
for group in groups_list:
gh, gw = estimate_group_size(group, effective_ppm)
if current_x + gw > image_size - margin and current_row:
rows.append(current_row)
current_row = []
current_x = margin
current_row.append((group, gh, gw))
current_x += gw + spacing_h
if current_row:
rows.append(current_row)
row_heights: list[int] = []
for row in rows:
max_h = max(gh for _, gh, _ in row)
row_heights.append(max_h)
total_row_height: int = sum(row_heights)
num_gaps: int = len(rows) + 1
total_free_space: int = image_size - total_row_height
spacing_v: int = total_free_space // num_gaps if num_gaps > 0 else margin
current_y: int = spacing_v
for row_idx, row in enumerate(rows):
row_height: int = row_heights[row_idx]
current_x = margin
row_width: int = sum(gw for _, _, gw in row) + spacing_h * (
len(row) - 1
)
current_x = (image_size - row_width) // 2
for group, gh_est, gw_est in row:
if current_y + row_height > image_size - spacing_v // 2:
break
pattern, _ = create_group_pattern(group, effective_ppm)
gh: int = int(pattern.shape[0])
gw: int = int(pattern.shape[1])
y_offset: int = (row_height - gh) // 2
x_pos: int = current_x
y_pos: int = current_y + y_offset
gh_clipped: int = min(gh, image_size - y_pos)
gw_clipped: int = min(gw, image_size - x_pos)
if gh_clipped > 0 and gw_clipped > 0 and y_pos >= 0 and x_pos >= 0:
clipped_pattern = pattern[:gh_clipped, :gw_clipped]
scaled_pattern: Float[Array, " gh gw"] = (
background + clipped_pattern * (foreground - background)
)
canvas = canvas.at[
y_pos : y_pos + gh_clipped, x_pos : x_pos + gw_clipped
].set(scaled_pattern)
current_x += gw + spacing_h
current_y += row_height + spacing_v
if foreground != background:
normalized_pattern: Float[Array, " h w"] = (canvas - background) / (
foreground - background
)
else:
normalized_pattern = jnp.zeros_like(canvas)
phase_pattern: Float[Array, " h w"] = normalized_pattern * max_phase
complex_field = canvas.astype(jnp.complex64) * jnp.exp(
1j * phase_pattern.astype(jnp.complex64)
)
pattern: SampleFunction = make_sample_function(
complex_field, dx_calculated
)
return pattern