"""Bessel functions for JAX.
Extended Summary
----------------
Differentiable Bessel functions written in JAx for use throughout janssen.
Routine Listings
----------------
bessel_j0 : function
Compute J_0(x), regular Bessel function of the first kind, order 0.
bessel_jn : function
Compute J_n(x), regular Bessel function of the first kind, order n.
bessel_iv_series : function
Compute I_v(x) using series expansion for Bessel function.
bessel_k0_series : function
Compute K_0(x) using series expansion.
bessel_kn_recurrence : function
Compute K_n(x) using recurrence relation.
bessel_kv_small_non_integer : function
Compute K_v(x) for small x and non-integer v.
bessel_kv_small_integer : function
Compute K_v(x) for small x and integer v.
bessel_kv : function
Compute K_v(x), modified Bessel function of the second kind.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Tuple
from jaxtyping import Array, Bool, Float, Int, jaxtyped
from janssen.types import ScalarFloat, ScalarInteger
[docs]
@jax.jit
@jaxtyped(typechecker=beartype)
def bessel_j0(x: Float[Array, " ..."]) -> Float[Array, " ..."]:
r"""Compute J_0(x), regular Bessel function of the first kind, order 0.
Parameters
----------
x : Float[Array, "..."]
Input array.
Returns
-------
Float[Array, " ..."]
Values of J_0(x).
Notes
-----
This is a wrapper around JAX's scipy implementation for consistency.
J_0(x) is the regular Bessel function of the first kind of order 0.
The function is differentiable, JIT-compatible, and supports broadcasting.
Examples
--------
>>> x = jnp.linspace(0, 10, 100)
>>> j0_vals = bessel_j0(x)
"""
result: Float[Array, " ..."] = jax.scipy.special.bessel_jn(x, v=0)[0]
j0_at_zero: ScalarFloat = 1.0
return jnp.where(x == 0.0, j0_at_zero, result)
[docs]
@jaxtyped(typechecker=beartype)
def bessel_jn(
n: ScalarInteger, x: Float[Array, " ..."]
) -> Float[Array, " ..."]:
r"""Compute J_n(x), regular Bessel function of the first kind, order n.
Parameters
----------
n : ScalarInteger
Order of the Bessel function (integer). Must be a compile-time
constant for JIT compilation.
x : Float[Array, "..."]
Input array.
Returns
-------
Float[Array, " ..."]
Values of J_n(x).
Notes
-----
This is a wrapper around JAX's scipy implementation for consistency.
J_n(x) is the regular Bessel function of the first kind of order n.
The function is differentiable and supports broadcasting. Note that
the order `n` must be a compile-time constant (not a traced value)
when used inside JIT-compiled functions.
Examples
--------
>>> x = jnp.linspace(0, 10, 100)
>>> j1_vals = bessel_jn(1, x)
>>> j2_vals = bessel_jn(2, x)
"""
result: Float[Array, " ..."] = jax.scipy.special.bessel_jn(x, v=n)[n]
jn_at_zero_for_n0: ScalarFloat = 1.0
jn_at_zero_for_higher_orders: ScalarFloat = 0.0
value_at_zero: Float[Array, " ..."] = jnp.where(
n == 0, jn_at_zero_for_n0, jn_at_zero_for_higher_orders
)
return jnp.where(x == 0.0, value_at_zero, result)
@jaxtyped(typechecker=beartype)
def bessel_iv_series(
v_order: ScalarFloat, x_val: Float[Array, " ..."], dtype: jnp.dtype
) -> Float[Array, " ..."]:
"""Compute I_v(x) using series expansion for Bessel function.
Note: For negative integer v, the gamma function is infinite which
would produce NaN. We deliberately return 0.0 in these cases to
prevent gradient issues in jnp.where branches that are not selected.
This is safe because bessel_kv only uses this function for non-integer
v values.
"""
x_half: Float[Array, " ..."] = x_val / 2.0
x_half_v: Float[Array, " ..."] = jnp.power(x_half, v_order)
x2_quarter: Float[Array, " ..."] = (x_val * x_val) / 4.0
max_terms: int = 20
k_arr: Float[Array, " 20"] = jnp.arange(max_terms, dtype=dtype)
gamma_v_plus_1: Float[Array, ""] = jax.scipy.special.gamma(v_order + 1)
safe_gamma_v_plus_1: Float[Array, ""] = jnp.where(
jnp.isfinite(gamma_v_plus_1), gamma_v_plus_1, 1.0
)
gamma_terms: Float[Array, " 20"] = jax.scipy.special.gamma(
k_arr + v_order + 1
)
safe_gamma_terms: Float[Array, " 20"] = jnp.where(
jnp.isfinite(gamma_terms), gamma_terms, 1.0
)
factorial_terms: Float[Array, " 20"] = jax.scipy.special.factorial(k_arr)
powers: Float[Array, " ... 20"] = jnp.power(
x2_quarter[..., jnp.newaxis], k_arr
)
series_terms: Float[Array, " ... 20"] = powers / (
factorial_terms * safe_gamma_terms / safe_gamma_v_plus_1
)
result: Float[Array, " ..."] = (
x_half_v / safe_gamma_v_plus_1 * jnp.sum(series_terms, axis=-1)
)
safe_result: Float[Array, " ..."] = jnp.where(
jnp.isfinite(gamma_v_plus_1), result, 0.0
)
return safe_result
@jaxtyped(typechecker=beartype)
def bessel_k0_series(
x: Float[Array, " ..."],
) -> Float[Array, " ..."]:
"""Compute K_0(x) using series expansion."""
i0: Float[Array, " ..."] = jax.scipy.special.i0(x)
coeffs: Float[Array, " 7"] = jnp.array(
[
-0.57721566,
0.42278420,
0.23069756,
0.03488590,
0.00262698,
0.00010750,
0.00000740,
],
dtype=jnp.float64,
)
x2: Float[Array, " ..."] = (x * x) / 4.0
powers: Float[Array, " ... 7"] = jnp.power(
x2[..., jnp.newaxis], jnp.arange(7)
)
poly: Float[Array, " ..."] = jnp.sum(coeffs * powers, axis=-1)
log_term: Float[Array, " ..."] = -jnp.log(x / 2.0) * i0
result: Float[Array, " ..."] = log_term + poly
return result
@jaxtyped(typechecker=beartype)
def bessel_kn_recurrence(
n: ScalarInteger,
x: Float[Array, " ..."],
k0: Float[Array, " ..."],
k1: Float[Array, " ..."],
) -> Float[Array, " ..."]:
"""Compute K_n(x) using recurrence relation.
Uses lax.scan instead of while_loop to support reverse-mode autodiff.
"""
max_order: int = 20
def scan_body(
carry: Tuple[Float[Array, " ..."], Float[Array, " ..."]],
i: Int[Array, ""],
) -> Tuple[
Tuple[Float[Array, " ..."], Float[Array, " ..."]], Float[Array, " ..."]
]:
k_prev2, k_prev1 = carry
two_i_over_x: Float[Array, " ..."] = 2.0 * jnp.asarray(i, x.dtype) / x
k_curr: Float[Array, " ..."] = two_i_over_x * k_prev1 + k_prev2
return (k_prev1, k_curr), k_curr
indices: Int[Array, " 20"] = jnp.arange(1, max_order + 1, dtype=jnp.int32)
_: Tuple[Float[Array, " ..."], Float[Array, " ..."]]
all_kn_from_2: Float[Array, " 20 ..."]
_, all_kn_from_2 = jax.lax.scan(scan_body, (k0, k1), indices)
all_kn: Float[Array, " 22 ..."] = jnp.concatenate(
[k0[jnp.newaxis, ...], k1[jnp.newaxis, ...], all_kn_from_2], axis=0
)
kn_result: Float[Array, " ..."] = all_kn[n]
return kn_result
@jaxtyped(typechecker=beartype)
def bessel_kv_small_non_integer(
v: ScalarFloat, x: Float[Array, " ..."], dtype: jnp.dtype
) -> Float[Array, " ..."]:
"""Compute K_v(x) for small x and non-integer v.
Note: For integer v, sin(pi*v) is zero which would cause division by
zero. We use a safe denominator to prevent NaN gradients in jnp.where
branches that are not selected. This is safe because bessel_kv only
uses this function for non-integer v values.
"""
error_bound: Float[Array, ""] = jnp.asarray(1e-10)
iv_pos: Float[Array, " ..."] = bessel_iv_series(v, x, dtype)
iv_neg: Float[Array, " ..."] = bessel_iv_series(-v, x, dtype)
sin_piv: Float[Array, ""] = jnp.sin(jnp.pi * v)
safe_sin: Float[Array, ""] = jnp.where(
jnp.abs(sin_piv) > error_bound, sin_piv, 1.0
)
pi_over_2sin: Float[Array, ""] = jnp.pi / (2.0 * safe_sin)
iv_diff: Float[Array, " ..."] = iv_neg - iv_pos
result: Float[Array, " ..."] = jnp.where(
jnp.abs(sin_piv) > error_bound, pi_over_2sin * iv_diff, 0.0
)
return result
@jaxtyped(typechecker=beartype)
def bessel_kv_small_integer(
v: Float[Array, ""],
x: Float[Array, " ..."],
) -> Float[Array, " ..."]:
"""Compute K_v(x) for small x and integer v."""
v_int: Float[Array, ""] = jnp.round(v)
n: Int[Array, ""] = jnp.abs(v_int).astype(jnp.int32)
k0: Float[Array, " ..."] = bessel_k0_series(x)
log_term: Float[Array, " ..."] = jnp.log(x / 2.0)
euler_gamma: Float[Array, ""] = jnp.array(
0.57721566490153286060, dtype=x.dtype
)
x2: Float[Array, " ..."] = x * x
x3: Float[Array, " ..."] = x2 * x
x5: Float[Array, " ..."] = x3 * x2
x7: Float[Array, " ..."] = x5 * x2
base_log: Float[Array, " ..."] = log_term + euler_gamma
k1: Float[Array, " ..."] = (
1.0 / x
+ 0.5 * x * (base_log - 0.5)
+ x3 * (base_log - 1.25) / 16.0
+ x5 * (base_log - 1.6666666666666667) / 384.0
+ x7 * (base_log - 1.9583333333333333) / 18432.0
)
kn_result: Float[Array, " ..."] = bessel_kn_recurrence(n, x, k0, k1)
pos_v_result: Float[Array, " ..."] = jnp.where(
v >= 0, kn_result, kn_result
)
return pos_v_result
def _bessel_kv_large(
v: ScalarFloat, x: Float[Array, " ..."]
) -> Float[Array, " ..."]:
"""Asymptotic expansion for K_v(x) for large x."""
sqrt_term: Float[Array, " ..."] = jnp.sqrt(jnp.pi / (2.0 * x))
exp_term: Float[Array, " ..."] = jnp.exp(-x)
v2: Float[Array, ""] = v * v
four_v2: Float[Array, ""] = 4.0 * v2
a0: Float[Array, ""] = 1.0
a1: Float[Array, ""] = (four_v2 - 1.0) / 8.0
a2: Float[Array, ""] = (four_v2 - 1.0) * (four_v2 - 9.0) / (2.0 * 64.0)
a3: Float[Array, ""] = (
(four_v2 - 1.0) * (four_v2 - 9.0) * (four_v2 - 25.0) / (6.0 * 512.0)
)
a4: Float[Array, ""] = (
(four_v2 - 1.0)
* (four_v2 - 9.0)
* (four_v2 - 25.0)
* (four_v2 - 49.0)
/ (24.0 * 4096.0)
)
z: Float[Array, " ..."] = 1.0 / x
poly: Float[Array, " ..."] = a0 + z * (a1 + z * (a2 + z * (a3 + z * a4)))
large_x_result: Float[Array, " ..."] = sqrt_term * exp_term * poly
return large_x_result
@jaxtyped(typechecker=beartype)
def bessel_k_half(x: Float[Array, " ..."]) -> Float[Array, " ..."]:
"""Compute special case K_{1/2}(x) = sqrt(π/(2x)) * exp(-x)."""
sqrt_pi_over_2x: Float[Array, " ..."] = jnp.sqrt(jnp.pi / (2.0 * x))
exp_neg_x: Float[Array, " ..."] = jnp.exp(-x)
k_half_result: Float[Array, " ..."] = sqrt_pi_over_2x * exp_neg_x
return k_half_result
[docs]
@jax.jit
@jaxtyped(typechecker=beartype)
def bessel_kv(v: ScalarFloat, x: Float[Array, " ..."]) -> Float[Array, " ..."]:
r"""Compute the modified Bessel function of the second kind K_v(x).
Parameters
----------
v : ScalarFloat
Order of the Bessel function (v >= 0).
x : Float[Array, "..."]
Positive real input array.
Returns
-------
Float[Array, " ..."]
Approximated values of K_v(x).
Notes
-----
Computes K_v(x) for real order v >= 0 and x > 0, using a numerically stable
and differentiable JAX-compatible approximation.
- Valid for v >= 0 and x > 0
- Supports broadcasting and autodiff
- JIT-safe and VMAP-safe
- Uses series expansion for small x (x <= 2.0) and asymptotic expansion
for large x
- For non-integer v, uses the reflection formula:
K_v = π/(2sin(πv)) * (I_{-v} - I_v)
- For integer v, uses specialized series and recurrence relations
- Special exact formula for v = 0.5: K_{1/2}(x) = sqrt(π/(2x)) * exp(-x)
- The transition point between small and large x approximations is
set at x = 2.0
Algorithm
---------
- For integer orders n > 1, uses recurrence relations with masked updates
to only update values within the target range
"""
v: Float[Array, ""] = jnp.asarray(v)
x: Float[Array, " ..."] = jnp.asarray(x)
dtype: jnp.dtype = x.dtype
v_int: Float[Array, ""] = jnp.round(v)
epsilon_tolerance: float = 1e-10
is_integer: Bool[Array, ""] = jnp.abs(v - v_int) < epsilon_tolerance
small_x_non_int: Float[Array, " ..."] = bessel_kv_small_non_integer(
v, x, dtype
)
small_x_int: Float[Array, " ..."] = bessel_kv_small_integer(v, x)
small_x_vals: Float[Array, " ..."] = jnp.where(
is_integer, small_x_int, small_x_non_int
)
large_x_vals: Float[Array, " ..."] = _bessel_kv_large(v, x)
small_x_threshold: float = 2.0
general_result: Float[Array, " ..."] = jnp.where(
x <= small_x_threshold, small_x_vals, large_x_vals
)
k_half_vals: Float[Array, " ..."] = bessel_k_half(x)
is_half: Bool[Array, ""] = jnp.abs(v - 0.5) < epsilon_tolerance
final_result: Float[Array, " ..."] = jnp.where(
is_half, k_half_vals, general_result
)
return final_result