"""Continuum functional forms: abstract base and concrete implementations."""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, cast, override
import jax
import jax.numpy as jnp
import numpy as np
from astropy import units as u
from jax import Array
from jax.typing import ArrayLike
from unite._utils import _ensure_wavelength, _get_conversion_factor, _make_register
from unite.continuum.functions import (
bernstein_eval,
bspline_eval,
chebval,
planck_function,
)
from unite.prior import Fixed, Prior, Uniform
if TYPE_CHECKING:
pass
def _gaussian_convolve_poly(coeffs: Array, lsf_fwhm: ArrayLike) -> Array:
"""Analytically convolve a polynomial with a Gaussian LSF.
Given a polynomial ``p(x) = Σ c_i x^(n-i)`` (NumPy descending-order
convention, degree *n* = ``len(coeffs) - 1``) and a Gaussian kernel with
FWHM ``lsf_fwhm``, return the coefficients of the convolved polynomial
``(p * G)(x)``.
A polynomial convolved with a Gaussian is still a polynomial of the same
degree. Only even Gaussian moments contribute (odd moments vanish by
symmetry), giving the coefficient update:
.. code-block:: text
c_j_new = sum_{k=0,2,4,...}^{N-j} c_{j+k} * C(j+k, k) * (k-1)!! * sigma^k
where ``(k-1)!!`` is the double factorial (the *k*-th even moment of
a standard normal). See the
:doc:`polynomial derivation </derivations/polynomial>` for a full derivation.
For a monomial ``x^k`` convolved with ``N(0, s^2)``, the result is::
sum_{j=0}^{floor(k/2)} C(k, 2j) * (2j-1)!! * s^{2j} * x^{k-2j}
Parameters
----------
coeffs : Array, shape ``(n+1,)``
Polynomial coefficients in **descending** order
(i.e. ``coeffs[0]`` is the leading coefficient for ``x^n``).
lsf_fwhm : ArrayLike
Scalar (or broadcastable) LSF FWHM in the same unit as *x*.
Returns
-------
Array, shape ``(n+1,)``
Convolved polynomial coefficients, same descending-order convention.
"""
sigma2 = (jnp.asarray(lsf_fwhm) / (2.0 * np.sqrt(2.0 * np.log(2.0)))) ** 2
n = coeffs.shape[0] - 1 # polynomial degree; static at trace time
max_half = n // 2 + 1
# Even Gaussian moments M[j] = (2j-1)!! * sigma^{2j}.
# M[0] = 1; M[j] = M[j-1] * (2j-1) * sigma^2 for j >= 1.
if max_half > 1:
def _moment_step(carry, j):
cur = carry * (2 * j - 1) * sigma2
return cur, cur
_, rest = jax.lax.scan(_moment_step, 1.0, jnp.arange(1, max_half))
moments = jnp.concatenate([jnp.array([1.0]), rest])
else:
moments = jnp.ones(1)
# Binomial coefficients C(k, 2j): pure NumPy — depends only on n (static).
binom_np = np.zeros((n + 1, max_half))
for k in range(n + 1):
for j in range(min(k // 2, max_half - 1) + 1):
val = 1.0
for m in range(2 * j):
val = val * (k - m) / (m + 1)
binom_np[k, j] = val
# Static scatter tensor: A[out_idx, i, j] = binom_np[n-i, j] when
# out_idx == i + 2*j (and the entry is in range), else 0.
# out[out_idx] = sum_i sum_j coeffs[i] * A[out_idx, i, j] * moments[j]
# = einsum('oij,j,i->o', A, moments, coeffs).
a_np = np.zeros((n + 1, n + 1, max_half))
for i in range(n + 1):
k = n - i
for j in range(min(k // 2, max_half - 1) + 1):
a_np[i + 2 * j, i, j] = binom_np[k, j]
return jnp.einsum('oij,j,i->o', jnp.asarray(a_np), moments, coeffs)
def _polyint_at(coeffs: Array, x: ArrayLike) -> Array:
"""Antiderivative of a polynomial evaluated at *x*.
Given descending-order coefficients ``[a_n, ..., a_0]``, return the
antiderivative ``P(x) = ∫ p(x') dx' = a_n/(n+1) * x^{n+1} + ... + a_0 * x``
evaluated at each entry of *x* (constant of integration zero).
Used by polynomial-based continuum forms to produce a cumulative-at-edges
array: ``jnp.diff(P_at_edges) / jnp.diff(edges)`` recovers the exact
pixel-averaged value of *p* over each pixel.
Parameters
----------
coeffs : Array, shape ``(n+1,)``
Polynomial coefficients in descending order.
x : ArrayLike
Points at which to evaluate the antiderivative.
Returns
-------
Array
Antiderivative value at each *x*.
"""
# Antiderivative coefficients (also descending), with zero constant term.
n = coeffs.shape[0]
divisors = jnp.arange(n, 0, -1, dtype=coeffs.dtype)
anti = jnp.concatenate([coeffs / divisors, jnp.array([0.0], dtype=coeffs.dtype)])
return jnp.polyval(anti, x)
def _cheb_to_mono_matrix(n: int) -> Array:
"""Build the (n x n) Chebyshev-to-monomial conversion matrix.
Returns a matrix *M* such that ``M @ cheb_coeffs`` gives monomial
coefficients in **ascending** order (``[a0, a1, ..., a_{n-1}]``).
``cheb_coeffs = [c0, c1, ..., c_{n-1}]`` represents
``c0*T0(x) + c1*T1(x) + ... + c_{n-1}*T_{n-1}(x)``.
Built at Python time (not traced), so the matrix is a static constant.
"""
import numpy as np
from numpy.polynomial.chebyshev import cheb2poly
M = np.zeros((n, n)) # noqa: N806
for k in range(n):
basis = np.zeros(n)
basis[k] = 1.0
# cheb2poly returns ascending monomial coefficients for T_k
poly = cheb2poly(basis)
M[: len(poly), k] = poly
return jnp.array(M)
def _bernstein_to_mono_matrix(n: int) -> Array:
"""Build the ((n+1) x (n+1)) Bernstein-to-monomial conversion matrix.
Returns a matrix *M* such that ``M @ bern_coeffs`` gives monomial
coefficients in **ascending** order for the polynomial on ``[0, 1]``.
``bern_coeffs = [b0, b1, ..., b_n]`` represents
``Σ b_i * C(n,i) * t^i * (1-t)^{n-i}``.
"""
import numpy as np
from scipy.special import comb
M = np.zeros((n + 1, n + 1)) # noqa: N806
# Monomial coeff for t^k: Σ_{i=0}^{k} C(n,i) C(n-i, k-i) (-1)^{k-i} b_i
for k in range(n + 1):
for i in range(k + 1):
M[k, i] = (
comb(n, i, exact=True)
* comb(n - i, k - i, exact=True)
* (-1) ** (k - i)
)
return jnp.array(M)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_FORM_REGISTRY: dict[str, type[ContinuumForm]] = {}
_register = _make_register(_FORM_REGISTRY)
# ---------------------------------------------------------------------------
# Abstract base class
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Deserialization helper
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Piecewise / region-local forms
# ---------------------------------------------------------------------------
[docs]
@_register
class Linear(ContinuumForm):
"""Linear continuum: ``scale + slope * (wavelength - norm_wav)``.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``slope`` — Continuum slope in flux per wavelength unit.
Default prior: ``Uniform(-10, 10)``.
* ``norm_wav`` — Reference wavelength where the
continuum equals ``scale``.
Default prior: ``Fixed(region_center)``.
"""
@property
@override
def is_linear(self) -> bool:
return True
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'angle', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'angle': Uniform(-np.pi / 2, np.pi / 2),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'angle': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
nw = params['norm_wav']
return params['scale'] + jnp.tan(params['angle']) * (wavelength - nw)
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# Linear is preserved exactly under Gaussian convolution, so LSF is a
# no-op. The antiderivative of ``scale + slope * (λ - nw)`` is
# ``scale * (λ - nw) + slope * (λ - nw)² / 2`` (constant of integration
# zero); ``jnp.diff`` then recovers the exact pixel integral.
nw = params['norm_wav']
x = jnp.asarray(edges) - nw
slope = jnp.tan(params['angle'])
return jnp.asarray(params['scale']) * x + 0.5 * slope * x * x
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Linear'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Linear:
return cls()
[docs]
@_register
class Polynomial(ContinuumForm):
"""Polynomial continuum of configurable degree.
Evaluates ``scale + c1*x + c2*x**2 + ...`` where
``x = wavelength - norm_wav``.
Parameters
----------
degree : int
Polynomial degree (default 1).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``c1, c2, ...`` — Higher-order polynomial coefficients.
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, degree: int = 1) -> None:
if degree < 0:
msg = f'Polynomial degree must be >= 0, got {degree}'
raise ValueError(msg)
self._degree = degree
@property
@override
def is_linear(self) -> bool:
return True
@property
def degree(self) -> int:
"""Polynomial degree."""
return self._degree
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._degree == 0:
return ('scale', 'norm_wav')
return ('scale', *(f'c{i}' for i in range(1, self._degree + 1)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._degree + 1):
priors[f'c{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._degree + 1):
d[f'c{i}'] = (True, flux_unit / wl_unit**i)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
nw = params['norm_wav']
x = wavelength - nw
# Monomial coefficients in descending order: c_n, ..., c_1, scale.
mono = jnp.array(
[params[f'c{i}'] for i in range(self._degree, 0, -1)] + [params['scale']]
)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm)
return jnp.polyval(convolved, x)
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
nw = params['norm_wav']
x = jnp.asarray(edges) - nw
mono = jnp.array(
[params[f'c{i}'] for i in range(self._degree, 0, -1)] + [params['scale']]
)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm)
return _polyint_at(convolved, x)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Polynomial', 'degree': self._degree}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Polynomial:
return cls(degree=d['degree'])
@override
def __repr__(self) -> str:
return f'Polynomial(degree={self._degree})'
def _normalize_wavelength(
wavelength: ArrayLike,
center: float,
obs_low: float,
obs_high: float,
stretch: float,
) -> tuple[Array, float]:
"""Map observed wavelengths to the normalised coordinate used by Chebyshev/Bernstein.
Both forms normalize ``wavelength`` to the interval ``[-1, 1]`` (for
Chebyshev) or ``[0, 1]`` (for Bernstein) using the same scale factor
``(obs_high - obs_low) / 2 * stretch``. This helper computes the
shared scale factor and the intermediate ``u = (w - center) / scale``
coordinate; callers apply the form-specific final transformation.
Returns ``(u, scale_factor)`` where ``u = (wavelength - center) / scale_factor``.
"""
scale_factor = (obs_high - obs_low) / 2 * stretch
u = (jnp.asarray(wavelength) - center) / scale_factor
return u, scale_factor
[docs]
@_register
class Chebyshev(ContinuumForm):
"""Chebyshev polynomial continuum of configurable order.
Evaluates a Chebyshev series on coordinates normalized to ``[-1, 1]``
within the continuum region, normalized so that the continuum equals
``scale`` at ``norm_wav``. Numerically more stable than a standard
polynomial basis for higher orders.
The x-coordinate is ``(wavelength - center) / (half_width * stretch)``
where ``half_width`` is derived from the region bounds passed to
:meth:`evaluate`, and ``stretch`` is a form-specific scaling factor
(default ``1.0`` for identity normalization).
The continuum is parameterized as ``scale * T(x) / T(x_nw)`` where
``T`` is the Chebyshev series with constant term fixed at 1.0,
and ``x_nw`` is the normalized coordinate at ``norm_wav``.
Parameters
----------
order : int
Chebyshev order (default 2). Number of coefficients = order + 1.
stretch : float
Stretch factor to scale the region normalization (default 1.0).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``c1, c2, ...`` — Higher-order Chebyshev coefficients (normalized to constant term 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, order: int = 2, stretch: float = 1.0) -> None:
if order < 0:
msg = f'Chebyshev order must be >= 0, got {order}'
raise ValueError(msg)
self._order = order
if stretch <= 0:
msg = f'Chebyshev stretch factor must be > 0, got {stretch}'
raise ValueError(msg)
self._stretch = stretch
# Static Chebyshev-to-monomial conversion matrix.
self._cheb2mono = _cheb_to_mono_matrix(order + 1)
@property
@override
def is_linear(self) -> bool:
return False
@property
def order(self) -> int:
"""Chebyshev order."""
return self._order
@property
def stretch(self) -> float:
"""Stretch factor for the region normalization."""
return self._stretch
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._order == 0:
return ('scale', 'norm_wav')
return ('scale', *(f'c{i}' for i in range(1, self._order + 1)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._order + 1):
priors[f'c{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# x is normalised to [-1, 1], so all coefficients have unit flux_unit.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._order + 1):
d[f'c{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
x, scale_factor = _normalize_wavelength(
wavelength, center, obs_low, obs_high, self._stretch
)
x_nw, _ = _normalize_wavelength(
params['norm_wav'], center, obs_low, obs_high, self._stretch
)
cheb_coeffs = jnp.array(
[1.0] + [params[f'c{i}'] for i in range(1, self._order + 1)]
)
mono = (self._cheb2mono @ cheb_coeffs)[::-1] # ascending → descending
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / scale_factor
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
shape = jnp.polyval(convolved, x)
shape_nw = chebval(x_nw, cheb_coeffs)
return params['scale'] * shape / shape_nw
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
x, scale_factor = _normalize_wavelength(
edges, center, obs_low, obs_high, self._stretch
)
x_nw, _ = _normalize_wavelength(
params['norm_wav'], center, obs_low, obs_high, self._stretch
)
cheb_coeffs = jnp.array(
[1.0] + [params[f'c{i}'] for i in range(1, self._order + 1)]
)
mono = (self._cheb2mono @ cheb_coeffs)[::-1]
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / scale_factor
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
# Antiderivative in normalised coord; rescale to λ by dλ/du = scale_factor.
shape_anti = _polyint_at(convolved, x) * scale_factor
shape_nw = chebval(x_nw, cheb_coeffs)
return params['scale'] * shape_anti / shape_nw
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Chebyshev', 'order': self._order, 'stretch': self._stretch}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Chebyshev:
return cls(order=d['order'], stretch=d.get('stretch', 1.0))
@override
def __repr__(self) -> str:
return f'Chebyshev(order={self._order}, stretch={self._stretch})'
[docs]
@_register
class BSpline(ContinuumForm):
"""B-spline continuum with local knot control.
The knot vector must be set via *knots* at construction time (typically
derived from the wavelength coverage of the spectrum). Knots should be
in the same wavelength unit as the :class:`ContinuumRegion` bounds;
they are converted to the canonical unit at region construction time via
:meth:`_prepare`. Knots must fall within the region bounds.
The continuum is normalized so that it equals ``scale`` at ``norm_wav``,
parameterized as ``scale * S(u) / S(u_nw)`` where ``S`` is the B-spline
series with first coefficient fixed at 1.0, and ``u_nw`` is the normalized
coordinate at ``norm_wav``.
Parameters
----------
knots : u.Quantity
Knot vector in wavelength units. It is automatically clamped at the region bounds.
degree : int
Spline degree (default 3 for cubic).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``coeff_1, coeff_2, …`` — Remaining B-spline coefficients (normalized to first 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, knots: u.Quantity, degree: int = 3) -> None:
if isinstance(knots, u.Quantity):
self._knots: u.Quantity = _ensure_wavelength(knots, 'knots', ndim=1)
else:
raise ValueError(
f'knots must be an astropy Quantity with length units, got {knots}'
)
self._degree = degree
self._n_basis = len(self._knots) + degree + 1
@property
@override
def is_linear(self) -> bool:
return False
@property
def degree(self) -> int:
"""Spline degree."""
return self._degree
@property
def n_basis(self) -> int:
"""Number of B-spline basis functions."""
return self._n_basis
@property
def knots(self) -> u.Quantity:
"""Knot vector (in the original units passed at construction)."""
return self._knots
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._n_basis == 1:
return ('scale', 'norm_wav')
return ('scale', *(f'coeff_{i}' for i in range(1, self._n_basis)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._n_basis):
priors[f'coeff_{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# B-spline coefficients share the same unit as the function value.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._n_basis):
d[f'coeff_{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for BSpline (non-polynomial basis).
# Normalize wavelengths relative to the knot range
obs_range = obs_high - obs_low
# Map wavelengths to the same coordinate system as the knots
u = 2 * (wavelength - obs_low) / obs_range - 1
# Map norm_wav to the same coordinate system
nw = params['norm_wav']
u_nw = 2 * (nw - obs_low) / obs_range - 1
# Also map knots to [-1, 1]
knots_norm = 2 * (self._knots_eval - obs_low) / obs_range - 1
shape_coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._n_basis)]
)
shape = bspline_eval(u, shape_coeffs, knots_norm, self._degree)
_snw = bspline_eval(
jnp.atleast_1d(u_nw), shape_coeffs, knots_norm, self._degree
)
shape_nw = _snw[0]
return params['scale'] * shape / shape_nw
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Validate and prepare the knot vector for evaluation within the region bounds."""
if any((self._knots <= low) or (self._knots >= high)):
msg = f'All knots must be within the region bounds [{low}, {high}], got {self._knots}'
raise ValueError(msg)
# Add n-1 extra knots at each end for clamping (B-spline convention)
p = self._degree + 1
knots_clamped = jnp.concatenate(
[
low.value * jnp.ones(p),
self._knots.to(low.unit).value,
high.value * jnp.ones(p),
]
)
self._knots_eval = knots_clamped
[docs]
@override
def to_dict(self) -> dict:
return {
'type': 'BSpline',
'knots': self._knots.value.tolist(),
'unit': str(self._knots.unit),
'degree': self._degree,
}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> BSpline:
return cls(knots=d['knots'] * u.Unit(d['unit']), degree=d.get('degree', 3))
@override
def __repr__(self) -> str:
return f'BSpline(degree={self._degree}, knots={self._knots})'
[docs]
@_register
class Bernstein(ContinuumForm):
"""Global Bernstein polynomial continuum, normalized at a reference wavelength.
Evaluates a Bernstein series on coordinates normalized to ``[0, 1]``
within the continuum region, normalized so that the continuum equals
``scale`` at ``norm_wav``.
The wavelength range is derived from the region bounds passed to
:meth:`evaluate`, normalized to ``[0, 1]`` for the Bernstein basis.
The ``stretch`` parameter optionally scales the region normalization.
The continuum is parameterized as ``scale * B(t) / B(t_nw)`` where
``B`` is the Bernstein series with first term fixed at 1.0, and ``t_nw``
is the normalized coordinate at ``norm_wav``.
Parameters
----------
degree : int
Polynomial degree (default 4). Number of coefficients = degree + 1.
stretch : float
Stretch factor to scale the region normalization (default 1.0).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``coeff_1, coeff_2, …`` — Remaining Bernstein coefficients (normalized to first term 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, degree: int = 4, stretch: float = 1.0) -> None:
from scipy.special import comb
self._degree = degree
if stretch <= 0:
msg = f'Bernstein stretch factor must be > 0, got {stretch}'
raise ValueError(msg)
self._stretch = stretch
self._binom = jnp.array(
[comb(degree, i, exact=True) for i in range(degree + 1)], dtype=float
)
# Static Bernstein-to-monomial conversion matrix.
self._bern2mono = _bernstein_to_mono_matrix(degree)
@property
@override
def is_linear(self) -> bool:
return False
@property
def degree(self) -> int:
"""Polynomial degree."""
return self._degree
@property
def stretch(self) -> float:
"""Stretch factor for the region normalization."""
return self._stretch
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._degree == 0:
return ('scale', 'norm_wav')
return (
'scale',
*(f'coeff_{i}' for i in range(1, self._degree + 1)),
'norm_wav',
)
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._degree + 1):
priors[f'coeff_{i}'] = Uniform(0, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# Bernstein coefficients share the same unit as the function value.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._degree + 1):
d[f'coeff_{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# u ∈ [-1, 1]; t = (u + 1) / 2 ∈ [0, 1] for Bernstein basis.
u, scale_factor = _normalize_wavelength(
wavelength, center, obs_low, obs_high, self._stretch
)
u_nw, _ = _normalize_wavelength(
params['norm_wav'], center, obs_low, obs_high, self._stretch
)
t = (u + 1) / 2
t_nw = (u_nw + 1) / 2
coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._degree + 1)]
)
mono = (self._bern2mono @ coeffs)[::-1] # ascending → descending
# LSF FWHM in t-coordinate: dt/dλ = 1 / (2 * scale_factor)
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / (2.0 * scale_factor)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
shape = jnp.polyval(convolved, t)
shape_nw = bernstein_eval(jnp.atleast_1d(t_nw), coeffs, self._binom)
return cast(Array, params['scale'] * shape / shape_nw)
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
u, scale_factor = _normalize_wavelength(
edges, center, obs_low, obs_high, self._stretch
)
u_nw, _ = _normalize_wavelength(
params['norm_wav'], center, obs_low, obs_high, self._stretch
)
t = (u + 1) / 2
t_nw = (u_nw + 1) / 2
coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._degree + 1)]
)
mono = (self._bern2mono @ coeffs)[::-1]
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / (2.0 * scale_factor)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
# Antiderivative in t-coord; rescale to λ via dλ/dt = 2 * scale_factor.
shape_anti = _polyint_at(convolved, t) * (2.0 * scale_factor)
shape_nw = bernstein_eval(jnp.atleast_1d(t_nw), coeffs, self._binom)
return params['scale'] * shape_anti / shape_nw
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Bernstein', 'degree': self._degree, 'stretch': self._stretch}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Bernstein:
return cls(degree=d['degree'], stretch=d.get('stretch', 1.0))
@override
def __repr__(self) -> str:
return f'Bernstein(degree={self._degree}, stretch={self._stretch})'
[docs]
@_register
class PowerLaw(ContinuumForm):
"""Power-law continuum: ``scale * (wavelength / norm_wav) ** beta``.
This form has no constructor parameters.
To share a consistent reference wavelength across multiple regions
(required for physically meaningful parameter sharing), pass a
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` with
``Fixed(value)`` carrying your chosen reference wavelength.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``beta`` — Power-law index (dimensionless).
Default prior: ``Uniform(-5, 5)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'beta', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'beta': Uniform(-5, 5),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'beta': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for PowerLaw.
nw = params['norm_wav']
return cast(Array, params['scale'] * (wavelength / nw) ** params['beta'])
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for PowerLaw (same as evaluate).
# Returns cumulative antiderivative at edges; diff/widths gives pixel averages.
edges_arr = jnp.asarray(edges)
nw = params['norm_wav']
beta = params['beta']
bp1 = beta + 1.0
# Power-rule antiderivative of (w/nw)^beta = w^bp1 / (bp1 * nw^beta).
anti_power = edges_arr**bp1 / (bp1 * nw**beta)
# beta = -1 fallback: antiderivative of 1/w is nw * ln(w).
anti_log = nw * jnp.log(edges_arr)
anti = jnp.where(jnp.abs(bp1) > 1e-10, anti_power, anti_log)
return cast(Array, params['scale'] * (anti - anti[0]))
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'PowerLaw'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> PowerLaw:
return cls()
[docs]
@_register
class Blackbody(ContinuumForm):
"""Planck blackbody continuum normalized at a reference wavelength.
Evaluates ``scale * B_λ(T) / B_λ(norm_wav, T)`` so that
*scale* directly represents the continuum flux at
``norm_wav``. Wavelength parameters may be in any unit;
automatic unit conversion to microns is applied internally.
``norm_wav`` is a named parameter with a default
``Fixed(region_center)`` prior. Pass an explicit
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` with
``Fixed(value)`` to pin it to a specific wavelength across multiple
regions — essential for physically consistent normalization when fitting
a single blackbody across disjoint spectral windows.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self) -> None:
self._micron_factor: float = 1.0
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for Blackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
return params['scale'] * bb
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the micron conversion factor for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Blackbody'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Blackbody:
return cls()
[docs]
@_register
class ModifiedBlackbody(ContinuumForm):
"""Modified blackbody: ``scale * B_λ(T) * (λ / norm_wav)^beta / B_λ(nw, T)``.
The power-law modifier *beta* broadens (beta > 0) or narrows (beta < 0)
the SED relative to a pure blackbody. *beta = 0* recovers
:class:`Blackbody`. Wavelength parameters may be in any unit;
automatic unit conversion to microns is applied internally.
``norm_wav`` is a named parameter with a default
``Fixed(region_center)`` prior. Share a
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` token
across regions to enforce a consistent reference wavelength.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``beta`` — Power-law modifier index (dimensionless).
Default prior: ``Uniform(-4, 4)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self) -> None:
self._micron_factor: float = 1.0
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'beta', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'beta': Uniform(-4, 4),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'beta': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for ModifiedBlackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
modifier = (wl_um / nw_um) ** params['beta']
return params['scale'] * bb * modifier
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the micron conversion factor for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'ModifiedBlackbody'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> ModifiedBlackbody:
return cls()
[docs]
@_register
class AttenuatedBlackbody(ContinuumForm):
"""Dust-attenuated blackbody continuum.
Evaluates
``scale * B_λ(T) / B_λ(nw,T) * exp(-tau_v * [(λ/lambda_ext)^alpha - (nw/lambda_ext)^alpha])``.
Extinction is normalized at ``norm_wav`` so that *scale*
represents the **observed** (attenuated) flux there. Negative *alpha*
gives steeper extinction at short wavelengths (typical dust law).
Wavelength parameters may be in any unit; automatic unit conversion
to microns is applied internally.
Parameters
----------
lambda_ext : astropy.units.Quantity
Reference wavelength for the extinction law. Must be
:class:`~astropy.units.Quantity` with any length unit — it will
be converted automatically. Defaults to ``5500 * u.AA``.
Notes
-----
``lambda_ext`` is a *static* configuration parameter (not
sampled). It is stored in microns internally and, along with
evaluation wavelengths, automatically converted at model-build time via
:meth:`_prepare`.
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Observed continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``tau_v`` — Optical depth at ``lambda_ext``.
Default prior: ``Uniform(0, 5)``.
* ``alpha`` — Dust extinction power-law index (negative = steeper at
short λ). Default prior: ``Uniform(-2, 0)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, lambda_ext: u.Quantity = 5500 * u.AA) -> None:
if isinstance(lambda_ext, u.Quantity):
self._lambda_ext: u.Quantity = _ensure_wavelength(
lambda_ext, 'lambda_ext', ndim=0
)
else:
raise ValueError(
f'lambda_ext must be an astropy Quantity with length units, got {lambda_ext}'
)
# Float used in evaluate(); initially microns, updated by _prepare().
self._lambda_ext_um: float = float(self._lambda_ext.to(u.um).value)
# Conversion factor from canonical unit to microns.
self._micron_factor: float = 1.0
@property
def lambda_ext(self) -> u.Quantity:
"""Extinction reference wavelength as an astropy Quantity."""
return self._lambda_ext
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'tau_v', 'alpha', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'tau_v': Uniform(0, 5),
'alpha': Uniform(-2, 0),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'tau_v': (False, None),
'alpha': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for AttenuatedBlackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
ext_data = (wl_um / self._lambda_ext_um) ** params['alpha']
ext_pivot = (nw_um / self._lambda_ext_um) ** params['alpha']
extinction = jnp.exp(-params['tau_v'] * (ext_data - ext_pivot))
return params['scale'] * bb * extinction
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the conversion factors for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'AttenuatedBlackbody', 'lambda_ext': self.lambda_ext}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> AttenuatedBlackbody:
return cls(lambda_ext=d.get('lambda_ext', 0.55))
@override
def __repr__(self) -> str:
return f'AttenuatedBlackbody(lambda_ext={self.lambda_ext})'
[docs]
@_register
class Template(ContinuumForm):
"""Interpolated spectral template continuum.
Loads one or more template spectra from a file readable by
:func:`astropy.table.Table.read`. The file must contain a wavelength
column (identified by its physical unit) and one or more flux columns.
Each flux column becomes a separate scale parameter named
``{column}_scale``.
The template is evaluated by linearly interpolating each column in the
rest frame and normalising so that ``{col}_scale`` equals the flux at
``norm_wav``:
.. code-block:: text
F(λ) = Σ_i {col_i}_scale * T_i(λ_rest) / T_i(norm_wav_rest)
where ``λ_rest = λ_obs / (1 + z_sys)``.
Parameters
----------
path : str or Path
Path to the template file. Any format supported by
:func:`~astropy.table.Table.read` is accepted (FITS, ECSV, …).
wavelength_colname : str, optional
Name of the wavelength column. If omitted, the column whose unit
has ``physical_type == 'length'`` is used; raises if the result is
ambiguous.
usecols : list or tuple of str, optional
Flux columns to load. Defaults to all non-wavelength columns.
Raises if any requested column is absent.
Notes
-----
The wavelength column must carry an astropy unit with
``physical_type == 'length'``. Flux columns without units, or with
units that are not spectral flux density (f_lambda), produce a
:class:`UserWarning` but are still accepted.
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``{col}_scale`` — Template amplitude at ``norm_wav``, one per column.
Default prior: ``Uniform(0, 2)``.
* ``norm_wav`` — Rest-frame reference wavelength (shared across all
columns). Default prior: ``Fixed(region_center_rest)``.
"""
def __init__(
self,
path: str | Path,
*,
wavelength_colname: str | None = None,
usecols: list[str] | tuple[str, ...] | None = None,
) -> None:
import warnings
from pathlib import Path as _Path
from astropy.table import Table
self._path = _Path(path)
self._wavelength_colname = wavelength_colname
self._usecols_arg = tuple(usecols) if usecols is not None else None
table = Table.read(self._path)
# --- identify wavelength column ---
if wavelength_colname is not None:
if wavelength_colname not in table.colnames:
msg = (
f"wavelength_colname '{wavelength_colname}' not found in "
f'{self._path}. Available columns: {table.colnames}'
)
raise ValueError(msg)
wl_col = wavelength_colname
else:
length_cols = [
c
for c in table.colnames
if getattr(table[c], 'unit', None) is not None
and u.Unit(table[c].unit).is_equivalent(u.m)
]
if len(length_cols) == 0:
msg = (
f'No column with a length unit found in {self._path}. '
'Set wavelength_colname explicitly.'
)
raise ValueError(msg)
if len(length_cols) > 1:
msg = (
f'Multiple length-unit columns in {self._path}: {length_cols}. '
'Set wavelength_colname explicitly.'
)
raise ValueError(msg)
wl_col = length_cols[0]
# --- wavelength array ---
wl_col_data = table[wl_col]
if not hasattr(wl_col_data, 'unit') or wl_col_data.unit is None:
msg = f"Wavelength column '{wl_col}' has no unit."
raise ValueError(msg)
self._lam_qty: u.Quantity = u.Quantity(
np.asarray(wl_col_data), unit=wl_col_data.unit
)
if not np.all(np.diff(self._lam_qty.value) > 0):
msg = f"Wavelength column '{wl_col}' is not strictly monotonically increasing."
raise ValueError(msg)
# --- flux columns ---
remaining = [c for c in table.colnames if c != wl_col]
if usecols is not None:
missing = [c for c in usecols if c not in table.colnames]
if missing:
msg = (
f'usecols columns not found in {self._path}: {missing}. '
f'Available columns: {table.colnames}'
)
raise ValueError(msg)
flux_cols = list(usecols)
else:
flux_cols = remaining
if len(flux_cols) == 0:
msg = f'No flux columns found in {self._path} after excluding the wavelength column.'
raise ValueError(msg)
# --- unit warnings and NaN/inf checks ---
_flam_ref = u.Unit('erg s-1 cm-2 AA-1')
for col in flux_cols:
col_data = table[col]
col_unit = getattr(col_data, 'unit', None)
if col_unit is None or str(col_unit) in ('', 'None'):
warnings.warn(
f"Template column '{col}' in {self._path} has no units. "
'Assuming f_lambda (spectral flux density per wavelength). '
'scale parameters will be in units of continuum_scale.',
UserWarning,
stacklevel=2,
)
elif not u.Unit(col_unit).is_equivalent(_flam_ref):
warnings.warn(
f"Template column '{col}' has unit '{col_unit}'. "
'Expected f_lambda (spectral flux density per wavelength).',
UserWarning,
stacklevel=2,
)
arr = np.asarray(col_data, dtype=float)
if not np.all(np.isfinite(arr)):
msg = f"Template column '{col}' contains NaN or inf values."
raise ValueError(msg)
lam_arr = self._lam_qty.to(u.um).value
if not np.all(np.isfinite(lam_arr)):
msg = f"Wavelength column '{wl_col}' contains NaN or inf values."
raise ValueError(msg)
self._flux_cols: list[str] = flux_cols
self._flam_arrays: dict[str, np.ndarray] = {
col: np.asarray(table[col], dtype=float) for col in flux_cols
}
# set by _prepare()
self._lam_um: np.ndarray | None = None
self._rest_low_um: float = 0.0
self._flam_eval: Array | None = None
self._lam_eval: Array | None = None
# ------------------------------------------------------------------
# ContinuumForm interface
# ------------------------------------------------------------------
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return (*[f'{c}_scale' for c in self._flux_cols], 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {}
for col in self._flux_cols:
priors[f'{col}_scale'] = Uniform(0, 2)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
d: dict[str, tuple[bool, u.UnitBase | None]] = {}
for col in self._flux_cols:
d[f'{col}_scale'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Convert template to microns and validate wavelength coverage."""
self._lam_um = self._lam_qty.to(u.um).value
self._rest_low_um = float(low.to(u.um).value)
rest_high_um = float(high.to(u.um).value)
# Use a small relative tolerance to absorb floating-point rounding from
# unit conversions (e.g. 9000 AA * 1e-4 µm/AA = 0.9000000000000001 µm).
_rtol = 1e-9
if self._lam_um[0] > self._rest_low_um * (1.0 + _rtol):
msg = (
f'Template wavelength grid starts at {self._lam_um[0]:.4f} µm '
f'but region lower bound is {low.to(u.um):.4f}. '
'Template does not cover the full region.'
)
raise ValueError(msg)
if self._lam_um[-1] < rest_high_um * (1.0 - _rtol):
msg = (
f'Template wavelength grid ends at {self._lam_um[-1]:.4f} µm '
f'but region upper bound is {high.to(u.um):.4f}. '
'Template does not cover the full region.'
)
raise ValueError(msg)
# Stack flux arrays: shape (N_cols, N_lam)
self._flam_eval = jnp.array(
np.stack([self._flam_arrays[c] for c in self._flux_cols], axis=0)
)
self._lam_eval = jnp.array(self._lam_um)
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# Convert observed-frame wavelength to rest-frame microns.
# obs_low = rest_low_canonical * (1 + z_sys), and _rest_low_um is
# rest_low in microns, so wavelength * _rest_low_um / obs_low gives
# rest-frame microns regardless of the canonical wavelength unit.
assert self._flam_eval is not None and self._lam_eval is not None, (
'Template._prepare() must be called before evaluate(). '
'Wrap the Template in a ContinuumRegion to trigger _prepare().'
)
wl = jnp.asarray(wavelength)
scale = self._rest_low_um / obs_low
lam_rest_um = wl * scale
norm_wav_rest_um = params['norm_wav'] * scale
total = jnp.zeros_like(wl)
for i, col in enumerate(self._flux_cols):
flam_row = self._flam_eval[i]
t_at_lam = jnp.interp(lam_rest_um, self._lam_eval, flam_row)
t_at_norm = jnp.interp(norm_wav_rest_um, self._lam_eval, flam_row)
total = total + params[f'{col}_scale'] * t_at_lam / t_at_norm
return total
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------
[docs]
@override
def to_dict(self) -> dict:
d: dict = {'type': 'Template', 'path': str(self._path)}
if self._wavelength_colname is not None:
d['wavelength_colname'] = self._wavelength_colname
if self._usecols_arg is not None:
d['usecols'] = list(self._usecols_arg)
return d
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Template:
return cls(
d['path'],
wavelength_colname=d.get('wavelength_colname'),
usecols=d.get('usecols'),
)
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, Template):
return NotImplemented
return (
self._path == other._path
and self._wavelength_colname == other._wavelength_colname
and self._usecols_arg == other._usecols_arg
)
@override
def __hash__(self) -> int:
return hash(
(
type(self).__name__,
self._path,
self._wavelength_colname,
self._usecols_arg,
)
)
@override
def __repr__(self) -> str:
cols = ', '.join(self._flux_cols)
return f'Template({self._path.name!r}, columns=[{cols}])'