"""Concrete line profile implementations."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import override
from jax import Array
from jax.typing import ArrayLike
from unite._utils import _make_register
from unite.line import functions
from unite.prior import Prior, TruncatedNormal, Uniform
# -------------------------------------------------------------------
# Base Profile class
# -------------------------------------------------------------------
[docs]
class Profile(ABC):
"""Abstract base class for spectral line profiles.
A profile declares which parameters it requires (via :meth:`param_names`
and :meth:`default_priors`) and provides an :meth:`integrate` method that
computes the profile integral over wavelength bins.
Each concrete subclass carries an integer :attr:`code` for dispatch
in JAX arrays, and supports serialization via :meth:`to_dict` /
:meth:`from_dict`.
"""
#: Integer code for this profile type, used in JAX arrays.
code: int
[docs]
@abstractmethod
def param_names(self) -> tuple[str, ...]:
"""Return names of parameters this profile requires.
Returns
-------
tuple of str
For example, ``('fwhm_gauss',)`` for Gaussian,
``('fwhm_gauss', 'fwhm_lorentz')`` for pseudo-Voigt, or
``('fwhm_gauss', 'h3', 'h4')`` for Gauss-Hermite.
"""
[docs]
@abstractmethod
def default_priors(self) -> dict[str, Prior]:
"""Return sensible default priors for each parameter.
The keys must match :meth:`param_names`. These are used when the
user does not supply an explicit token for a parameter.
Returns
-------
dict of str to Prior
For example, ``{'fwhm_gauss': Uniform(0, 1000)}``.
"""
[docs]
def integrate(
self,
low: ArrayLike,
high: ArrayLike,
center: ArrayLike,
lsf_fwhm: ArrayLike,
**params: ArrayLike,
) -> Array:
r"""Integrate the profile over wavelength bins.
Delegates to :meth:`integrate_branch` by mapping keyword arguments to
positional slots (p0, p1, p2) in :meth:`param_names` order.
Parameters
----------
low : ArrayLike
Lower wavelength edges of bins.
high : ArrayLike
Upper wavelength edges of bins.
center : ArrayLike
Line center wavelength.
lsf_fwhm : ArrayLike
Instrumental line spread function FWHM at the line center.
\*\*params : ArrayLike
Parameter values, keyed by the names from :meth:`param_names`.
Returns
-------
Array
Fractional flux integrated in each bin (sums to 1 over all bins).
"""
pnames = self.param_names()
p0 = params[pnames[0]] if len(pnames) > 0 else 0.0
p1 = params[pnames[1]] if len(pnames) > 1 else 0.0
p2 = params[pnames[2]] if len(pnames) > 2 else 0.0
return self.integrate_branch()(low, high, center, lsf_fwhm, p0, p1, p2)
[docs]
def evaluate(
self,
wavelength: ArrayLike,
center: ArrayLike,
lsf_fwhm: ArrayLike,
**params: ArrayLike,
) -> Array:
r"""Evaluate the normalised profile at wavelength points.
Delegates to :meth:`evaluate_branch` by mapping keyword arguments to
positional slots (p0, p1, p2) in :meth:`param_names` order.
Parameters
----------
wavelength : ArrayLike
Wavelength points at which to evaluate the profile.
center : ArrayLike
Line center wavelength.
lsf_fwhm : ArrayLike
Instrumental line spread function FWHM at the line center.
\*\*params : ArrayLike
Parameter values, keyed by the names from :meth:`param_names`.
Returns
-------
Array
Normalised profile value at each wavelength point
(1/wavelength units).
"""
pnames = self.param_names()
p0 = params[pnames[0]] if len(pnames) > 0 else 0.0
p1 = params[pnames[1]] if len(pnames) > 1 else 0.0
p2 = params[pnames[2]] if len(pnames) > 2 else 0.0
return self.evaluate_branch()(wavelength, center, lsf_fwhm, p0, p1, p2)
[docs]
@abstractmethod
def integrate_branch(self) -> Callable[..., Array]:
"""Return a JAX-compatible branch callable for ``lax.switch`` dispatch.
The returned function must have the fixed signature::
fn(low, high, center, lsf_fwhm, p0, p1, p2) -> Array
Parameters correspond to :meth:`param_names` in order: ``p0`` is
``param_names()[0]``, ``p1`` is ``param_names()[1]``, ``p2`` is
``param_names()[2]``. Unused slots receive zero from the model
builder and must be ignored.
Returns
-------
callable
A pure-JAX function suitable as a ``lax.switch`` branch.
"""
[docs]
@abstractmethod
def evaluate_branch(self) -> Callable[..., Array]:
"""Return a JAX-compatible branch callable for pointwise evaluation.
The returned function must have the fixed signature::
fn(wavelength, center, lsf_fwhm, p0, p1, p2) -> Array
Returns the normalised profile value at each wavelength point.
Returns
-------
callable
A pure-JAX function suitable as a ``lax.switch`` branch.
"""
[docs]
@abstractmethod
def to_dict(self) -> dict:
"""Serialize to a YAML-safe dictionary."""
[docs]
@classmethod
@abstractmethod
def from_dict(cls, d: dict) -> Profile:
"""Deserialize from a dictionary."""
# -------------------------------------------------------------------
# Registry for deserialization
# -------------------------------------------------------------------
_PROFILE_REGISTRY: dict[str, type[Profile]] = {}
_register = _make_register(_PROFILE_REGISTRY)
[docs]
def profile_from_dict(d: dict) -> Profile:
"""Deserialize a Profile from a dictionary using the 'type' key.
Parameters
----------
d : dict
Dictionary with a ``'type'`` key matching a registered profile class.
Returns
-------
Profile
Raises
------
KeyError
If the type is not registered.
"""
cls = _PROFILE_REGISTRY[d['type']]
return cls.from_dict(d)
# -------------------------------------------------------------------
# Concrete profiles
# -------------------------------------------------------------------
[docs]
@_register
class Gaussian(Profile):
"""Gaussian (normal) line profile.
Requires a single parameter ``fwhm_gauss``. The instrumental LSF
is added in quadrature: ``total_fwhm = sqrt(lsf_fwhm² + fwhm_gauss²)``.
"""
code = 0
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_gauss',)
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {'fwhm_gauss': Uniform(0, 1000)}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_gauss
return functions.integrate_gaussian(lo, hi, c, lsf, p0)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_gaussian(wavelength, c, lsf, p0)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Gaussian'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Gaussian:
return cls()
@override
def __repr__(self) -> str:
return 'Gaussian()'
[docs]
@_register
class Cauchy(Profile):
"""Cauchy (Lorentzian) line profile.
Requires a single parameter ``fwhm_lorentz``. The LSF is **not**
convolved — this profile is a pure Lorentzian.
Note: This profile is implemented as a PseudoVoigt with LSF=0 for consistency
with the scientific assumptions of the package (all lines are convolved with
instrumental LSF).
"""
code = 1
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_lorentz',)
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {'fwhm_lorentz': Uniform(0, 1000)}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_lorentz; pure Cauchy via PseudoVoigt with zero Gaussian width
return functions.integrate_voigt(lo, hi, c, lsf, 0.0, p0)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_voigt(wavelength, c, lsf, 0.0, p0)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Cauchy'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Cauchy:
return cls()
@override
def __repr__(self) -> str:
return 'Cauchy()'
[docs]
@_register
class PseudoVoigt(Profile):
"""Pseudo-Voigt line profile (Thompson et al. 1987).
Requires two parameters: ``fwhm_gauss`` for the Gaussian component
and ``fwhm_lorentz`` for the Lorentzian component. The instrumental
LSF is added in quadrature to the Gaussian component.
"""
code = 2
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_gauss', 'fwhm_lorentz')
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {'fwhm_gauss': Uniform(0, 1000), 'fwhm_lorentz': Uniform(0, 1000)}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_gauss, p1 = fwhm_lorentz
return functions.integrate_voigt(lo, hi, c, lsf, p0, p1)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_voigt(wavelength, c, lsf, p0, p1)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'PseudoVoigt'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> PseudoVoigt:
return cls()
@override
def __repr__(self) -> str:
return 'PseudoVoigt()'
[docs]
@_register
class Laplace(Profile):
"""Laplace (double-exponential) line profile.
Requires a single parameter ``fwhm_exp``. The LSF is **not**
convolved --- this profile is a pure Laplace distribution.
"""
code = 3
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_exp',)
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {'fwhm_exp': Uniform(0, 1000)}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_exp; pure Laplace convolved with Gaussian LSF
return functions.integrate_gaussianLaplace(lo, hi, c, lsf, 0.0, p0)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_gaussianLaplace(wavelength, c, lsf, 0.0, p0)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'Laplace'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> Laplace:
return cls()
@override
def __repr__(self) -> str:
return 'Laplace()'
[docs]
@_register
class SEMG(Profile):
"""Symmetric Exponentially Modified Gaussian (SEMG) line profile.
A Gaussian (with LSF) convolved with a symmetric Laplace
(double-exponential) distribution. Requires two parameters:
``fwhm_gauss`` for the intrinsic Gaussian component and ``fwhm_exp``
for the Laplacian component. The instrumental LSF is added in
quadrature to the Gaussian component.
"""
code = 4
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_gauss', 'fwhm_exp')
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {'fwhm_gauss': Uniform(0, 1000), 'fwhm_exp': Uniform(0, 1000)}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_gauss, p1 = fwhm_exp
return functions.integrate_gaussianLaplace(lo, hi, c, lsf, p0, p1)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_gaussianLaplace(wavelength, c, lsf, p0, p1)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'SEMG'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> SEMG:
return cls()
@override
def __repr__(self) -> str:
return 'SEMG()'
[docs]
@_register
class GaussHermite(Profile):
"""Gauss-Hermite line profile.
A Gaussian (with LSF) modified by Hermite polynomial corrections for
skewness (h3) and kurtosis (h4). Requires three parameters:
``fwhm_gauss`` for the intrinsic Gaussian FWHM, ``h3`` for the
skewness coefficient, and ``h4`` for the kurtosis coefficient. The
instrumental LSF is added in quadrature to the Gaussian component.
"""
code = 5
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_gauss', 'h3', 'h4')
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {
'fwhm_gauss': Uniform(0, 1000),
'h3': TruncatedNormal(loc=0, scale=0.1, low=-0.3, high=0.3),
'h4': TruncatedNormal(loc=0, scale=0.1, low=-0.3, high=0.3),
}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_gauss, p1 = h3, p2 = h4
return functions.integrate_gaussHermite(lo, hi, c, lsf, p0, p1, p2)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_gaussHermite(wavelength, c, lsf, p0, p1, p2)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'GaussHermite'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> GaussHermite:
return cls()
@override
def __repr__(self) -> str:
return 'GaussHermite()'
[docs]
@_register
class SplitNormal(Profile):
"""Split-normal (two-sided Gaussian) line profile.
A Gaussian with different standard deviations on each side of the mean.
Requires two parameters: ``fwhm_blue`` for the blue (left) side and
``fwhm_red`` for the red (right) side. The instrumental LSF is added
in quadrature to both components.
"""
code = 6
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_blue', 'fwhm_red')
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {'fwhm_blue': Uniform(0, 1000), 'fwhm_red': Uniform(0, 1000)}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_blue, p1 = fwhm_red
return functions.integrate_split_normal(lo, hi, c, lsf, p0, p1)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_split_normal(wavelength, c, lsf, p0, p1)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'SplitNormal'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> SplitNormal:
return cls()
@override
def __repr__(self) -> str:
return 'SplitNormal()'
[docs]
@_register
class BoxGauss(Profile):
"""Boxcar distribution convolved with a Gaussian.
The intrinsic profile is a uniform rectangular (boxcar) distribution of
full width ``fwhm_box`` centred at zero (area = 1), convolved with a
Gaussian whose FWHM is the quadrature sum of ``fwhm_gauss`` and
``lsf_fwhm``. As ``fwhm_box`` → 0 the profile reduces to a pure
Gaussian; as ``fwhm_gauss`` → 0 (and ``lsf_fwhm`` → 0) it approaches
the sharp rectangular distribution.
Requires two parameters: ``fwhm_box`` for the boxcar full width and
``fwhm_gauss`` for the intrinsic Gaussian component.
"""
code = 8
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_box', 'fwhm_gauss')
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {'fwhm_box': Uniform(0, 1000), 'fwhm_gauss': Uniform(0, 1000)}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_box, p1 = fwhm_gauss
return functions.integrate_boxGauss(lo, hi, c, lsf, p0, p1)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_boxGauss(wavelength, c, lsf, p0, p1)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'BoxGauss'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> BoxGauss:
return cls()
@override
def __repr__(self) -> str:
return 'BoxGauss()'
[docs]
@_register
class SkewVoigt(Profile):
r"""Skew pseudo-Voigt line profile.
A pseudo-Voigt profile multiplied by a skew factor
``[1 + erf(alpha * (x - c) / (sqrt(2) * sigma_g))]``, where
``sigma_g`` is the standard deviation of the Gaussian component.
The profile integrates to 1 for any value of ``alpha`` because the
skew factor is odd and the pseudo-Voigt is even.
Convolution with the Gaussian LSF rescales the skewness parameter to
.. math::
\\alpha_\\text{eff} = \\frac{\\alpha\\,\\sigma_g}
{\\sqrt{\\sigma_\\text{tot}^2 + \\alpha^2\\sigma_\\text{lsf}^2}}
where :math:`\\sigma_\\text{tot} = \\sqrt{\\sigma_g^2 + \\sigma_\\text{lsf}^2}`.
The skewness is reduced by the LSF and vanishes entirely when
``fwhm_gauss = 0`` (no intrinsic Gaussian component).
Requires three parameters: ``fwhm_gauss`` for the Gaussian component,
``fwhm_lorentz`` for the Lorentzian component, and ``alpha`` for the
skewness (positive values shift flux redward).
"""
code = 7
[docs]
@override
def param_names(self) -> tuple[str, ...]:
return ('fwhm_gauss', 'fwhm_lorentz', 'alpha')
[docs]
@override
def default_priors(self) -> dict[str, Prior]:
return {
'fwhm_gauss': Uniform(0, 1000),
'fwhm_lorentz': Uniform(0, 1000),
'alpha': TruncatedNormal(loc=0, scale=100, low=-300, high=300),
}
[docs]
@override
def integrate_branch(self):
def _fn(lo, hi, c, lsf, p0, p1, p2):
# p0 = fwhm_gauss, p1 = fwhm_lorentz, p2 = alpha
return functions.integrate_skewVoigt(lo, hi, c, lsf, p0, p1, p2)
return _fn
[docs]
@override
def evaluate_branch(self):
def _fn(wavelength, c, lsf, p0, p1, p2):
return functions.evaluate_skewVoigt(wavelength, c, lsf, p0, p1, p2)
return _fn
[docs]
@override
def to_dict(self) -> dict:
return {'type': 'SkewVoigt'}
[docs]
@classmethod
@override
def from_dict(cls, d: dict) -> SkewVoigt:
return cls()
@override
def __repr__(self) -> str:
return 'SkewVoigt()'
_PROFILE_ALIASES: dict[str, Profile] = {
'gaussian': Gaussian(),
'normal': Gaussian(),
'lorentzian': Cauchy(),
'cauchy': Cauchy(),
'exponential': Laplace(),
'laplace': Laplace(),
'voigt': PseudoVoigt(),
'pseudovoigt': PseudoVoigt(),
'semg': SEMG(),
'exp-gaussian': SEMG(),
'hermite': GaussHermite(),
'gauss-hermite': GaussHermite(),
'split-normal': SplitNormal(),
'two-sided': SplitNormal(),
'skew-voigt': SkewVoigt(),
'skewvoigt': SkewVoigt(),
'boxgauss': BoxGauss(),
'box-gauss': BoxGauss(),
'boxcar': BoxGauss(),
}
[docs]
def resolve_profile(profile: str | Profile) -> Profile:
"""Convert a profile string or instance to a Profile object.
Parameters
----------
profile : str or Profile
Profile name (case-insensitive) or instance.
Returns
-------
Profile
Raises
------
ValueError
If the string is not a recognized profile alias.
"""
if isinstance(profile, Profile):
return profile
if isinstance(profile, str):
key = profile.lower()
if key not in _PROFILE_ALIASES:
valid = ', '.join(sorted(_PROFILE_ALIASES))
msg = f'Unknown profile {profile!r}. Valid names: {valid}'
raise ValueError(msg)
return _PROFILE_ALIASES[key]
msg = f'profile must be a str or Profile, got {type(profile).__name__}'
raise TypeError(msg)
# -------------------------------------------------------------------
# JAX dispatch: integration and evaluation branches
# -------------------------------------------------------------------
# Build the lax.switch branch lists once at import time.
# Each Profile subclass owns its branches via integrate_branch() and
# evaluate_branch(); sorted by code guarantees the list index matches
# Profile.code.
_INTEGRATE_BRANCHES = [
cls().integrate_branch()
for cls in sorted(_PROFILE_REGISTRY.values(), key=lambda c: c.code)
]
_EVALUATE_BRANCHES = [
cls().evaluate_branch()
for cls in sorted(_PROFILE_REGISTRY.values(), key=lambda c: c.code)
]
# Re-export from compute module for backward compatibility.
from unite.line.compute import integrate_lines, evaluate_lines # noqa: I001, E402, F401