Source code for unite.spectrum.spectrum

"""The Spectrum class — a single observed spectrum."""

from __future__ import annotations

import jax.numpy as jnp
from astropy import units as u

from unite._utils import _ensure_flux_density, _ensure_wavelength
from unite.instrument.base import Disperser


[docs] class Spectrum: """A single observed spectrum. A spectrum is defined by pixel bin edges (*low*, *high*), flux and error arrays, and a :class:`~unite.instrument.base.Disperser`. Calibration parameters live on the disperser as :class:`~unite.instrument.base.CalibParam` tokens (``disperser.r_scale``, ``disperser.flux_scale``, ``disperser.pix_offset``). Use :func:`~unite.spectrum.from_arrays`, :func:`~unite.spectrum.from_DJA`, or :func:`~unite.spectrum.from_sdss_fits` to construct spectra from arrays or instrument-native file formats. Parameters ---------- low : astropy.units.Quantity Lower wavelength edges of each pixel. Must be 1-D with wavelength (length) dimensions. high : astropy.units.Quantity Upper wavelength edges of each pixel. Same shape and compatible units as *low*. flux : astropy.units.Quantity Flux density values per pixel. Must be 1-D with the same length as *low* and carry spectral flux density per wavelength units (f_lambda, e.g. ``erg / s / cm^2 / Angstrom``). error : astropy.units.Quantity Flux density uncertainty per pixel. Must be 1-D with the same length as *low* and carry units compatible with *flux*. disperser : Disperser Instrumental disperser associated with this spectrum. Carries any calibration tokens (``r_scale``, ``flux_scale``, ``pix_offset``). name : str, optional Human-readable label (e.g. ``'G235H'``). Used in repr and for constructing numpyro site names. Defaults to ``disperser.name``. Raises ------ TypeError If *low* / *high* are not Quantities with wavelength dimensions, if *flux* / *error* are not Quantities with f_lambda dimensions, or if *disperser* is not a :class:`Disperser` instance. ValueError If array shapes are inconsistent or *low* ≥ *high* for any pixel. """ def __init__( self, low: u.Quantity, high: u.Quantity, flux: u.Quantity, error: u.Quantity, disperser: Disperser, *, name: str = '', ) -> None: # -- flux unit -------------------------------------------------------- flux = _ensure_flux_density(flux, 'flux', ndim=1) error = _ensure_flux_density(error, 'error', ndim=1) _flux_unit = flux.unit if not _flux_unit.is_equivalent(error.unit): msg = f'flux and error must have compatible units, got {flux.unit!r} and {error.unit!r}.' raise ValueError(msg) self._flux_unit: u.UnitBase = _flux_unit # -- disperser -------------------------------------------------------- if not isinstance(disperser, Disperser): msg = f'disperser must be a Disperser instance, got {type(disperser).__name__}.' raise TypeError(msg) self.disperser = disperser # -- wavelength edges ------------------------------------------------- low = _ensure_wavelength(low, 'low', ndim=1) high = _ensure_wavelength(high, 'high', ndim=1) if low.shape != high.shape: msg = f'low and high must have the same shape, got {low.shape} and {high.shape}.' raise ValueError(msg) # Store in the disperser's wavelength unit as JAX arrays. self._low = jnp.asarray(low.to(disperser.unit).value, dtype=float) self._high = jnp.asarray(high.to(disperser.unit).value, dtype=float) # -- flux and error --------------------------------------------------- # Convert error to the same unit as flux, then store bare values. error_converted = error.to(self._flux_unit) flux_arr = jnp.asarray(flux.value, dtype=float) error_arr = jnp.asarray(error_converted.value, dtype=float) npix = self._low.shape[0] for arr, label in ((flux_arr, 'flux'), (error_arr, 'error')): if arr.shape[0] != npix: msg = f'{label} length ({arr.shape[0]}) does not match the number of pixels ({npix}).' raise ValueError(msg) self._flux = flux_arr self._error = error_arr self._error_scale: jnp.ndarray | float = 1.0 self._scale_diagnostic: object = None # -- metadata --------------------------------------------------------- self.name = name or disperser.name # -- properties ----------------------------------------------------------- @property def low(self) -> jnp.ndarray: """Lower pixel-edge wavelengths in the disperser's unit.""" return self._low @property def high(self) -> jnp.ndarray: """Upper pixel-edge wavelengths in the disperser's unit.""" return self._high @property def wavelength(self) -> jnp.ndarray: """Pixel-center wavelengths (mean of low and high edges).""" return (self._low + self._high) / 2.0 @property def flux(self) -> jnp.ndarray: """Observed flux values per pixel.""" return self._flux @property def error(self) -> jnp.ndarray: """Flux uncertainty per pixel.""" return self._error @property def npix(self) -> int: """Number of pixels.""" return int(self._low.shape[0]) @property def unit(self) -> u.UnitBase: """Wavelength unit inherited from the disperser.""" return self.disperser.unit @property def flux_unit(self) -> u.UnitBase: """Flux density unit (f_lambda).""" return self._flux_unit @property def error_scale(self) -> jnp.ndarray | float: """Multiplicative scale factor applied to errors. Can be a scalar (applied uniformly) or a per-pixel array. """ return self._error_scale @error_scale.setter def error_scale(self, value: float | jnp.ndarray) -> None: arr = jnp.asarray(value, dtype=float) if arr.ndim == 0: if float(arr) <= 0: msg = f'error_scale must be > 0, got {float(arr)}' raise ValueError(msg) else: if arr.shape != (self.npix,): msg = ( f'error_scale array must have shape ({self.npix},), got {arr.shape}' ) raise ValueError(msg) if bool(jnp.any(arr <= 0)): msg = 'error_scale values must all be > 0' raise ValueError(msg) self._error_scale = arr if arr.ndim > 0 else float(arr) @property def scaled_error(self) -> jnp.ndarray: """Flux uncertainty scaled by :attr:`error_scale`.""" return self._error * self._error_scale @property def scale_diagnostic(self): """Continuum-fit diagnostics from the most recent :meth:`~unite.spectrum.Spectra.compute_scales` call. Returns a :class:`~unite.spectrum.SpectrumScaleDiagnostic` holding the line mask, the fitted continuum model array, and per-region fit details. ``None`` if :meth:`~unite.spectrum.Spectra.compute_scales` has not been called yet. The spectrum's own :attr:`wavelength`, :attr:`flux`, :attr:`error`, and unit attributes provide the full picture alongside this diagnostic. Examples -------- >>> diag = spectrum.scale_diagnostic >>> if diag is not None: ... cont = diag.continuum_model # NaN outside fitted regions ... mask = diag.line_mask # True where a pixel was excluded """ return self._scale_diagnostic @property def wavelength_range(self) -> tuple[float, float]: """``(min, max)`` wavelength in the disperser's unit.""" return float(self._low[0]), float(self._high[-1]) # -- calibration ---------------------------------------------------------- @property def has_calibration_priors(self) -> bool: """``True`` if any calibration token is set on the disperser.""" return self.disperser.has_calibration_params # -- coverage -------------------------------------------------------------
[docs] def covers(self, low: float, high: float) -> bool: """Return ``True`` if any pixel overlaps ``[low, high]``. Parameters ---------- low : float Lower bound in the disperser's unit. high : float Upper bound in the disperser's unit. """ return bool(jnp.any((self._high > low) & (self._low < high)))
[docs] def pixel_mask(self, low: float, high: float) -> jnp.ndarray: """Return a boolean array selecting pixels that overlap ``[low, high]``. Parameters ---------- low : float Lower bound in the disperser's unit. high : float Upper bound in the disperser's unit. Returns ------- jnp.ndarray Boolean array of shape ``(npix,)``. """ return (self._low > low) & (self._high < high)
# -- slicing (internal) --------------------------------------------------- def _sliced(self, mask: jnp.ndarray) -> Spectrum: """Return a new spectrum with arrays selected by a boolean mask. Bypasses ``__init__`` validation (arrays are already validated). Used internally by :class:`ModelBuilder` to trim spectra to continuum coverage before model evaluation. Parameters ---------- mask : jnp.ndarray Boolean array of shape ``(npix,)``. """ new = object.__new__(type(self)) new._low = self._low[mask] new._high = self._high[mask] new._flux = self._flux[mask] new._error = self._error[mask] new._flux_unit = self._flux_unit new.disperser = self.disperser new.name = self.name if isinstance(self._error_scale, (int, float)): new._error_scale = self._error_scale else: new._error_scale = self._error_scale[mask] return new # -- repr ----------------------------------------------------------------- def __repr__(self) -> str: lo, hi = self.wavelength_range unit_str = self.unit.to_string() cls_name = type(self).__name__ label = f'{cls_name} {self.name!r}' if self.name else cls_name cal = ' [calibrated]' if self.has_calibration_priors else '' return f'{label}: {self.npix} px, λ ∈ [{lo:.4g}, {hi:.4g}] {unit_str}{cal}'
# Re-export ArrayLike for type hints in downstream code (keeps it importable). __all__ = ['Spectrum']