"""Continuum functional forms: abstract base and concrete implementations."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, cast, override
import jax
import jax.numpy as jnp
import numpy as np
from astropy import units as u
from jax import Array
from jax.typing import ArrayLike
from unite._utils import _ensure_wavelength, _get_conversion_factor, _make_register
from unite.continuum.functions import (
bernstein_eval,
bspline_eval,
chebval,
planck_function,
)
from unite.prior import Fixed, Prior, Uniform
if TYPE_CHECKING:
pass
def _gaussian_convolve_poly(coeffs: Array, lsf_fwhm: ArrayLike) -> Array:
"""Analytically convolve a polynomial with a Gaussian LSF.
Given a polynomial ``p(x) = Σ c_i x^(n-i)`` (NumPy descending-order
convention, degree *n* = ``len(coeffs) - 1``) and a Gaussian kernel with
FWHM ``lsf_fwhm``, return the coefficients of the convolved polynomial
``(p * G)(x)``.
A polynomial convolved with a Gaussian is still a polynomial of the same
degree. Only even Gaussian moments contribute (odd moments vanish by
symmetry), giving the coefficient update:
.. code-block:: text
c_j_new = sum_{k=0,2,4,...}^{N-j} c_{j+k} * C(j+k, k) * (k-1)!! * sigma^k
where ``(k-1)!!`` is the double factorial (the *k*-th even moment of
a standard normal). See the
:doc:`polynomial derivation </derivations/polynomial>` for a full derivation.
For a monomial ``x^k`` convolved with ``N(0, s^2)``, the result is::
sum_{j=0}^{floor(k/2)} C(k, 2j) * (2j-1)!! * s^{2j} * x^{k-2j}
Parameters
----------
coeffs : Array, shape ``(n+1,)``
Polynomial coefficients in **descending** order
(i.e. ``coeffs[0]`` is the leading coefficient for ``x^n``).
lsf_fwhm : ArrayLike
Scalar (or broadcastable) LSF FWHM in the same unit as *x*.
Returns
-------
Array, shape ``(n+1,)``
Convolved polynomial coefficients, same descending-order convention.
"""
sigma2 = (jnp.asarray(lsf_fwhm) / (2.0 * np.sqrt(2.0 * np.log(2.0)))) ** 2
n = coeffs.shape[0] - 1 # polynomial degree; static at trace time
max_half = n // 2 + 1
# Even Gaussian moments M[j] = (2j-1)!! * sigma^{2j}.
# M[0] = 1; M[j] = M[j-1] * (2j-1) * sigma^2 for j >= 1.
if max_half > 1:
def _moment_step(carry, j):
cur = carry * (2 * j - 1) * sigma2
return cur, cur
_, rest = jax.lax.scan(_moment_step, 1.0, jnp.arange(1, max_half))
moments = jnp.concatenate([jnp.array([1.0]), rest])
else:
moments = jnp.ones(1)
# Binomial coefficients C(k, 2j): pure NumPy — depends only on n (static).
binom_np = np.zeros((n + 1, max_half))
for k in range(n + 1):
for j in range(min(k // 2, max_half - 1) + 1):
val = 1.0
for m in range(2 * j):
val = val * (k - m) / (m + 1)
binom_np[k, j] = val
# Static scatter tensor: A[out_idx, i, j] = binom_np[n-i, j] when
# out_idx == i + 2*j (and the entry is in range), else 0.
# out[out_idx] = sum_i sum_j coeffs[i] * A[out_idx, i, j] * moments[j]
# = einsum('oij,j,i->o', A, moments, coeffs).
a_np = np.zeros((n + 1, n + 1, max_half))
for i in range(n + 1):
k = n - i
for j in range(min(k // 2, max_half - 1) + 1):
a_np[i + 2 * j, i, j] = binom_np[k, j]
return jnp.einsum('oij,j,i->o', jnp.asarray(a_np), moments, coeffs)
def _polyint_avg(coeffs: Array, x_low: ArrayLike, x_high: ArrayLike) -> Array:
"""Exact pixel-averaged value of a polynomial over ``[x_low, x_high]``.
Given descending-order coefficients ``[a_n, ..., a_0]``, compute::
(1 / (x_high - x_low)) * integral_{x_low}^{x_high} p(x) dx
using the analytic antiderivative.
Parameters
----------
coeffs : Array, shape ``(n+1,)``
Polynomial coefficients in descending order.
x_low, x_high : ArrayLike
Pixel bin edges (may be arrays).
Returns
-------
Array
Pixel-averaged polynomial values.
"""
# Antiderivative: for descending [a_n, ..., a_0], the antiderivative
# is [a_n/(n+1), a_{n-1}/n, ..., a_0/1, 0] (also descending).
n = coeffs.shape[0]
divisors = jnp.arange(n, 0, -1, dtype=coeffs.dtype)
anti = jnp.concatenate([coeffs / divisors, jnp.array([0.0])])
return (jnp.polyval(anti, x_high) - jnp.polyval(anti, x_low)) / (x_high - x_low)
def _cheb_to_mono_matrix(n: int) -> Array:
"""Build the (n x n) Chebyshev-to-monomial conversion matrix.
Returns a matrix *M* such that ``M @ cheb_coeffs`` gives monomial
coefficients in **ascending** order (``[a0, a1, ..., a_{n-1}]``).
``cheb_coeffs = [c0, c1, ..., c_{n-1}]`` represents
``c0*T0(x) + c1*T1(x) + ... + c_{n-1}*T_{n-1}(x)``.
Built at Python time (not traced), so the matrix is a static constant.
"""
import numpy as np
from numpy.polynomial.chebyshev import cheb2poly
M = np.zeros((n, n)) # noqa: N806
for k in range(n):
basis = np.zeros(n)
basis[k] = 1.0
# cheb2poly returns ascending monomial coefficients for T_k
poly = cheb2poly(basis)
M[: len(poly), k] = poly
return jnp.array(M)
def _bernstein_to_mono_matrix(n: int) -> Array:
"""Build the ((n+1) x (n+1)) Bernstein-to-monomial conversion matrix.
Returns a matrix *M* such that ``M @ bern_coeffs`` gives monomial
coefficients in **ascending** order for the polynomial on ``[0, 1]``.
``bern_coeffs = [b0, b1, ..., b_n]`` represents
``Σ b_i * C(n,i) * t^i * (1-t)^{n-i}``.
"""
import numpy as np
from scipy.special import comb
M = np.zeros((n + 1, n + 1)) # noqa: N806
# Monomial coeff for t^k: Σ_{i=0}^{k} C(n,i) C(n-i, k-i) (-1)^{k-i} b_i
for k in range(n + 1):
for i in range(k + 1):
M[k, i] = (
comb(n, i, exact=True)
* comb(n - i, k - i, exact=True)
* (-1) ** (k - i)
)
return jnp.array(M)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_FORM_REGISTRY: dict[str, type[ContinuumForm]] = {}
_register = _make_register(_FORM_REGISTRY)
# ---------------------------------------------------------------------------
# Abstract base class
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Deserialization helper
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Piecewise / region-local forms
# ---------------------------------------------------------------------------
[docs]
@_register
class Linear(ContinuumForm):
"""Linear continuum: ``scale + slope * (wavelength - norm_wav)``.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``slope`` — Continuum slope in flux per wavelength unit.
Default prior: ``Uniform(-10, 10)``.
* ``norm_wav`` — Reference wavelength where the
continuum equals ``scale``.
Default prior: ``Fixed(region_center)``.
"""
@property
@override
def is_linear(self) -> bool:
return True
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'angle', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'angle': Uniform(-np.pi / 2, np.pi / 2),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'angle': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
nw = params['norm_wav']
return params['scale'] + jnp.tan(params['angle']) * (wavelength - nw)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Linear'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Linear:
return cls()
[docs]
@_register
class Polynomial(ContinuumForm):
"""Polynomial continuum of configurable degree.
Evaluates ``scale + c1*x + c2*x**2 + ...`` where
``x = wavelength - norm_wav``.
Parameters
----------
degree : int
Polynomial degree (default 1).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``c1, c2, ...`` — Higher-order polynomial coefficients.
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, degree: int = 1) -> None:
if degree < 0:
msg = f'Polynomial degree must be >= 0, got {degree}'
raise ValueError(msg)
self._degree = degree
@property
@override
def is_linear(self) -> bool:
return True
@property
def degree(self) -> int:
"""Polynomial degree."""
return self._degree
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._degree == 0:
return ('scale', 'norm_wav')
return ('scale', *(f'c{i}' for i in range(1, self._degree + 1)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._degree + 1):
priors[f'c{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._degree + 1):
d[f'c{i}'] = (True, flux_unit / wl_unit**i)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
nw = params['norm_wav']
x = wavelength - nw
# Monomial coefficients in descending order: c_n, ..., c_1, scale.
mono = jnp.array(
[params[f'c{i}'] for i in range(self._degree, 0, -1)] + [params['scale']]
)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm)
return jnp.polyval(convolved, x)
[docs]
@override
def integrate(
self,
low: ArrayLike,
high: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
nw = params['norm_wav']
x_low = low - nw
x_high = high - nw
mono = jnp.array(
[params[f'c{i}'] for i in range(self._degree, 0, -1)] + [params['scale']]
)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm)
return _polyint_avg(convolved, x_low, x_high)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Polynomial', 'degree': self._degree}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Polynomial:
return cls(degree=d['degree'])
@override
def __repr__(self) -> str:
return f'Polynomial(degree={self._degree})'
[docs]
@_register
class Chebyshev(ContinuumForm):
"""Chebyshev polynomial continuum of configurable order.
Evaluates a Chebyshev series on coordinates normalized to ``[-1, 1]``
within the continuum region, normalized so that the continuum equals
``scale`` at ``norm_wav``. Numerically more stable than a standard
polynomial basis for higher orders.
The x-coordinate is ``(wavelength - center) / (half_width * stretch)``
where ``half_width`` is derived from the region bounds passed to
:meth:`evaluate`, and ``stretch`` is a form-specific scaling factor
(default ``1.0`` for identity normalization).
The continuum is parameterized as ``scale * T(x) / T(x_nw)`` where
``T`` is the Chebyshev series with constant term fixed at 1.0,
and ``x_nw`` is the normalized coordinate at ``norm_wav``.
Parameters
----------
order : int
Chebyshev order (default 2). Number of coefficients = order + 1.
stretch : float
Stretch factor to scale the region normalization (default 1.0).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``c1, c2, ...`` — Higher-order Chebyshev coefficients (normalized to constant term 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, order: int = 2, stretch: float = 1.0) -> None:
if order < 0:
msg = f'Chebyshev order must be >= 0, got {order}'
raise ValueError(msg)
self._order = order
if stretch <= 0:
msg = f'Chebyshev stretch factor must be > 0, got {stretch}'
raise ValueError(msg)
self._stretch = stretch
# Static Chebyshev-to-monomial conversion matrix.
self._cheb2mono = _cheb_to_mono_matrix(order + 1)
@property
@override
def is_linear(self) -> bool:
return False
@property
def order(self) -> int:
"""Chebyshev order."""
return self._order
@property
def stretch(self) -> float:
"""Stretch factor for the region normalization."""
return self._stretch
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._order == 0:
return ('scale', 'norm_wav')
return ('scale', *(f'c{i}' for i in range(1, self._order + 1)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._order + 1):
priors[f'c{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# x is normalised to [-1, 1], so all coefficients have unit flux_unit.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._order + 1):
d[f'c{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
half_width = (obs_high - obs_low) / 2
scale_factor = half_width * self._stretch
x = (wavelength - center) / scale_factor
nw = params['norm_wav']
x_nw = (nw - center) / scale_factor
# Chebyshev coefficients → monomial (ascending) via static matrix,
# then convolve with rescaled LSF and evaluate.
cheb_coeffs = jnp.array(
[1.0] + [params[f'c{i}'] for i in range(1, self._order + 1)]
)
mono_asc = self._cheb2mono @ cheb_coeffs # ascending order
mono = mono_asc[::-1] # descending for jnp.polyval
# Rescale LSF FWHM into the normalised coordinate system.
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / scale_factor
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
shape = jnp.polyval(convolved, x)
# norm_wav is a scalar — no LSF convolution needed.
shape_nw = chebval(x_nw, cheb_coeffs)
return params['scale'] * shape / shape_nw
[docs]
@override
def integrate(
self,
low: ArrayLike,
high: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
half_width = (obs_high - obs_low) / 2
scale_factor = half_width * self._stretch
x_low = (low - center) / scale_factor
x_high = (high - center) / scale_factor
nw = params['norm_wav']
x_nw = (nw - center) / scale_factor
cheb_coeffs = jnp.array(
[1.0] + [params[f'c{i}'] for i in range(1, self._order + 1)]
)
mono_asc = self._cheb2mono @ cheb_coeffs
mono = mono_asc[::-1]
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / scale_factor
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
shape_avg = _polyint_avg(convolved, x_low, x_high)
shape_nw = chebval(x_nw, cheb_coeffs)
return params['scale'] * shape_avg / shape_nw
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Chebyshev', 'order': self._order, 'stretch': self._stretch}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Chebyshev:
return cls(order=d['order'], stretch=d.get('stretch', 1.0))
@override
def __repr__(self) -> str:
return f'Chebyshev(order={self._order}, stretch={self._stretch})'
[docs]
@_register
class BSpline(ContinuumForm):
"""B-spline continuum with local knot control.
The knot vector must be set via *knots* at construction time (typically
derived from the wavelength coverage of the spectrum). Knots should be
in the same wavelength unit as the :class:`ContinuumRegion` bounds;
they are converted to the canonical unit at region construction time via
:meth:`_prepare`. Knots must fall within the region bounds.
The continuum is normalized so that it equals ``scale`` at ``norm_wav``,
parameterized as ``scale * S(u) / S(u_nw)`` where ``S`` is the B-spline
series with first coefficient fixed at 1.0, and ``u_nw`` is the normalized
coordinate at ``norm_wav``.
Parameters
----------
knots : u.Quantity
Knot vector in wavelength units. It is automatically clamped at the region bounds.
degree : int
Spline degree (default 3 for cubic).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``coeff_1, coeff_2, …`` — Remaining B-spline coefficients (normalized to first 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, knots: u.Quantity, degree: int = 3) -> None:
if isinstance(knots, u.Quantity):
self._knots: u.Quantity = _ensure_wavelength(knots, 'knots', ndim=1)
else:
raise ValueError(
f'knots must be an astropy Quantity with length units, got {knots}'
)
self._degree = degree
self._n_basis = len(self._knots) + degree + 1
@property
@override
def is_linear(self) -> bool:
return False
@property
def degree(self) -> int:
"""Spline degree."""
return self._degree
@property
def n_basis(self) -> int:
"""Number of B-spline basis functions."""
return self._n_basis
@property
def knots(self) -> u.Quantity:
"""Knot vector (in the original units passed at construction)."""
return self._knots
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._n_basis == 1:
return ('scale', 'norm_wav')
return ('scale', *(f'coeff_{i}' for i in range(1, self._n_basis)), 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._n_basis):
priors[f'coeff_{i}'] = Uniform(-10, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# B-spline coefficients share the same unit as the function value.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._n_basis):
d[f'coeff_{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
# LSF convolution is not supported for BSpline (non-polynomial basis).
# Normalize wavelengths relative to the knot range
obs_range = obs_high - obs_low
# Map wavelengths to the same coordinate system as the knots
u = 2 * (wavelength - obs_low) / obs_range - 1
# Map norm_wav to the same coordinate system
nw = params['norm_wav']
u_nw = 2 * (nw - obs_low) / obs_range - 1
# Also map knots to [-1, 1]
knots_norm = 2 * (self._knots_eval - obs_low) / obs_range - 1
shape_coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._n_basis)]
)
shape = bspline_eval(u, shape_coeffs, knots_norm, self._degree)
_snw = bspline_eval(
jnp.atleast_1d(u_nw), shape_coeffs, knots_norm, self._degree
)
shape_nw = _snw[0]
return params['scale'] * shape / shape_nw
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Validate and prepare the knot vector for evaluation within the region bounds."""
if any((self._knots <= low) or (self._knots >= high)):
msg = f'All knots must be within the region bounds [{low}, {high}], got {self._knots}'
raise ValueError(msg)
# Add n-1 extra knots at each end for clamping (B-spline convention)
p = self._degree + 1
knots_clamped = jnp.concatenate(
[
low.value * jnp.ones(p),
self._knots.to(low.unit).value,
high.value * jnp.ones(p),
]
)
self._knots_eval = knots_clamped
[docs]
@override
def to_dict(self) -> dict:
return {
'type': 'BSpline',
'knots': self._knots.value.tolist(),
'unit': str(self._knots.unit),
'degree': self._degree,
}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> BSpline:
return cls(knots=d['knots'] * u.Unit(d['unit']), degree=d.get('degree', 3))
@override
def __repr__(self) -> str:
return f'BSpline(degree={self._degree}, knots={self._knots})'
[docs]
@_register
class Bernstein(ContinuumForm):
"""Global Bernstein polynomial continuum, normalized at a reference wavelength.
Evaluates a Bernstein series on coordinates normalized to ``[0, 1]``
within the continuum region, normalized so that the continuum equals
``scale`` at ``norm_wav``.
The wavelength range is derived from the region bounds passed to
:meth:`evaluate`, normalized to ``[0, 1]`` for the Bernstein basis.
The ``stretch`` parameter optionally scales the region normalization.
The continuum is parameterized as ``scale * B(t) / B(t_nw)`` where
``B`` is the Bernstein series with first term fixed at 1.0, and ``t_nw``
is the normalized coordinate at ``norm_wav``.
Parameters
----------
degree : int
Polynomial degree (default 4). Number of coefficients = degree + 1.
stretch : float
Stretch factor to scale the region normalization (default 1.0).
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``coeff_1, coeff_2, …`` — Remaining Bernstein coefficients (normalized to first term 1.0).
Default prior: ``Uniform(-10, 10)`` each.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, degree: int = 4, stretch: float = 1.0) -> None:
from scipy.special import comb
self._degree = degree
if stretch <= 0:
msg = f'Bernstein stretch factor must be > 0, got {stretch}'
raise ValueError(msg)
self._stretch = stretch
self._binom = jnp.array(
[comb(degree, i, exact=True) for i in range(degree + 1)], dtype=float
)
# Static Bernstein-to-monomial conversion matrix.
self._bern2mono = _bernstein_to_mono_matrix(degree)
@property
@override
def is_linear(self) -> bool:
return False
@property
def degree(self) -> int:
"""Polynomial degree."""
return self._degree
@property
def stretch(self) -> float:
"""Stretch factor for the region normalization."""
return self._stretch
[docs]
@override
def param_names(self) -> tuple[str, ...]:
if self._degree == 0:
return ('scale', 'norm_wav')
return (
'scale',
*(f'coeff_{i}' for i in range(1, self._degree + 1)),
'norm_wav',
)
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
priors: dict[str, Prior] = {'scale': Uniform(0, 2)}
for i in range(1, self._degree + 1):
priors[f'coeff_{i}'] = Uniform(0, 10)
priors['norm_wav'] = Fixed(region_center)
return priors
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
# Bernstein coefficients share the same unit as the function value.
d: dict[str, tuple[bool, u.UnitBase | None]] = {'scale': (True, flux_unit)}
for i in range(1, self._degree + 1):
d[f'coeff_{i}'] = (True, flux_unit)
d['norm_wav'] = (False, wl_unit)
return d
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
# 1. Coordinate Transformation
half_width = (obs_high - obs_low) / 2
stretch_factor = half_width * self._stretch
# helper to transform wavelength to [0, 1]
def to_t(w):
uu = (w - center) / stretch_factor
return (uu + 1) / 2
t = to_t(wavelength)
t_nw = to_t(params['norm_wav'])
# 2. Bernstein coefficients → monomial (ascending) via static matrix,
# convolve with rescaled LSF, then evaluate.
coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._degree + 1)]
)
mono_asc = self._bern2mono @ coeffs # ascending monomial in t
mono = mono_asc[::-1] # descending for jnp.polyval
# LSF FWHM in t-coordinate: dt/dλ = 1 / (2 * stretch_factor)
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / (2.0 * stretch_factor)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
shape = jnp.polyval(convolved, t)
# norm_wav is a scalar — no LSF convolution needed.
shape_nw = bernstein_eval(jnp.atleast_1d(t_nw), coeffs, self._binom)
# 3. Normalize so that the continuum equals `scale` at `norm_wav`
return cast(Array, params['scale'] * shape / shape_nw)
[docs]
@override
def integrate(
self,
low: ArrayLike,
high: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
half_width = (obs_high - obs_low) / 2
stretch_factor = half_width * self._stretch
def to_t(w):
uu = (w - center) / stretch_factor
return (uu + 1) / 2
t_low = to_t(low)
t_high = to_t(high)
t_nw = to_t(params['norm_wav'])
coeffs = jnp.concatenate(
[jnp.array([1.0])]
+ [jnp.atleast_1d(params[f'coeff_{i}']) for i in range(1, self._degree + 1)]
)
mono_asc = self._bern2mono @ coeffs
mono = mono_asc[::-1]
lsf_fwhm_scaled = jnp.asarray(lsf_fwhm) / (2.0 * stretch_factor)
convolved = _gaussian_convolve_poly(mono, lsf_fwhm_scaled)
shape_avg = _polyint_avg(convolved, t_low, t_high)
shape_nw = bernstein_eval(jnp.atleast_1d(t_nw), coeffs, self._binom)
return params['scale'] * shape_avg / shape_nw
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Bernstein', 'degree': self._degree, 'stretch': self._stretch}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Bernstein:
return cls(degree=d['degree'], stretch=d.get('stretch', 1.0))
@override
def __repr__(self) -> str:
return f'Bernstein(degree={self._degree}, stretch={self._stretch})'
[docs]
@_register
class PowerLaw(ContinuumForm):
"""Power-law continuum: ``scale * (wavelength / norm_wav) ** beta``.
This form has no constructor parameters.
To share a consistent reference wavelength across multiple regions
(required for physically meaningful parameter sharing), pass a
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` with
``Fixed(value)`` carrying your chosen reference wavelength.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum level at ``norm_wav``.
Default prior: ``Uniform(0, 10)``.
* ``beta`` — Power-law index (dimensionless).
Default prior: ``Uniform(-5, 5)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'beta', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'beta': Uniform(-5, 5),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'beta': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
# LSF convolution is not supported for PowerLaw.
nw = params['norm_wav']
return cast(Array, params['scale'] * (wavelength / nw) ** params['beta'])
# def integrate(
# self,
# low: ArrayLike,
# high: ArrayLike,
# center: float,
# params: dict[str, ArrayLike],
# obs_low: float,
# obs_high: float,
# lsf_fwhm: ArrayLike = 0.0,
# ) -> Array:
# # Exact integral of scale * (wavelength / nw)^beta over [low, high]:
# # = scale / nw^beta * [w^{beta+1} / (beta+1)]_{low}^{high} / (high - low)
# nw = params['norm_wav']
# beta = params['beta']
# bp1 = beta + 1.0
# # For beta != -1 (the common case): use the power-rule antiderivative.
# # beta = -1 gives log, but that is physically unusual; we handle it
# # via jnp.where for safety.
# antideriv_high = high**bp1 / bp1
# antideriv_low = low**bp1 / bp1
# power_avg = (antideriv_high - antideriv_low) / (high - low)
# # beta = -1 fallback: integral of 1/w is ln(w)
# log_avg = (jnp.log(high) - jnp.log(low)) / (high - low)
# avg = jnp.where(jnp.abs(bp1) > 1e-10, power_avg, log_avg)
# return params['scale'] / nw**beta * avg
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'PowerLaw'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> PowerLaw:
return cls()
[docs]
@_register
class Blackbody(ContinuumForm):
"""Planck blackbody continuum normalized at a reference wavelength.
Evaluates ``scale * B_λ(T) / B_λ(norm_wav, T)`` so that
*scale* directly represents the continuum flux at
``norm_wav``. Wavelength parameters may be in any unit;
automatic unit conversion to microns is applied internally.
``norm_wav`` is a named parameter with a default
``Fixed(region_center)`` prior. Pass an explicit
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` with
``Fixed(value)`` to pin it to a specific wavelength across multiple
regions — essential for physically consistent normalization when fitting
a single blackbody across disjoint spectral windows.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self) -> None:
self._micron_factor: float = 1.0
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
# LSF convolution is not supported for Blackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
return params['scale'] * bb
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the micron conversion factor for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Blackbody'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Blackbody:
return cls()
[docs]
@_register
class ModifiedBlackbody(ContinuumForm):
"""Modified blackbody: ``scale * B_λ(T) * (λ / norm_wav)^beta / B_λ(nw, T)``.
The power-law modifier *beta* broadens (beta > 0) or narrows (beta < 0)
the SED relative to a pure blackbody. *beta = 0* recovers
:class:`Blackbody`. Wavelength parameters may be in any unit;
automatic unit conversion to microns is applied internally.
``norm_wav`` is a named parameter with a default
``Fixed(region_center)`` prior. Share a
:class:`~unite.continuum.config.ContinuumNormalizationWavelength` token
across regions to enforce a consistent reference wavelength.
This form has no constructor parameters.
Notes
-----
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``beta`` — Power-law modifier index (dimensionless).
Default prior: ``Uniform(-4, 4)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self) -> None:
self._micron_factor: float = 1.0
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'beta', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'beta': Uniform(-4, 4),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'beta': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
# LSF convolution is not supported for ModifiedBlackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
modifier = (wl_um / nw_um) ** params['beta']
return params['scale'] * bb * modifier
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the micron conversion factor for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'ModifiedBlackbody'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> ModifiedBlackbody:
return cls()
[docs]
@_register
class AttenuatedBlackbody(ContinuumForm):
"""Dust-attenuated blackbody continuum.
Evaluates
``scale * B_λ(T) / B_λ(nw,T) * exp(-tau_v * [(λ/lambda_ext)^alpha - (nw/lambda_ext)^alpha])``.
Extinction is normalized at ``norm_wav`` so that *scale*
represents the **observed** (attenuated) flux there. Negative *alpha*
gives steeper extinction at short wavelengths (typical dust law).
Wavelength parameters may be in any unit; automatic unit conversion
to microns is applied internally.
Parameters
----------
lambda_ext : astropy.units.Quantity
Reference wavelength for the extinction law. Must be
:class:`~astropy.units.Quantity` with any length unit — it will
be converted automatically. Defaults to ``5500 * u.AA``.
Notes
-----
``lambda_ext`` is a *static* configuration parameter (not
sampled). It is stored in microns internally and, along with
evaluation wavelengths, automatically converted at model-build time via
:meth:`_prepare`.
**Model parameters** (sampled with priors, overridable via
``ContinuumRegion(params={...})``):
* ``scale`` — Observed continuum flux at ``norm_wav``
(in units of ``continuum_scale``).
Default prior: ``Uniform(0, 10)``.
* ``temperature`` — Blackbody temperature in Kelvin.
Default prior: ``Uniform(100, 50000)``.
* ``tau_v`` — Optical depth at ``lambda_ext``.
Default prior: ``Uniform(0, 5)``.
* ``alpha`` — Dust extinction power-law index (negative = steeper at
short λ). Default prior: ``Uniform(-2, 0)``.
* ``norm_wav`` — Reference wavelength.
Default prior: ``Fixed(region_center)``.
"""
def __init__(self, lambda_ext: u.Quantity = 5500 * u.AA) -> None:
if isinstance(lambda_ext, u.Quantity):
self._lambda_ext: u.Quantity = _ensure_wavelength(
lambda_ext, 'lambda_ext', ndim=0
)
else:
raise ValueError(
f'lambda_ext must be an astropy Quantity with length units, got {lambda_ext}'
)
# Float used in evaluate(); initially microns, updated by _prepare().
self._lambda_ext_um: float = float(self._lambda_ext.to(u.um).value)
# Conversion factor from canonical unit to microns.
self._micron_factor: float = 1.0
@property
def lambda_ext(self) -> u.Quantity:
"""Extinction reference wavelength as an astropy Quantity."""
return self._lambda_ext
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('scale', 'temperature', 'tau_v', 'alpha', 'norm_wav')
[docs]
@override
def default_priors(self, region_center: float = 1.0) -> dict[str, Prior]:
return {
'scale': Uniform(0, 2),
'temperature': Uniform(100, 50000),
'tau_v': Uniform(0, 5),
'alpha': Uniform(-2, 0),
'norm_wav': Fixed(region_center),
}
[docs]
@override
def param_units(
self, flux_unit: u.UnitBase, wl_unit: u.UnitBase
) -> dict[str, tuple[bool, u.UnitBase | None]]:
return {
'scale': (True, flux_unit),
'temperature': (False, u.K),
'tau_v': (False, None),
'alpha': (False, None),
'norm_wav': (False, wl_unit),
}
[docs]
@override
def evaluate(
self,
wavelength: ArrayLike,
center: float,
params: dict[str, ArrayLike],
obs_low: float,
obs_high: float,
lsf_fwhm: ArrayLike = 0.0,
) -> Array:
# LSF convolution is not supported for AttenuatedBlackbody.
wl_um = wavelength * self._micron_factor
nw_um = params['norm_wav'] * self._micron_factor
bb = planck_function(wl_um, params['temperature'], nw_um)
ext_data = (wl_um / self._lambda_ext_um) ** params['alpha']
ext_pivot = (nw_um / self._lambda_ext_um) ** params['alpha']
extinction = jnp.exp(-params['tau_v'] * (ext_data - ext_pivot))
return params['scale'] * bb * extinction
@override
def _prepare(self, low: u.Quantity, high: u.Quantity) -> None:
"""Compute the conversion factors for the region's wavelength unit."""
self._micron_factor = _get_conversion_factor(low.unit, u.um)
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'AttenuatedBlackbody', 'lambda_ext': self.lambda_ext}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> AttenuatedBlackbody:
return cls(lambda_ext=d.get('lambda_ext', 0.55))
@override
def __repr__(self) -> str:
return f'AttenuatedBlackbody(lambda_ext={self.lambda_ext})'