Source code for unite.continuum.functions

"""JAX-jitted continuum evaluation kernels.

All functions are pure JAX with no numpyro dependency and are designed to
be called from within :func:`jax.jit`-compiled model code.
"""

from __future__ import annotations

from functools import partial
from typing import Final

import jax.numpy as jnp
from astropy import units as u
from astropy.constants import (
    c as _c,  # ty: ignore[unresolved-import]
    h as _h,  # ty: ignore[unresolved-import]
    k_B as _k_B,  # ty: ignore[unresolved-import]
)
from jax import Array, jit
from jax.typing import ArrayLike

# ---------------------------------------------------------------------------
# Physical constants
# ---------------------------------------------------------------------------

_HC_KB: Final[float] = ((_c * _h) / _k_B).to(u.um * u.K).value

# ---------------------------------------------------------------------------
# Planck blackbody
# ---------------------------------------------------------------------------


[docs] @jit def planck_function( wavelength_micron: ArrayLike, temperature_k: ArrayLike, pivot_micron: float ) -> Array: """Return the normalized Planck function ``B_λ(T) / B_λ(pivot, T)``. Returns the blackbody spectral radiance normalized to unity at *pivot_micron*, so the fitted amplitude directly represents the observed flux at the pivot wavelength. Parameters ---------- wavelength_micron : ArrayLike Rest-frame wavelengths in microns. temperature_k : ArrayLike Temperature in Kelvin. pivot_micron : float Normalization wavelength in microns. Returns ------- Array Normalized Planck function (= 1 at *pivot_micron*). Notes ----- Physical constants are pre-combined to avoid gradient overflow in JAX when differentiating ``exp(hc / λkT)`` with respect to temperature. """ hc_kbt = _HC_KB / temperature_k x = hc_kbt / wavelength_micron x_p = hc_kbt / pivot_micron return ((pivot_micron / wavelength_micron) ** 5) * (jnp.expm1(x_p) / jnp.expm1(x))
# --------------------------------------------------------------------------- # Chebyshev polynomial # ---------------------------------------------------------------------------
[docs] def chebval(x: ArrayLike, coeffs: ArrayLike) -> Array: """Evaluate a Chebyshev series using the trigonometric identity. Parameters ---------- x : ArrayLike Evaluation points, normalized to ``[-1, 1]``. coeffs : list of ArrayLike Chebyshev coefficients ``[c0, c1, ..., cN]``. Returns ------- Array Series value at each point in *x*. """ x = jnp.atleast_1d(x) coeffs = jnp.asarray(coeffs) # Create an array of degrees: [0, 1, 2, ..., n] degrees = jnp.arange(len(coeffs), dtype=x.dtype) # Calculate the theta values: theta = acos(x) # Resulting shape: (len(x),) theta = jnp.acos(x) # Calculate the basis: cos(n * theta) # We use broadcasting to get a matrix of shape (len(coeffs), len(x)) basis = jnp.cos(degrees[:, None] * theta[None, :]) # Weighted sum: coeffs @ basis return jnp.dot(coeffs, basis)
# --------------------------------------------------------------------------- # B-spline # ---------------------------------------------------------------------------
[docs] @jit def bernstein_eval(x, coeffs, binom_coeffs): """ Evaluate a Bernstein polynomial series using a vectorized basis matrix. Parameters ---------- x : ArrayLike Evaluation points, must be normalized to the range [0, 1]. Shape: (N,). coeffs : ArrayLike Bernstein coefficients (control points). Shape: (n + 1,). binom_coeffs : ArrayLike Pre-computed binomial coefficients for degree n, where binom_coeffs[i] = C(n, i). Shape: (n + 1,). Returns ------- Array The evaluated polynomial values at each point in x. Shape: (N,). """ x = jnp.atleast_1d(x) n = coeffs.shape[0] - 1 i = jnp.arange(n + 1) # Compute the basis functions using broadcasting # For n=10, this is numerically safe and extremely fast # Resulting shape: (len(wavelength), n+1) basis = binom_coeffs * (x[:, None] ** i) * ((1.0 - x[:, None]) ** (n - i)) return basis @ coeffs
[docs] def bspline_basis(t: ArrayLike, knots: ArrayLike, degree: int) -> Array: """Compute the B-spline basis matrix via iterative Cox-de Boor recursion. The Python loop over *degree* is unrolled at JAX trace time because *degree* is a concrete ``int``, not a traced value. Parameters ---------- t : ArrayLike Evaluation points, shape ``(N,)``. knots : ArrayLike Clamped knot vector, shape ``(M,)``. degree : int Spline degree (e.g. 3 for cubic). Returns ------- Array Basis matrix, shape ``(N, n_basis)`` where ``n_basis = M - degree - 1``. """ t = jnp.asarray(t) knots = jnp.asarray(knots) n_knots = len(knots) # Handle the right-boundary condition (x == knots[-1]) # by pushing points slightly inside the last interval. t = jnp.clip(t, knots[0], knots[-1] - 1e-14) # Degree 0 basis: indicator functions basis = jnp.where( (t[:, None] >= knots[None, :-1]) & (t[:, None] < knots[None, 1:]), 1.0, 0.0 ) # Recursive Cox-de Boor steps for d in range(1, degree + 1): n_basis = n_knots - d - 1 # Denominators dt_left = knots[d : d + n_basis] - knots[:n_basis] dt_right = knots[d + 1 : d + 1 + n_basis] - knots[1 : 1 + n_basis] # Avoid division by zero for repeated knots # Using 1.0 as a dummy denominator; the 'where' will zero out the result anyway left_w = jnp.where(dt_left > 0, (t[:, None] - knots[:n_basis]) / dt_left, 0.0) right_w = jnp.where( dt_right > 0, (knots[d + 1 : d + 1 + n_basis] - t[:, None]) / dt_right, 0.0 ) # Update basis: linear combination of lower-degree bases basis = left_w * basis[:, :n_basis] + right_w * basis[:, 1 : n_basis + 1] return basis
[docs] @partial(jit, static_argnums=(3,)) def bspline_eval( wavelength: ArrayLike, coeffs: ArrayLike, knots: ArrayLike, degree: int ) -> Array: """Evaluate a B-spline continuum model. Parameters ---------- wavelength : ArrayLike Wavelength values, shape ``(N,)``. coeffs : ArrayLike B-spline coefficients, shape ``(n_basis,)``. knots : ArrayLike Clamped knot vector. degree : int Spline degree (static for JIT). Returns ------- Array Continuum flux, shape ``(N,)``. """ basis = bspline_basis(wavelength, knots, degree) return basis @ coeffs