"""The Spectrum class — a single observed spectrum."""
from __future__ import annotations
import jax.numpy as jnp
import numpy as np
from astropy import units as u
from unite._utils import _ensure_flux_density, _ensure_wavelength
from unite.instrument.base import Disperser
def _compute_edge_topology(
low: jnp.ndarray, high: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Build the unique-edge topology for a 1-D pixel grid.
Adjacent pixels share an edge when ``high[i] == low[i+1]``; chip gaps
(``high[i] != low[i+1]``) contribute two edges instead of one. The
returned ``edges`` array therefore has length ``npix + num_gaps + 1``
and the ``keep_mask`` marks which ``diff(edges)`` entries correspond
to real pixels (True) versus inter-pixel gaps (False).
Parameters
----------
low, high : jnp.ndarray, shape ``(npix,)``
Pixel edges in the spectrum's wavelength unit.
Returns
-------
edges : jnp.ndarray, shape ``(E,)``
Sorted unique edges, where ``E = npix + num_gaps + 1``.
keep_mask : jnp.ndarray, shape ``(E - 1,)`` of bool
True where the corresponding ``diff(edges)`` entry is a real pixel
width; False where it spans an inter-pixel gap.
pixel_idx : jnp.ndarray, shape ``(npix,)``
Indices into the length-``(E - 1)`` axis selecting the real pixels
in original order (i.e. ``edges[1:][pixel_idx] == high`` and
``edges[:-1][pixel_idx] == low``).
"""
low_np = np.asarray(low)
high_np = np.asarray(high)
npix = low_np.shape[0]
if npix == 0:
return (
jnp.zeros(0, dtype=low.dtype),
jnp.zeros(0, dtype=bool),
jnp.zeros(0, dtype=jnp.int32),
)
# Gap after pixel i if high[i] != low[i+1] (no shared edge with i+1).
gap_after = high_np[:-1] != low_np[1:]
n_gaps = int(gap_after.sum())
n_edges = npix + n_gaps + 1
edges_np = np.empty(n_edges, dtype=low_np.dtype)
keep_np = np.empty(n_edges - 1, dtype=bool)
pix_idx_np = np.empty(npix, dtype=np.int32)
edges_np[0] = low_np[0]
w = 1
for i in range(npix):
edges_np[w] = high_np[i]
keep_np[w - 1] = True
pix_idx_np[i] = w - 1
w += 1
if i < npix - 1 and gap_after[i]:
edges_np[w] = low_np[i + 1]
keep_np[w - 1] = False
w += 1
return jnp.asarray(edges_np), jnp.asarray(keep_np), jnp.asarray(pix_idx_np)
[docs]
class Spectrum:
"""A single observed spectrum.
A spectrum is defined by pixel bin edges (*low*, *high*), flux and error
arrays, and a :class:`~unite.instrument.base.Disperser`. Calibration
parameters live on the disperser as
:class:`~unite.instrument.base.CalibParam` tokens (``disperser.r_scale``,
``disperser.flux_scale``, ``disperser.pix_offset``).
Use :func:`~unite.spectrum.from_arrays`, :func:`~unite.spectrum.from_DJA`,
or :func:`~unite.spectrum.from_sdss_fits` to construct spectra from arrays
or instrument-native file formats.
Parameters
----------
low : astropy.units.Quantity
Lower wavelength edges of each pixel. Must be 1-D with wavelength
(length) dimensions.
high : astropy.units.Quantity
Upper wavelength edges of each pixel. Same shape and compatible
units as *low*.
flux : astropy.units.Quantity
Flux density values per pixel. Must be 1-D with the same length
as *low* and carry spectral flux density per wavelength units
(f_lambda, e.g. ``erg / s / cm^2 / Angstrom``).
error : astropy.units.Quantity
Flux density uncertainty per pixel. Must be 1-D with the same
length as *low* and carry units compatible with *flux*.
disperser : Disperser
Instrumental disperser associated with this spectrum. Carries any
calibration tokens (``r_scale``, ``flux_scale``, ``pix_offset``).
name : str, optional
Human-readable label (e.g. ``'G235H'``). Used in repr and for
constructing numpyro site names. Defaults to ``disperser.name``.
Raises
------
TypeError
If *low* / *high* are not Quantities with wavelength dimensions,
if *flux* / *error* are not Quantities with f_lambda dimensions,
or if *disperser* is not a :class:`Disperser` instance.
ValueError
If array shapes are inconsistent or *low* ≥ *high* for any pixel.
"""
def __init__(
self,
low: u.Quantity,
high: u.Quantity,
flux: u.Quantity,
error: u.Quantity,
disperser: Disperser,
*,
name: str = '',
) -> None:
# -- flux unit --------------------------------------------------------
flux = _ensure_flux_density(flux, 'flux', ndim=1)
error = _ensure_flux_density(error, 'error', ndim=1)
_flux_unit = flux.unit
if not _flux_unit.is_equivalent(error.unit):
msg = f'flux and error must have compatible units, got {flux.unit!r} and {error.unit!r}.'
raise ValueError(msg)
self._flux_unit: u.UnitBase = _flux_unit
# -- disperser --------------------------------------------------------
if not isinstance(disperser, Disperser):
msg = f'disperser must be a Disperser instance, got {type(disperser).__name__}.'
raise TypeError(msg)
self.disperser = disperser
# -- wavelength edges -------------------------------------------------
low = _ensure_wavelength(low, 'low', ndim=1)
high = _ensure_wavelength(high, 'high', ndim=1)
if low.shape != high.shape:
msg = f'low and high must have the same shape, got {low.shape} and {high.shape}.'
raise ValueError(msg)
# Store in the disperser's wavelength unit as JAX arrays.
self._low = jnp.asarray(low.to(disperser.unit).value, dtype=float)
self._high = jnp.asarray(high.to(disperser.unit).value, dtype=float)
self._wavelength = (self._low + self._high) / 2.0
# Edge topology: unique edges with inter-pixel gap bookkeeping.
self._edges, self._keep_mask, self._pixel_idx = _compute_edge_topology(
self._low, self._high
)
# -- flux and error ---------------------------------------------------
# Convert error to the same unit as flux, then store bare values.
error_converted = error.to(self._flux_unit)
flux_arr = jnp.asarray(flux.value, dtype=float)
error_arr = jnp.asarray(error_converted.value, dtype=float)
npix = self._low.shape[0]
for arr, label in ((flux_arr, 'flux'), (error_arr, 'error')):
if arr.shape[0] != npix:
msg = f'{label} length ({arr.shape[0]}) does not match the number of pixels ({npix}).'
raise ValueError(msg)
self._flux = flux_arr
self._error = error_arr
self._error_scale: jnp.ndarray | float = 1.0
self._scale_diagnostic: object = None
# -- metadata ---------------------------------------------------------
self.name = name or disperser.name
# -- properties -----------------------------------------------------------
@property
def low(self) -> jnp.ndarray:
"""Lower pixel-edge wavelengths in the disperser's unit."""
return self._low
@property
def high(self) -> jnp.ndarray:
"""Upper pixel-edge wavelengths in the disperser's unit."""
return self._high
@property
def wavelength(self) -> jnp.ndarray:
"""Pixel-center wavelengths (mean of low and high edges)."""
return self._wavelength
@property
def edges(self) -> jnp.ndarray:
"""Unique sorted pixel edges, length ``E = npix + num_gaps + 1``.
Contiguous pixels share an edge (one entry per shared boundary);
chip gaps contribute two entries — ``high`` of the pixel before
the gap and ``low`` of the pixel after. Use together with
:attr:`keep_mask` to recover per-pixel quantities from
edge-evaluated cumulative arrays via ``diff`` + masking.
"""
return self._edges
@property
def keep_mask(self) -> jnp.ndarray:
"""Boolean mask of length ``E - 1`` selecting real-pixel diffs.
``jnp.diff(edges)[keep_mask]`` gives the per-pixel widths; the
masked-out entries span inter-pixel gaps and should be discarded
from any edge-cumulative-then-diff reduction.
"""
return self._keep_mask
@property
def midpoints(self) -> jnp.ndarray:
"""Midpoint of each ``(edges[i], edges[i+1])`` pair, length ``E - 1``.
Entries where :attr:`keep_mask` is False fall inside chip gaps and
should be ignored.
"""
return (self._edges[1:] + self._edges[:-1]) / 2.0
@property
def widths(self) -> jnp.ndarray:
"""``diff(edges)``, length ``E - 1``.
Entries where :attr:`keep_mask` is False are inter-pixel gap widths,
not pixel widths, and should be ignored.
"""
return jnp.diff(self._edges)
@property
def pixel_idx(self) -> jnp.ndarray:
"""Indices into the length-``(E - 1)`` axis selecting real pixels in order.
Convenience accessor equivalent to ``jnp.where(keep_mask)[0]`` but
cached at construction time as a concrete array.
"""
return self._pixel_idx
@property
def flux(self) -> jnp.ndarray:
"""Observed flux values per pixel."""
return self._flux
@property
def error(self) -> jnp.ndarray:
"""Flux uncertainty per pixel."""
return self._error
@property
def npix(self) -> int:
"""Number of pixels."""
return int(self._low.shape[0])
@property
def unit(self) -> u.UnitBase:
"""Wavelength unit inherited from the disperser."""
return self.disperser.unit
@property
def flux_unit(self) -> u.UnitBase:
"""Flux density unit (f_lambda)."""
return self._flux_unit
@property
def error_scale(self) -> jnp.ndarray | float:
"""Multiplicative scale factor applied to errors.
Can be a scalar (applied uniformly) or a per-pixel array.
"""
return self._error_scale
@error_scale.setter
def error_scale(self, value: float | jnp.ndarray) -> None:
arr = jnp.asarray(value, dtype=float)
if arr.ndim == 0:
if float(arr) <= 0:
msg = f'error_scale must be > 0, got {float(arr)}'
raise ValueError(msg)
else:
if arr.shape != (self.npix,):
msg = (
f'error_scale array must have shape ({self.npix},), got {arr.shape}'
)
raise ValueError(msg)
if bool(jnp.any(arr <= 0)):
msg = 'error_scale values must all be > 0'
raise ValueError(msg)
self._error_scale = arr if arr.ndim > 0 else float(arr)
@property
def scaled_error(self) -> jnp.ndarray:
"""Flux uncertainty scaled by :attr:`error_scale`."""
return self._error * self._error_scale
@property
def scale_diagnostic(self):
"""Continuum-fit diagnostics from the most recent :meth:`~unite.spectrum.Spectra.compute_scales` call.
Returns a :class:`~unite.spectrum.SpectrumScaleDiagnostic`
holding the line mask, the fitted continuum model array, and
per-region fit details. ``None`` if
:meth:`~unite.spectrum.Spectra.compute_scales` has not been
called yet.
The spectrum's own :attr:`wavelength`, :attr:`flux`, :attr:`error`, and
unit attributes provide the full picture alongside this diagnostic.
Examples
--------
>>> diag = spectrum.scale_diagnostic
>>> if diag is not None:
... cont = diag.continuum_model # NaN outside fitted regions
... mask = diag.line_mask # True where a pixel was excluded
"""
return self._scale_diagnostic
@property
def wavelength_range(self) -> tuple[float, float]:
"""``(min, max)`` wavelength in the disperser's unit."""
return float(self._low[0]), float(self._high[-1])
# -- calibration ----------------------------------------------------------
@property
def has_calibration_priors(self) -> bool:
"""``True`` if any calibration token is set on the disperser."""
return self.disperser.has_calibration_params
# -- coverage -------------------------------------------------------------
[docs]
def covers(self, low: float, high: float) -> bool:
"""Return ``True`` if any pixel overlaps ``[low, high]``.
Parameters
----------
low : float
Lower bound in the disperser's unit.
high : float
Upper bound in the disperser's unit.
"""
return bool(jnp.any((self._high > low) & (self._low < high)))
[docs]
def pixel_mask(self, low: float, high: float) -> jnp.ndarray:
"""Return a boolean array selecting pixels that overlap ``[low, high]``.
Parameters
----------
low : float
Lower bound in the disperser's unit.
high : float
Upper bound in the disperser's unit.
Returns
-------
jnp.ndarray
Boolean array of shape ``(npix,)``.
"""
return (self._low > low) & (self._high < high)
# -- slicing (internal) ---------------------------------------------------
def _sliced(self, mask: jnp.ndarray) -> Spectrum:
"""Return a new spectrum with arrays selected by a boolean mask.
Bypasses ``__init__`` validation (arrays are already validated).
Used internally by :class:`ModelBuilder` to trim spectra to
continuum coverage before model evaluation.
Parameters
----------
mask : jnp.ndarray
Boolean array of shape ``(npix,)``.
"""
new = object.__new__(type(self))
new._low = self._low[mask]
new._high = self._high[mask]
new._wavelength = (new._low + new._high) / 2.0
new._edges, new._keep_mask, new._pixel_idx = _compute_edge_topology(
new._low, new._high
)
new._flux = self._flux[mask]
new._error = self._error[mask]
new._flux_unit = self._flux_unit
new.disperser = self.disperser
new.name = self.name
if isinstance(self._error_scale, (int, float)):
new._error_scale = self._error_scale
else:
new._error_scale = self._error_scale[mask]
return new
# -- repr -----------------------------------------------------------------
def __repr__(self) -> str:
lo, hi = self.wavelength_range
unit_str = self.unit.to_string()
cls_name = type(self).__name__
label = f'{cls_name} {self.name!r}' if self.name else cls_name
cal = ' [calibrated]' if self.has_calibration_priors else ''
return f'{label}: {self.npix} px, λ ∈ [{lo:.4g}, {hi:.4g}] {unit_str}{cal}'
# Re-export ArrayLike for type hints in downstream code (keeps it importable).
__all__ = ['Spectrum']