Source code for unite.line.library

"""Concrete line profile implementations."""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import override

from jax import Array
from jax.typing import ArrayLike

from unite._utils import _make_register
from unite.line import functions
from unite.prior import Prior, TruncatedNormal, Uniform

# -------------------------------------------------------------------
# Base Profile class
# -------------------------------------------------------------------


[docs] class Profile(ABC): """Abstract base class for spectral line profiles. A profile declares which parameters it requires (via :meth:`param_names` and :meth:`default_priors`) and provides an :meth:`integrate` method that computes the profile integral over wavelength bins. Each concrete subclass carries an integer :attr:`code` for dispatch in JAX arrays, and supports serialization via :meth:`to_dict` / :meth:`from_dict`. """ #: Integer code for this profile type, used in JAX arrays. code: int
[docs] @abstractmethod def param_names(self) -> tuple[str, ...]: """Return names of parameters this profile requires. Returns ------- tuple of str For example, ``('fwhm_gauss',)`` for Gaussian, ``('fwhm_gauss', 'fwhm_lorentz')`` for pseudo-Voigt, or ``('fwhm_gauss', 'h3', 'h4')`` for Gauss-Hermite. """
[docs] @abstractmethod def default_priors(self) -> dict[str, Prior]: """Return sensible default priors for each parameter. The keys must match :meth:`param_names`. These are used when the user does not supply an explicit token for a parameter. Returns ------- dict of str to Prior For example, ``{'fwhm_gauss': Uniform(0, 1000)}``. """
[docs] def integrate( self, low: ArrayLike, high: ArrayLike, center: ArrayLike, lsf_fwhm: ArrayLike, **params: ArrayLike, ) -> Array: r"""Integrate the profile over wavelength bins. Delegates to :meth:`integrate_branch` by mapping keyword arguments to positional slots (p0, p1, p2) in :meth:`param_names` order. Parameters ---------- low : ArrayLike Lower wavelength edges of bins. high : ArrayLike Upper wavelength edges of bins. center : ArrayLike Line center wavelength. lsf_fwhm : ArrayLike Instrumental line spread function FWHM at the line center. \*\*params : ArrayLike Parameter values, keyed by the names from :meth:`param_names`. Returns ------- Array Fractional flux integrated in each bin (sums to 1 over all bins). """ pnames = self.param_names() p0 = params[pnames[0]] if len(pnames) > 0 else 0.0 p1 = params[pnames[1]] if len(pnames) > 1 else 0.0 p2 = params[pnames[2]] if len(pnames) > 2 else 0.0 return self.integrate_branch()(low, high, center, lsf_fwhm, p0, p1, p2)
[docs] def evaluate( self, wavelength: ArrayLike, center: ArrayLike, lsf_fwhm: ArrayLike, **params: ArrayLike, ) -> Array: r"""Evaluate the normalised profile at wavelength points. Delegates to :meth:`evaluate_branch` by mapping keyword arguments to positional slots (p0, p1, p2) in :meth:`param_names` order. Parameters ---------- wavelength : ArrayLike Wavelength points at which to evaluate the profile. center : ArrayLike Line center wavelength. lsf_fwhm : ArrayLike Instrumental line spread function FWHM at the line center. \*\*params : ArrayLike Parameter values, keyed by the names from :meth:`param_names`. Returns ------- Array Normalised profile value at each wavelength point (1/wavelength units). """ pnames = self.param_names() p0 = params[pnames[0]] if len(pnames) > 0 else 0.0 p1 = params[pnames[1]] if len(pnames) > 1 else 0.0 p2 = params[pnames[2]] if len(pnames) > 2 else 0.0 return self.evaluate_branch()(wavelength, center, lsf_fwhm, p0, p1, p2)
[docs] @abstractmethod def integrate_branch(self) -> Callable[..., Array]: """Return a JAX-compatible branch callable for ``lax.switch`` dispatch. The returned function must have the fixed signature:: fn(low, high, center, lsf_fwhm, p0, p1, p2) -> Array Parameters correspond to :meth:`param_names` in order: ``p0`` is ``param_names()[0]``, ``p1`` is ``param_names()[1]``, ``p2`` is ``param_names()[2]``. Unused slots receive zero from the model builder and must be ignored. Returns ------- callable A pure-JAX function suitable as a ``lax.switch`` branch. """
[docs] @abstractmethod def evaluate_branch(self) -> Callable[..., Array]: """Return a JAX-compatible branch callable for pointwise evaluation. The returned function must have the fixed signature:: fn(wavelength, center, lsf_fwhm, p0, p1, p2) -> Array Returns the normalised profile value at each wavelength point. Returns ------- callable A pure-JAX function suitable as a ``lax.switch`` branch. """
[docs] @abstractmethod def to_dict(self) -> dict: """Serialize to a YAML-safe dictionary."""
[docs] @classmethod @abstractmethod def from_dict(cls, d: dict) -> Profile: """Deserialize from a dictionary."""
# ------------------------------------------------------------------- # Registry for deserialization # ------------------------------------------------------------------- _PROFILE_REGISTRY: dict[str, type[Profile]] = {} _register = _make_register(_PROFILE_REGISTRY)
[docs] def profile_from_dict(d: dict) -> Profile: """Deserialize a Profile from a dictionary using the 'type' key. Parameters ---------- d : dict Dictionary with a ``'type'`` key matching a registered profile class. Returns ------- Profile Raises ------ KeyError If the type is not registered. """ cls = _PROFILE_REGISTRY[d['type']] return cls.from_dict(d)
# ------------------------------------------------------------------- # Concrete profiles # -------------------------------------------------------------------
[docs] @_register class Gaussian(Profile): """Gaussian (normal) line profile. Requires a single parameter ``fwhm_gauss``. The instrumental LSF is added in quadrature: ``total_fwhm = sqrt(lsf_fwhm² + fwhm_gauss²)``. """ code = 0
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_gauss',)
[docs] @override def default_priors(self) -> dict[str, Prior]: return {'fwhm_gauss': Uniform(0, 1000)}
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_gauss return functions.integrate_gaussian(lo, hi, c, lsf, p0) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_gaussian(wavelength, c, lsf, p0) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'Gaussian'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> Gaussian: return cls()
@override def __repr__(self) -> str: return 'Gaussian()'
[docs] @_register class Cauchy(Profile): """Cauchy (Lorentzian) line profile. Requires a single parameter ``fwhm_lorentz``. The LSF is **not** convolved — this profile is a pure Lorentzian. Note: This profile is implemented as a PseudoVoigt with LSF=0 for consistency with the scientific assumptions of the package (all lines are convolved with instrumental LSF). """ code = 1
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_lorentz',)
[docs] @override def default_priors(self) -> dict[str, Prior]: return {'fwhm_lorentz': Uniform(0, 1000)}
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_lorentz; pure Cauchy via PseudoVoigt with zero Gaussian width return functions.integrate_voigt(lo, hi, c, lsf, 0.0, p0) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_voigt(wavelength, c, lsf, 0.0, p0) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'Cauchy'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> Cauchy: return cls()
@override def __repr__(self) -> str: return 'Cauchy()'
[docs] @_register class PseudoVoigt(Profile): """Pseudo-Voigt line profile (Thompson et al. 1987). Requires two parameters: ``fwhm_gauss`` for the Gaussian component and ``fwhm_lorentz`` for the Lorentzian component. The instrumental LSF is added in quadrature to the Gaussian component. """ code = 2
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_gauss', 'fwhm_lorentz')
[docs] @override def default_priors(self) -> dict[str, Prior]: return {'fwhm_gauss': Uniform(0, 1000), 'fwhm_lorentz': Uniform(0, 1000)}
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_gauss, p1 = fwhm_lorentz return functions.integrate_voigt(lo, hi, c, lsf, p0, p1) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_voigt(wavelength, c, lsf, p0, p1) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'PseudoVoigt'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> PseudoVoigt: return cls()
@override def __repr__(self) -> str: return 'PseudoVoigt()'
[docs] @_register class Laplace(Profile): """Laplace (double-exponential) line profile. Requires a single parameter ``fwhm_exp``. The LSF is **not** convolved --- this profile is a pure Laplace distribution. """ code = 3
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_exp',)
[docs] @override def default_priors(self) -> dict[str, Prior]: return {'fwhm_exp': Uniform(0, 1000)}
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_exp; pure Laplace convolved with Gaussian LSF return functions.integrate_gaussianLaplace(lo, hi, c, lsf, 0.0, p0) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_gaussianLaplace(wavelength, c, lsf, 0.0, p0) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'Laplace'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> Laplace: return cls()
@override def __repr__(self) -> str: return 'Laplace()'
[docs] @_register class SEMG(Profile): """Symmetric Exponentially Modified Gaussian (SEMG) line profile. A Gaussian (with LSF) convolved with a symmetric Laplace (double-exponential) distribution. Requires two parameters: ``fwhm_gauss`` for the intrinsic Gaussian component and ``fwhm_exp`` for the Laplacian component. The instrumental LSF is added in quadrature to the Gaussian component. """ code = 4
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_gauss', 'fwhm_exp')
[docs] @override def default_priors(self) -> dict[str, Prior]: return {'fwhm_gauss': Uniform(0, 1000), 'fwhm_exp': Uniform(0, 1000)}
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_gauss, p1 = fwhm_exp return functions.integrate_gaussianLaplace(lo, hi, c, lsf, p0, p1) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_gaussianLaplace(wavelength, c, lsf, p0, p1) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'SEMG'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> SEMG: return cls()
@override def __repr__(self) -> str: return 'SEMG()'
[docs] @_register class GaussHermite(Profile): """Gauss-Hermite line profile. A Gaussian (with LSF) modified by Hermite polynomial corrections for skewness (h3) and kurtosis (h4). Requires three parameters: ``fwhm_gauss`` for the intrinsic Gaussian FWHM, ``h3`` for the skewness coefficient, and ``h4`` for the kurtosis coefficient. The instrumental LSF is added in quadrature to the Gaussian component. """ code = 5
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_gauss', 'h3', 'h4')
[docs] @override def default_priors(self) -> dict[str, Prior]: return { 'fwhm_gauss': Uniform(0, 1000), 'h3': TruncatedNormal(loc=0, scale=0.1, low=-0.3, high=0.3), 'h4': TruncatedNormal(loc=0, scale=0.1, low=-0.3, high=0.3), }
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_gauss, p1 = h3, p2 = h4 return functions.integrate_gaussHermite(lo, hi, c, lsf, p0, p1, p2) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_gaussHermite(wavelength, c, lsf, p0, p1, p2) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'GaussHermite'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> GaussHermite: return cls()
@override def __repr__(self) -> str: return 'GaussHermite()'
[docs] @_register class SplitNormal(Profile): """Split-normal (two-sided Gaussian) line profile. A Gaussian with different standard deviations on each side of the mean. Requires two parameters: ``fwhm_blue`` for the blue (left) side and ``fwhm_red`` for the red (right) side. The instrumental LSF is added in quadrature to both components. """ code = 6
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_blue', 'fwhm_red')
[docs] @override def default_priors(self) -> dict[str, Prior]: return {'fwhm_blue': Uniform(0, 1000), 'fwhm_red': Uniform(0, 1000)}
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_blue, p1 = fwhm_red return functions.integrate_split_normal(lo, hi, c, lsf, p0, p1) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_split_normal(wavelength, c, lsf, p0, p1) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'SplitNormal'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> SplitNormal: return cls()
@override def __repr__(self) -> str: return 'SplitNormal()'
[docs] @_register class BoxGauss(Profile): """Boxcar distribution convolved with a Gaussian. The intrinsic profile is a uniform rectangular (boxcar) distribution of full width ``fwhm_box`` centred at zero (area = 1), convolved with a Gaussian whose FWHM is the quadrature sum of ``fwhm_gauss`` and ``lsf_fwhm``. As ``fwhm_box`` → 0 the profile reduces to a pure Gaussian; as ``fwhm_gauss`` → 0 (and ``lsf_fwhm`` → 0) it approaches the sharp rectangular distribution. Requires two parameters: ``fwhm_box`` for the boxcar full width and ``fwhm_gauss`` for the intrinsic Gaussian component. """ code = 8
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_box', 'fwhm_gauss')
[docs] @override def default_priors(self) -> dict[str, Prior]: return {'fwhm_box': Uniform(0, 1000), 'fwhm_gauss': Uniform(0, 1000)}
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_box, p1 = fwhm_gauss return functions.integrate_boxGauss(lo, hi, c, lsf, p0, p1) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_boxGauss(wavelength, c, lsf, p0, p1) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'BoxGauss'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> BoxGauss: return cls()
@override def __repr__(self) -> str: return 'BoxGauss()'
[docs] @_register class SkewVoigt(Profile): r"""Skew pseudo-Voigt line profile. A pseudo-Voigt profile multiplied by a skew factor ``[1 + erf(alpha * (x - c) / (sqrt(2) * sigma_g))]``, where ``sigma_g`` is the standard deviation of the Gaussian component. The profile integrates to 1 for any value of ``alpha`` because the skew factor is odd and the pseudo-Voigt is even. Convolution with the Gaussian LSF rescales the skewness parameter to .. math:: \\alpha_\\text{eff} = \\frac{\\alpha\\,\\sigma_g} {\\sqrt{\\sigma_\\text{tot}^2 + \\alpha^2\\sigma_\\text{lsf}^2}} where :math:`\\sigma_\\text{tot} = \\sqrt{\\sigma_g^2 + \\sigma_\\text{lsf}^2}`. The skewness is reduced by the LSF and vanishes entirely when ``fwhm_gauss = 0`` (no intrinsic Gaussian component). Requires three parameters: ``fwhm_gauss`` for the Gaussian component, ``fwhm_lorentz`` for the Lorentzian component, and ``alpha`` for the skewness (positive values shift flux redward). """ code = 7
[docs] @override def param_names(self) -> tuple[str, ...]: return ('fwhm_gauss', 'fwhm_lorentz', 'alpha')
[docs] @override def default_priors(self) -> dict[str, Prior]: return { 'fwhm_gauss': Uniform(0, 1000), 'fwhm_lorentz': Uniform(0, 1000), 'alpha': TruncatedNormal(loc=0, scale=100, low=-300, high=300), }
[docs] @override def integrate_branch(self): def _fn(lo, hi, c, lsf, p0, p1, p2): # p0 = fwhm_gauss, p1 = fwhm_lorentz, p2 = alpha return functions.integrate_skewVoigt(lo, hi, c, lsf, p0, p1, p2) return _fn
[docs] @override def evaluate_branch(self): def _fn(wavelength, c, lsf, p0, p1, p2): return functions.evaluate_skewVoigt(wavelength, c, lsf, p0, p1, p2) return _fn
[docs] @override def to_dict(self) -> dict: return {'type': 'SkewVoigt'}
[docs] @classmethod @override def from_dict(cls, d: dict) -> SkewVoigt: return cls()
@override def __repr__(self) -> str: return 'SkewVoigt()'
_PROFILE_ALIASES: dict[str, Profile] = { 'gaussian': Gaussian(), 'normal': Gaussian(), 'lorentzian': Cauchy(), 'cauchy': Cauchy(), 'exponential': Laplace(), 'laplace': Laplace(), 'voigt': PseudoVoigt(), 'pseudovoigt': PseudoVoigt(), 'semg': SEMG(), 'exp-gaussian': SEMG(), 'hermite': GaussHermite(), 'gauss-hermite': GaussHermite(), 'split-normal': SplitNormal(), 'two-sided': SplitNormal(), 'skew-voigt': SkewVoigt(), 'skewvoigt': SkewVoigt(), 'boxgauss': BoxGauss(), 'box-gauss': BoxGauss(), 'boxcar': BoxGauss(), }
[docs] def resolve_profile(profile: str | Profile) -> Profile: """Convert a profile string or instance to a Profile object. Parameters ---------- profile : str or Profile Profile name (case-insensitive) or instance. Returns ------- Profile Raises ------ ValueError If the string is not a recognized profile alias. """ if isinstance(profile, Profile): return profile if isinstance(profile, str): key = profile.lower() if key not in _PROFILE_ALIASES: valid = ', '.join(sorted(_PROFILE_ALIASES)) msg = f'Unknown profile {profile!r}. Valid names: {valid}' raise ValueError(msg) return _PROFILE_ALIASES[key] msg = f'profile must be a str or Profile, got {type(profile).__name__}' raise TypeError(msg)
# ------------------------------------------------------------------- # JAX dispatch: integration and evaluation branches # ------------------------------------------------------------------- # Build the lax.switch branch lists once at import time. # Each Profile subclass owns its branches via integrate_branch() and # evaluate_branch(); sorted by code guarantees the list index matches # Profile.code. _INTEGRATE_BRANCHES = [ cls().integrate_branch() for cls in sorted(_PROFILE_REGISTRY.values(), key=lambda c: c.code) ] _EVALUATE_BRANCHES = [ cls().evaluate_branch() for cls in sorted(_PROFILE_REGISTRY.values(), key=lambda c: c.code) ] # Re-export from compute module for backward compatibility. from unite.line.compute import integrate_lines, evaluate_lines # noqa: I001, E402, F401