"""Continuum functional forms: abstract base and concrete implementations."""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, cast, override
import jax
import jax.numpy as jnp
import numpy as np
from astropy import units as u
from jax import Array
from jax.typing import ArrayLike
from unite._utils import _ensure_wavelength, _get_conversion_factor, _make_register
from unite.continuum.functions import (
bernstein_eval,
bspline_eval,
chebval,
planck_function,
)
from unite.prior import Fixed, Prior, Uniform
if TYPE_CHECKING:
pass
def _gaussian_convolve_poly(coeffs: Array, lsf_fwhm: ArrayLike) -> Array:
"""Analytically convolve a polynomial with a Gaussian LSF.
Given a polynomial ``p(x) = Σ c_i x^(n-i)`` (NumPy descending-order
convention, degree *n* = ``len(coeffs) - 1``) and a Gaussian kernel with
FWHM ``lsf_fwhm``, return the coefficients of the convolved polynomial
``(p * G)(x)``.
A polynomial convolved with a Gaussian is still a polynomial of the same
degree. Only even Gaussian moments contribute (odd moments vanish by
symmetry), giving the coefficient update:
.. code-block:: text
c_j_new = sum_{k=0,2,4,...}^{N-j} c_{j+k} * C(j+k, k) * (k-1)!! * sigma^k
where ``(k-1)!!`` is the double factorial (the *k*-th even moment of
a standard normal). See the
:doc:`polynomial derivation </derivations/polynomial>` for a full derivation.
For a monomial ``x^k`` convolved with ``N(0, s^2)``, the result is::
sum_{j=0}^{floor(k/2)} C(k, 2j) * (2j-1)!! * s^{2j} * x^{k-2j}
Parameters
----------
coeffs : Array, shape ``(n+1,)``
Polynomial coefficients in **descending** order
(i.e. ``coeffs[0]`` is the leading coefficient for ``x^n``).
lsf_fwhm : ArrayLike
Scalar (or broadcastable) LSF FWHM in the same unit as *x*.
Returns
-------
Array, shape ``(n+1,)``
Convolved polynomial coefficients, same descending-order convention.
"""
sigma2 = (jnp.asarray(lsf_fwhm) / (2.0 * np.sqrt(2.0 * np.log(2.0)))) ** 2
n = coeffs.shape[0] - 1 # polynomial degree; static at trace time
max_half = n // 2 + 1
# Even Gaussian moments M[j] = (2j-1)!! * sigma^{2j}.
# M[0] = 1; M[j] = M[j-1] * (2j-1) * sigma^2 for j >= 1.
if max_half > 1:
def _moment_step(carry, j):
cur = carry * (2 * j - 1) * sigma2
return cur, cur
_, rest = jax.lax.scan(_moment_step, 1.0, jnp.arange(1, max_half))
moments = jnp.concatenate([jnp.array([1.0]), rest])
else:
moments = jnp.ones(1)
# Binomial coefficients C(k, 2j): pure NumPy — depends only on n (static).
binom_np = np.zeros((n + 1, max_half))
for k in range(n + 1):
for j in range(min(k // 2, max_half - 1) + 1):
val = 1.0
for m in range(2 * j):
val = val * (k - m) / (m + 1)
binom_np[k, j] = val
# Static scatter tensor: A[out_idx, i, j] = binom_np[n-i, j] when
# out_idx == i + 2*j (and the entry is in range), else 0.
# out[out_idx] = sum_i sum_j coeffs[i] * A[out_idx, i, j] * moments[j]
# = einsum('oij,j,i->o', A, moments, coeffs).
a_np = np.zeros((n + 1, n + 1, max_half))
for i in range(n + 1):
k = n - i
for j in range(min(k // 2, max_half - 1) + 1):
a_np[i + 2 * j, i, j] = binom_np[k, j]
return jnp.einsum('oij,j,i->o', jnp.asarray(a_np), moments, coeffs)
def _polyint_at(coeffs: Array, x: ArrayLike) -> Array:
"""Antiderivative of a polynomial evaluated at *x*.
Given descending-order coefficients ``[a_n, ..., a_0]``, return the
antiderivative ``P(x) = ∫ p(x') dx' = a_n/(n+1) * x^{n+1} + ... + a_0 * x``
evaluated at each entry of *x* (constant of integration zero).
Used by polynomial-based continuum forms to produce a cumulative-at-edges
array: ``jnp.diff(P_at_edges) / jnp.diff(edges)`` recovers the exact
pixel-averaged value of *p* over each pixel.
Parameters
----------
coeffs : Array, shape ``(n+1,)``
Polynomial coefficients in descending order.
x : ArrayLike
Points at which to evaluate the antiderivative.
Returns
-------
Array
Antiderivative value at each *x*.
"""
# Antiderivative coefficients (also descending), with zero constant term.
n = coeffs.shape[0]
divisors = jnp.arange(n, 0, -1, dtype=coeffs.dtype)
anti = jnp.concatenate([coeffs / divisors, jnp.array([0.0], dtype=coeffs.dtype)])
return jnp.polyval(anti, x)
def _cheb_to_mono_matrix(n: int) -> Array:
"""Build the (n x n) Chebyshev-to-monomial conversion matrix.
Returns a matrix *M* such that ``M @ cheb_coeffs`` gives monomial
coefficients in **ascending** order (``[a0, a1, ..., a_{n-1}]``).
``cheb_coeffs = [c0, c1, ..., c_{n-1}]`` represents
``c0*T0(x) + c1*T1(x) + ... + c_{n-1}*T_{n-1}(x)``.
Built at Python time (not traced), so the matrix is a static constant.
"""
import numpy as np
from numpy.polynomial.chebyshev import cheb2poly
M = np.zeros((n, n)) # noqa: N806
for k in range(n):
basis = np.zeros(n)
basis[k] = 1.0
# cheb2poly returns ascending monomial coefficients for T_k
poly = cheb2poly(basis)
M[: len(poly), k] = poly
return jnp.array(M)
def _bernstein_to_mono_matrix(n: int) -> Array:
"""Build the ((n+1) x (n+1)) Bernstein-to-monomial conversion matrix.
Returns a matrix *M* such that ``M @ bern_coeffs`` gives monomial
coefficients in **ascending** order for the polynomial on ``[0, 1]``.
``bern_coeffs = [b0, b1, ..., b_n]`` represents
``Σ b_i * C(n,i) * t^i * (1-t)^{n-i}``.
"""
import numpy as np
from scipy.special import comb
M = np.zeros((n + 1, n + 1)) # noqa: N806
# Monomial coeff for t^k: Σ_{i=0}^{k} C(n,i) C(n-i, k-i) (-1)^{k-i} b_i
for k in range(n + 1):
for i in range(k + 1):
M[k, i] = (
comb(n, i, exact=True)
* comb(n - i, k - i, exact=True)
* (-1) ** (k - i)
)
return jnp.array(M)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_FORM_REGISTRY: dict[str, type[ContinuumForm]] = {}
_register = _make_register(_FORM_REGISTRY)
# ---------------------------------------------------------------------------
# Abstract base class
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Deserialization helper
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Piecewise / region-local forms
# ---------------------------------------------------------------------------
[docs]
@_register
class Linear(ContinuumForm):
"""Linear continuum: ``scale + slope * (wavelength - norm_wav)``.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``slope`` — Continuum slope in flux per wavelength unit.
Default prior: ``Uniform(-10, 10)``.
* ``norm_wav`` — Reference wavelength where the
continuum equals ``scale``.
Default prior: ``Fixed(region_center)``.
"""
@property
@override
def is_linear(self) -> bool:
return True
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'angle', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'angle': Uniform(-np.pi / 2, np.pi / 2),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'angle': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
nw = params['norm_wav']
return params['scale'] + jnp.tan(params['angle']) * (wavelength - nw)
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# Linear is preserved exactly under Gaussian convolution, so LSF is a
# no-op. The antiderivative of ``scale + slope * (λ - nw)`` is
# ``scale * (λ - nw) + slope * (λ - nw)² / 2`` (constant of integration
# zero); ``jnp.diff`` then recovers the exact pixel integral.
nw = params['norm_wav']
x = jnp.asarray(edges) - nw
slope = jnp.tan(params['angle'])
return jnp.asarray(params['scale']) * x + 0.5 * slope * x * x
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Linear'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Linear:
return cls()
[docs]
@_register
class Polynomial(ContinuumForm):
"""Polynomial continuum of configurable degree.
Evaluates ``scale + c1*x + c2*x**2 + ...`` where
``x = wavelength - norm_wav``.
Parameters
----------
degree : int
Polynomial degree (default 1).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``c1, c2, ...`` — Higher-order polynomial coefficients.
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, degree: int = 1) -> None:
if degree < 0:
msg = f'Polynomial degree must be >= 0, got {degree}'
raise ValueError(msg)
self._degree = degree
@property
@override
def is_linear(self) -> bool:
return True
@property
def degree(self) -> int:
"""Polynomial degree."""
return self._degree
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._degree == 0:
return ('scale', 'norm_wav')
return ('scale', *(f'c{i}' for i in range(1, self._degree + 1)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._degree + 1):
priors[f'c{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._degree + 1):
d[f'c{i}'] = (True, flux_unit / wl_unit**i)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
nw = params['norm_wav']
x = wavelength - nw
# Monomial coefficients in descending order: c_n, ..., c_1, scale.
mono = jnp.array(
[params[f'c{i}'] for i in range(self._degree, 0, -1)] + [params['scale']]
)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm)
return jnp.polyval(convolved, x)
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
nw = params['norm_wav']
x = jnp.asarray(edges) - nw
mono = jnp.array(
[params[f'c{i}'] for i in range(self._degree, 0, -1)] + [params['scale']]
)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm)
return _polyint_at(convolved, x)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Polynomial', 'degree': self._degree}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Polynomial:
return cls(degree=d['degree'])
@override
def __repr__(self) -> str:
return f'Polynomial(degree={self._degree})'
def _normalize_wavelength(
wavelength: ArrayLike,
center: float,
obs_low: float,
obs_high: float,
stretch: float,
) -> tuple[Array, float]:
"""Map observed wavelengths to the normalised coordinate used by Chebyshev/Bernstein.
Both forms normalize ``wavelength`` to the interval ``[-1, 1]`` (for
Chebyshev) or ``[0, 1]`` (for Bernstein) using the same scale factor
``(obs_high - obs_low) / 2 * stretch``. This helper computes the
shared scale factor and the intermediate ``u = (w - center) / scale``
coordinate; callers apply the form-specific final transformation.
Returns ``(u, scale_factor)`` where ``u = (wavelength - center) / scale_factor``.
"""
scale_factor = (obs_high - obs_low) / 2 * stretch
u = (jnp.asarray(wavelength) - center) / scale_factor
return u, scale_factor
[docs]
@_register
class Chebyshev(ContinuumForm):
"""Chebyshev polynomial continuum of configurable order.
Evaluates a Chebyshev series on coordinates normalized to ``[-1, 1]``
within the continuum region, normalized so that the continuum equals
``scale`` at ``norm_wav``. Numerically more stable than a standard
polynomial basis for higher orders.
The x-coordinate is ``(wavelength - center) / (half_width * stretch)``
where ``half_width`` is derived from the region bounds passed to
:meth:`evaluate`, and ``stretch`` is a form-specific scaling factor
(default ``1.0`` for identity normalization).
The continuum is parameterized as ``scale * T(x) / T(x_nw)`` where
``T`` is the Chebyshev series with constant term fixed at 1.0,
and ``x_nw`` is the normalized coordinate at ``norm_wav``.
Parameters
----------
order : int
Chebyshev order (default 2). Number of coefficients = order + 1.
stretch : float
Stretch factor to scale the region normalization (default 1.0).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``c1, c2, ...`` — Higher-order Chebyshev coefficients (normalized to constant term 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, order: int = 2, stretch: float = 1.0) -> None:
if order < 0:
msg = f'Chebyshev order must be >= 0, got {order}'
raise ValueError(msg)
self._order = order
if stretch <= 0:
msg = f'Chebyshev stretch factor must be > 0, got {stretch}'
raise ValueError(msg)
self._stretch = stretch
# Static Chebyshev-to-monomial conversion matrix.
self._cheb2mono = _cheb_to_mono_matrix(order + 1)
@property
@override
def is_linear(self) -> bool:
return False
@property
def order(self) -> int:
"""Chebyshev order."""
return self._order
@property
def stretch(self) -> float:
"""Stretch factor for the region normalization."""
return self._stretch
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._order == 0:
return ('scale', 'norm_wav')
return ('scale', *(f'c{i}' for i in range(1, self._order + 1)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._order + 1):
priors[f'c{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# x is normalised to [-1, 1], so all coefficients have unit flux_unit.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._order + 1):
d[f'c{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
x, scale_factor = _normalize_wavelength(
wavelength, center, obs_low, obs_high, self._stretch
)
x_nw, _ = _normalize_wavelength(
params['norm_wav'], center, obs_low, obs_high, self._stretch
)
cheb_coeffs = jnp.array(
[1.0] + [params[f'c{i}'] for i in range(1, self._order + 1)]
)
mono = (self._cheb2mono @ cheb_coeffs)[::-1] # ascending → descending
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / scale_factor
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
shape = jnp.polyval(convolved, x)
shape_nw = chebval(x_nw, cheb_coeffs)
return params['scale'] * shape / shape_nw
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
x, scale_factor = _normalize_wavelength(
edges, center, obs_low, obs_high, self._stretch
)
x_nw, _ = _normalize_wavelength(
params['norm_wav'], center, obs_low, obs_high, self._stretch
)
cheb_coeffs = jnp.array(
[1.0] + [params[f'c{i}'] for i in range(1, self._order + 1)]
)
mono = (self._cheb2mono @ cheb_coeffs)[::-1]
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / scale_factor
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
# Antiderivative in normalised coord; rescale to λ by dλ/du = scale_factor.
shape_anti = _polyint_at(convolved, x) * scale_factor
shape_nw = chebval(x_nw, cheb_coeffs)
return params['scale'] * shape_anti / shape_nw
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Chebyshev', 'order': self._order, 'stretch': self._stretch}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Chebyshev:
return cls(order=d['order'], stretch=d.get('stretch', 1.0))
@override
def __repr__(self) -> str:
return f'Chebyshev(order={self._order}, stretch={self._stretch})'
[docs]
@_register
class BSpline(ContinuumForm):
"""B-spline continuum with local knot control.
The knot vector must be set via *knots* at construction time (typically
derived from the wavelength coverage of the spectrum). Knots should be
in the same wavelength unit as the :class:`ContinuumRegion` bounds;
they are converted to the canonical unit at region construction time via
:meth:`_prepare`. Knots must fall within the region bounds.
The continuum is normalized so that it equals ``scale`` at ``norm_wav``,
parameterized as ``scale * S(u) / S(u_nw)`` where ``S`` is the B-spline
series with first coefficient fixed at 1.0, and ``u_nw`` is the normalized
coordinate at ``norm_wav``.
Parameters
----------
knots : u.Quantity
Knot vector in wavelength units. It is automatically clamped at the region bounds.
degree : int
Spline degree (default 3 for cubic).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``coeff_1, coeff_2, …`` — Remaining B-spline coefficients (normalized to first 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, knots: u.Quantity, degree: int = 3) -> None:
if isinstance(knots, u.Quantity):
self._knots: u.Quantity = _ensure_wavelength(knots, 'knots', ndim=1)
else:
raise ValueError(
f'knots must be an astropy Quantity with length units, got {knots}'
)
self._degree = degree
self._n_basis = len(self._knots) + degree + 1
@property
@override
def is_linear(self) -> bool:
return False
@property
def degree(self) -> int:
"""Spline degree."""
return self._degree
@property
def n_basis(self) -> int:
"""Number of B-spline basis functions."""
return self._n_basis
@property
def knots(self) -> u.Quantity:
"""Knot vector (in the original units passed at construction)."""
return self._knots
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._n_basis == 1:
return ('scale', 'norm_wav')
return ('scale', *(f'coeff_{i}' for i in range(1, self._n_basis)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._n_basis):
priors[f'coeff_{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# B-spline coefficients share the same unit as the function value.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._n_basis):
d[f'coeff_{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for BSpline (non-polynomial basis).
# Normalize wavelengths relative to the knot range
obs_range = obs_high - obs_low
# Map wavelengths to the same coordinate system as the knots
u = 2 * (wavelength - obs_low) / obs_range - 1
# Map norm_wav to the same coordinate system
nw = params['norm_wav']
u_nw = 2 * (nw - obs_low) / obs_range - 1
# Also map knots to [-1, 1]
knots_norm = 2 * (self._knots_eval - obs_low) / obs_range - 1
shape_coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._n_basis)]
)
shape = bspline_eval(u, shape_coeffs, knots_norm, self._degree)
_snw = bspline_eval(
jnp.atleast_1d(u_nw), shape_coeffs, knots_norm, self._degree
)
shape_nw = _snw[0]
return params['scale'] * shape / shape_nw
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Validate and prepare the knot vector for evaluation within the region bounds."""
if any((self._knots <= low) or (self._knots >= high)):
msg = f'All knots must be within the region bounds [{low}, {high}], got {self._knots}'
raise ValueError(msg)
# Add n-1 extra knots at each end for clamping (B-spline convention)
p = self._degree + 1
knots_clamped = jnp.concatenate(
[
low.value * jnp.ones(p),
self._knots.to(low.unit).value,
high.value * jnp.ones(p),
]
)
self._knots_eval = knots_clamped
[docs]
@override
def to_dict(self) -> dict:
return {
'type': 'BSpline',
'knots': self._knots.value.tolist(),
'unit': str(self._knots.unit),
'degree': self._degree,
}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> BSpline:
return cls(knots=d['knots'] * u.Unit(d['unit']), degree=d.get('degree', 3))
@override
def __repr__(self) -> str:
return f'BSpline(degree={self._degree}, knots={self._knots})'
[docs]
@_register
class Bernstein(ContinuumForm):
"""Global Bernstein polynomial continuum, normalized at a reference wavelength.
Evaluates a Bernstein series on coordinates normalized to ``[0, 1]``
within the continuum region, normalized so that the continuum equals
``scale`` at ``norm_wav``.
The wavelength range is derived from the region bounds passed to
:meth:`evaluate`, normalized to ``[0, 1]`` for the Bernstein basis.
The ``stretch`` parameter optionally scales the region normalization.
The continuum is parameterized as ``scale * B(t) / B(t_nw)`` where
``B`` is the Bernstein series with first term fixed at 1.0, and ``t_nw``
is the normalized coordinate at ``norm_wav``.
Parameters
----------
degree : int
Polynomial degree (default 4). Number of coefficients = degree + 1.
stretch : float
Stretch factor to scale the region normalization (default 1.0).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``coeff_1, coeff_2, …`` — Remaining Bernstein coefficients (normalized to first term 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, degree: int = 4, stretch: float = 1.0) -> None:
from scipy.special import comb
self._degree = degree
if stretch <= 0:
msg = f'Bernstein stretch factor must be > 0, got {stretch}'
raise ValueError(msg)
self._stretch = stretch
self._binom = jnp.array(
[comb(degree, i, exact=True) for i in range(degree + 1)], dtype=float
)
# Static Bernstein-to-monomial conversion matrix.
self._bern2mono = _bernstein_to_mono_matrix(degree)
@property
@override
def is_linear(self) -> bool:
return False
@property
def degree(self) -> int:
"""Polynomial degree."""
return self._degree
@property
def stretch(self) -> float:
"""Stretch factor for the region normalization."""
return self._stretch
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._degree == 0:
return ('scale', 'norm_wav')
return (
'scale',
*(f'coeff_{i}' for i in range(1, self._degree + 1)),
'norm_wav',
)
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._degree + 1):
priors[f'coeff_{i}'] = Uniform(0, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# Bernstein coefficients share the same unit as the function value.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._degree + 1):
d[f'coeff_{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# u ∈ [-1, 1]; t = (u + 1) / 2 ∈ [0, 1] for Bernstein basis.
u, scale_factor = _normalize_wavelength(
wavelength, center, obs_low, obs_high, self._stretch
)
u_nw, _ = _normalize_wavelength(
params['norm_wav'], center, obs_low, obs_high, self._stretch
)
t = (u + 1) / 2
t_nw = (u_nw + 1) / 2
coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._degree + 1)]
)
mono = (self._bern2mono @ coeffs)[::-1] # ascending → descending
# LSF FWHM in t-coordinate: dt/dλ = 1 / (2 * scale_factor)
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / (2.0 * scale_factor)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
shape = jnp.polyval(convolved, t)
shape_nw = bernstein_eval(jnp.atleast_1d(t_nw), coeffs, self._binom)
return cast(Array, params['scale'] * shape / shape_nw)
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
u, scale_factor = _normalize_wavelength(
edges, center, obs_low, obs_high, self._stretch
)
u_nw, _ = _normalize_wavelength(
params['norm_wav'], center, obs_low, obs_high, self._stretch
)
t = (u + 1) / 2
t_nw = (u_nw + 1) / 2
coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._degree + 1)]
)
mono = (self._bern2mono @ coeffs)[::-1]
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / (2.0 * scale_factor)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
# Antiderivative in t-coord; rescale to λ via dλ/dt = 2 * scale_factor.
shape_anti = _polyint_at(convolved, t) * (2.0 * scale_factor)
shape_nw = bernstein_eval(jnp.atleast_1d(t_nw), coeffs, self._binom)
return params['scale'] * shape_anti / shape_nw
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Bernstein', 'degree': self._degree, 'stretch': self._stretch}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Bernstein:
return cls(degree=d['degree'], stretch=d.get('stretch', 1.0))
@override
def __repr__(self) -> str:
return f'Bernstein(degree={self._degree}, stretch={self._stretch})'
[docs]
@_register
class PowerLaw(ContinuumForm):
"""Power-law continuum: ``scale * (wavelength / norm_wav) ** beta``.
This form has no constructor parameters.
To share a consistent reference wavelength across multiple regions
(required for physically meaningful parameter sharing), pass a
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` with
``Fixed(value)`` carrying your chosen reference wavelength.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``beta`` — Power-law index (dimensionless).
Default prior: ``Uniform(-5, 5)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'beta', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'beta': Uniform(-5, 5),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'beta': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for PowerLaw.
nw = params['norm_wav']
return cast(Array, params['scale'] * (wavelength / nw) ** params['beta'])
[docs]
@override
def integrate(
self,
edges: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for PowerLaw (same as evaluate).
# Returns cumulative antiderivative at edges; diff/widths gives pixel averages.
edges_arr = jnp.asarray(edges)
nw = params['norm_wav']
beta = params['beta']
bp1 = beta + 1.0
# Power-rule antiderivative of (w/nw)^beta = w^bp1 / (bp1 * nw^beta).
anti_power = edges_arr**bp1 / (bp1 * nw**beta)
# beta = -1 fallback: antiderivative of 1/w is nw * ln(w).
anti_log = nw * jnp.log(edges_arr)
anti = jnp.where(jnp.abs(bp1) > 1e-10, anti_power, anti_log)
return cast(Array, params['scale'] * (anti - anti[0]))
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'PowerLaw'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> PowerLaw:
return cls()
[docs]
@_register
class Blackbody(ContinuumForm):
"""Planck blackbody continuum normalized at a reference wavelength.
Evaluates ``scale * B_λ(T) / B_λ(norm_wav, T)`` so that
*scale* directly represents the continuum flux at
``norm_wav``. Wavelength parameters may be in any unit;
automatic unit conversion to microns is applied internally.
``norm_wav`` is a named parameter with a default
``Fixed(region_center)`` prior. Pass an explicit
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` with
``Fixed(value)`` to pin it to a specific wavelength across multiple
regions — essential for physically consistent normalization when fitting
a single blackbody across disjoint spectral windows.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self) -> None:
self._micron_factor: float = 1.0
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for Blackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
return params['scale'] * bb
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the micron conversion factor for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Blackbody'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Blackbody:
return cls()
[docs]
@_register
class ModifiedBlackbody(ContinuumForm):
"""Modified blackbody: ``scale * B_λ(T) * (λ / norm_wav)^beta / B_λ(nw, T)``.
The power-law modifier *beta* broadens (beta > 0) or narrows (beta < 0)
the SED relative to a pure blackbody. *beta = 0* recovers
:class:`Blackbody`. Wavelength parameters may be in any unit;
automatic unit conversion to microns is applied internally.
``norm_wav`` is a named parameter with a default
``Fixed(region_center)`` prior. Share a
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` token
across regions to enforce a consistent reference wavelength.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``beta`` — Power-law modifier index (dimensionless).
Default prior: ``Uniform(-4, 4)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self) -> None:
self._micron_factor: float = 1.0
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'beta', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'beta': Uniform(-4, 4),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'beta': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for ModifiedBlackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
modifier = (wl_um / nw_um) ** params['beta']
return params['scale'] * bb * modifier
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the micron conversion factor for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'ModifiedBlackbody'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> ModifiedBlackbody:
return cls()
[docs]
@_register
class AttenuatedBlackbody(ContinuumForm):
"""Dust-attenuated blackbody continuum.
Evaluates
``scale * B_λ(T) / B_λ(nw,T) * exp(-tau_v * [(λ/lambda_ext)^alpha - (nw/lambda_ext)^alpha])``.
Extinction is normalized at ``norm_wav`` so that *scale*
represents the **observed** (attenuated) flux there. Negative *alpha*
gives steeper extinction at short wavelengths (typical dust law).
Wavelength parameters may be in any unit; automatic unit conversion
to microns is applied internally.
Parameters
----------
lambda_ext : astropy.units.Quantity
Reference wavelength for the extinction law. Must be
:class:`~astropy.units.Quantity` with any length unit — it will
be converted automatically. Defaults to ``5500 * u.AA``.
Notes
-----
``lambda_ext`` is a *static* configuration parameter (not
sampled). It is stored in microns internally and, along with
evaluation wavelengths, automatically converted at model-build time via
:meth:`_prepare`.
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Observed continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``tau_v`` — Optical depth at ``lambda_ext``.
Default prior: ``Uniform(0, 5)``.
* ``alpha`` — Dust extinction power-law index (negative = steeper at
short λ). Default prior: ``Uniform(-2, 0)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, lambda_ext: u.Quantity = 5500 * u.AA) -> None:
if isinstance(lambda_ext, u.Quantity):
self._lambda_ext: u.Quantity = _ensure_wavelength(
lambda_ext, 'lambda_ext', ndim=0
)
else:
raise ValueError(
f'lambda_ext must be an astropy Quantity with length units, got {lambda_ext}'
)
# Float used in evaluate(); initially microns, updated by _prepare().
self._lambda_ext_um: float = float(self._lambda_ext.to(u.um).value)
# Conversion factor from canonical unit to microns.
self._micron_factor: float = 1.0
@property
def lambda_ext(self) -> u.Quantity:
"""Extinction reference wavelength as an astropy Quantity."""
return self._lambda_ext
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'tau_v', 'alpha', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'tau_v': Uniform(0, 5),
'alpha': Uniform(-2, 0),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'tau_v': (False, None),
'alpha': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# LSF convolution is not supported for AttenuatedBlackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
ext_data = (wl_um / self._lambda_ext_um) ** params['alpha']
ext_pivot = (nw_um / self._lambda_ext_um) ** params['alpha']
extinction = jnp.exp(-params['tau_v'] * (ext_data - ext_pivot))
return params['scale'] * bb * extinction
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the conversion factors for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'AttenuatedBlackbody', 'lambda_ext': self.lambda_ext}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> AttenuatedBlackbody:
return cls(lambda_ext=d.get('lambda_ext', 0.55))
@override
def __repr__(self) -> str:
return f'AttenuatedBlackbody(lambda_ext={self.lambda_ext})'
[docs]
@_register
class Template(ContinuumForm):
"""Interpolated spectral template continuum.
Loads one or more template spectra from a file readable by
:func:`astropy.table.Table.read`. The file must contain a wavelength
column (identified by its physical unit) and one or more flux columns.
Each flux column becomes a separate scale parameter named
``{column}_scale``.
The template is evaluated by linearly interpolating each column in the
rest frame and normalising so that ``{col}_scale`` equals the flux at
``norm_wav``:
.. code-block:: text
F(λ) = Σ_i {col_i}_scale * T_i(λ_rest) / T_i(norm_wav_rest)
where ``λ_rest = λ_obs / (1 + z_sys)``.
Parameters
----------
path : str or Path
Path to the template file. Any format supported by
:func:`~astropy.table.Table.read` is accepted (FITS, ECSV, …).
wavelength_colname : str, optional
Name of the wavelength column. If omitted, the column whose unit
has ``physical_type == 'length'`` is used; raises if the result is
ambiguous.
usecols : list or tuple of str, optional
Flux columns to load. Defaults to all non-wavelength columns.
Raises if any requested column is absent.
Notes
-----
The wavelength column must carry an astropy unit with
``physical_type == 'length'``. Flux columns without units, or with
units that are not spectral flux density (f_lambda), produce a
:class:`UserWarning` but are still accepted.
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``{col}_scale`` — Template amplitude at ``norm_wav``, one per column.
Default prior: ``Uniform(0, 2)``.
* ``norm_wav`` — Rest-frame reference wavelength (shared across all
columns). Default prior: ``Fixed(region_center_rest)``.
**LSF convolution behaviour:**
In *analytic* mode the ``lsf_fwhm`` argument is ignored — the template
returns raw interpolated values with no convolution applied. In
*convolution* mode (``integration_mode='convolution'``) the full
numerical LSF kernel is applied externally by the model, but the kernel
assumes the template is intrinsically unresolved. Templates that
already carry native spectral resolution (e.g. stellar population models
convolved to a library resolution) will be further broadened by the
instrument LSF, producing an effective resolution equal to the
convolution of both. Use a template whose native resolution is well
below the instrument LSF, or deconvolve it beforehand, to avoid
over-convolution.
"""
def __init__(
self,
path: str | Path,
*,
wavelength_colname: str | None = None,
usecols: list[str] | tuple[str, ...] | None = None,
) -> None:
import warnings
from pathlib import Path as _Path
from astropy.table import Table
self._path = _Path(path)
self._wavelength_colname = wavelength_colname
self._usecols_arg = tuple(usecols) if usecols is not None else None
table = Table.read(self._path)
# --- identify wavelength column ---
if wavelength_colname is not None:
if wavelength_colname not in table.colnames:
msg = (
f"wavelength_colname '{wavelength_colname}' not found in "
f'{self._path}. Available columns: {table.colnames}'
)
raise ValueError(msg)
wl_col = wavelength_colname
else:
length_cols = [
c
for c in table.colnames
if getattr(table[c], 'unit', None) is not None
and u.Unit(table[c].unit).is_equivalent(u.m)
]
if len(length_cols) == 0:
msg = (
f'No column with a length unit found in {self._path}. '
'Set wavelength_colname explicitly.'
)
raise ValueError(msg)
if len(length_cols) > 1:
msg = (
f'Multiple length-unit columns in {self._path}: {length_cols}. '
'Set wavelength_colname explicitly.'
)
raise ValueError(msg)
wl_col = length_cols[0]
# --- wavelength array ---
wl_col_data = table[wl_col]
if not hasattr(wl_col_data, 'unit') or wl_col_data.unit is None:
msg = f"Wavelength column '{wl_col}' has no unit."
raise ValueError(msg)
self._lam_qty: u.Quantity = u.Quantity(
np.asarray(wl_col_data), unit=wl_col_data.unit
)
if not np.all(np.diff(self._lam_qty.value) > 0):
msg = f"Wavelength column '{wl_col}' is not strictly monotonically increasing."
raise ValueError(msg)
# --- flux columns ---
remaining = [c for c in table.colnames if c != wl_col]
if usecols is not None:
missing = [c for c in usecols if c not in table.colnames]
if missing:
msg = (
f'usecols columns not found in {self._path}: {missing}. '
f'Available columns: {table.colnames}'
)
raise ValueError(msg)
flux_cols = list(usecols)
else:
flux_cols = remaining
if len(flux_cols) == 0:
msg = f'No flux columns found in {self._path} after excluding the wavelength column.'
raise ValueError(msg)
# --- unit warnings and NaN/inf checks ---
_flam_ref = u.Unit('erg s-1 cm-2 AA-1')
for col in flux_cols:
col_data = table[col]
col_unit = getattr(col_data, 'unit', None)
if col_unit is None or str(col_unit) in ('', 'None'):
warnings.warn(
f"Template column '{col}' in {self._path} has no units. "
'Assuming f_lambda (spectral flux density per wavelength). '
'scale parameters will be in units of continuum_scale.',
UserWarning,
stacklevel=2,
)
elif not u.Unit(col_unit).is_equivalent(_flam_ref):
warnings.warn(
f"Template column '{col}' has unit '{col_unit}'. "
'Expected f_lambda (spectral flux density per wavelength).',
UserWarning,
stacklevel=2,
)
arr = np.asarray(col_data, dtype=float)
if not np.all(np.isfinite(arr)):
msg = f"Template column '{col}' contains NaN or inf values."
raise ValueError(msg)
lam_arr = self._lam_qty.to(u.um).value
if not np.all(np.isfinite(lam_arr)):
msg = f"Wavelength column '{wl_col}' contains NaN or inf values."
raise ValueError(msg)
self._flux_cols: list[str] = flux_cols
self._flam_arrays: dict[str, np.ndarray] = {
col: np.asarray(table[col], dtype=float) for col in flux_cols
}
# set by _prepare()
self._lam_um: np.ndarray | None = None
self._rest_low_um: float = 0.0
self._flam_eval: Array | None = None
self._lam_eval: Array | None = None
# ------------------------------------------------------------------
# ContinuumForm interface
# ------------------------------------------------------------------
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return (*[f'{c}_scale' for c in self._flux_cols], 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {}
for col in self._flux_cols:
priors[f'{col}_scale'] = Uniform(0, 2)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
d: dict[str, tuple[bool, u.UnitBase | None]] = {}
for col in self._flux_cols:
d[f'{col}_scale'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Convert template to microns and validate wavelength coverage."""
self._lam_um = self._lam_qty.to(u.um).value
self._rest_low_um = float(low.to(u.um).value)
rest_high_um = float(high.to(u.um).value)
# Use a small relative tolerance to absorb floating-point rounding from
# unit conversions (e.g. 9000 AA * 1e-4 µm/AA = 0.9000000000000001 µm).
_rtol = 1e-9
if self._lam_um[0] > self._rest_low_um * (1.0 + _rtol):
msg = (
f'Template wavelength grid starts at {self._lam_um[0]:.4f} µm '
f'but region lower bound is {low.to(u.um):.4f}. '
'Template does not cover the full region.'
)
raise ValueError(msg)
if self._lam_um[-1] < rest_high_um * (1.0 - _rtol):
msg = (
f'Template wavelength grid ends at {self._lam_um[-1]:.4f} µm '
f'but region upper bound is {high.to(u.um):.4f}. '
'Template does not cover the full region.'
)
raise ValueError(msg)
# Stack flux arrays: shape (N_cols, N_lam)
self._flam_eval = jnp.array(
np.stack([self._flam_arrays[c] for c in self._flux_cols], axis=0)
)
self._lam_eval = jnp.array(self._lam_um)
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
z_sys: float = 0.0,
) -> Array:
# lsf_fwhm is intentionally ignored: Template returns raw interpolated
# values. In analytic mode no LSF convolution is applied. In convolution
# mode the full numerical LSF kernel is applied externally by the caller,
# which assumes the template is intrinsically unresolved — templates with
# native spectral resolution will be further convolved by the instrument
# LSF, producing an effective resolution equal to the convolution of both.
# Convert observed-frame wavelength to rest-frame microns.
# obs_low = rest_low_canonical * (1 + z_sys), and _rest_low_um is
# rest_low in microns, so wavelength * _rest_low_um / obs_low gives
# rest-frame microns regardless of the canonical wavelength unit.
assert self._flam_eval is not None and self._lam_eval is not None, (
'Template._prepare() must be called before evaluate(). '
'Wrap the Template in a ContinuumRegion to trigger _prepare().'
)
wl = jnp.asarray(wavelength)
scale = self._rest_low_um / obs_low
lam_rest_um = wl * scale
norm_wav_rest_um = params['norm_wav'] * scale
total = jnp.zeros_like(wl)
for i, col in enumerate(self._flux_cols):
flam_row = self._flam_eval[i]
t_at_lam = jnp.interp(lam_rest_um, self._lam_eval, flam_row)
t_at_norm = jnp.interp(norm_wav_rest_um, self._lam_eval, flam_row)
total = total + params[f'{col}_scale'] * t_at_lam / t_at_norm
return total
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------
[docs]
@override
def to_dict(self) -> dict:
d: dict = {'type': 'Template', 'path': str(self._path)}
if self._wavelength_colname is not None:
d['wavelength_colname'] = self._wavelength_colname
if self._usecols_arg is not None:
d['usecols'] = list(self._usecols_arg)
return d
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Template:
return cls(
d['path'],
wavelength_colname=d.get('wavelength_colname'),
usecols=d.get('usecols'),
)
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, Template):
return NotImplemented
return (
self._path == other._path
and self._wavelength_colname == other._wavelength_colname
and self._usecols_arg == other._usecols_arg
)
@override
def __hash__(self) -> int:
return hash(
(
type(self).__name__,
self._path,
self._wavelength_colname,
self._usecols_arg,
)
)
@override
def __repr__(self) -> str:
cols = ', '.join(self._flux_cols)
return f'Template({self._path.name!r}, columns=[{cols}])'