Source code for unite.instrument.generic

"""Generic disperser implementations.

These are the building blocks for custom instruments.

Import from this module directly::

    from unite.instrument.generic import GenericDisperser, SimpleDisperser

The :class:`~unite.spectrum.Spectrum` class lives in :mod:`unite.spectrum`.
"""

from __future__ import annotations

from collections.abc import Callable

import jax.numpy as jnp
from astropy import units as u
from jax import Array
from jax.typing import ArrayLike

from unite._utils import C_KMS, _ensure_velocity, _ensure_wavelength
from unite.instrument.base import Disperser, FluxScale, PixOffset, RScale

# ---------------------------------------------------------------------------
# Generic dispersers
# ---------------------------------------------------------------------------


[docs] class GenericDisperser(Disperser): """A disperser defined by user-supplied, JAX-jittable callables. This is the most flexible concrete disperser: you provide arbitrary functions for *R(λ)* and *dλ/dpix(λ)* and they are forwarded directly. Parameters ---------- R_func : Callable[[ArrayLike], ArrayLike] A JAX-jittable function that returns the resolving power for a given array of wavelengths. dlam_dpix_func : Callable[[ArrayLike], ArrayLike] A JAX-jittable function that returns the linear dispersion for a given array of wavelengths. unit : astropy.units.UnitBase The wavelength unit the functions expect. name : str, optional Human-readable label. r_scale : RScale, optional Token for the resolving-power scale. flux_scale : FluxScale, optional Token for the flux normalisation. pix_offset : PixOffset, optional Token for the pixel shift. Examples -------- >>> import jax.numpy as jnp >>> from astropy import units as u >>> d = GenericDisperser( ... R_func=lambda w: jnp.full_like(w, 2700.0), ... dlam_dpix_func=lambda w: w / 2700.0, ... unit=u.Angstrom, ... ) >>> d.R(jnp.array([5000.0])) Array([2700.], dtype=float32) """ def __init__( self, R_func: Callable[[ArrayLike], ArrayLike], # noqa: N803 dlam_dpix_func: Callable[[ArrayLike], ArrayLike], unit: u.UnitBase, *, name: str = '', r_scale: RScale | None = None, flux_scale: FluxScale | None = None, pix_offset: PixOffset | None = None, ) -> None: super().__init__( unit, name=name, r_scale=r_scale, flux_scale=flux_scale, pix_offset=pix_offset, ) self._R_func = R_func self._dlam_dpix_func = dlam_dpix_func
[docs] def R(self, wavelength: ArrayLike) -> ArrayLike: """Return the resolving power by evaluating the stored callable.""" return self._R_func(wavelength)
[docs] def dlam_dpix(self, wavelength: ArrayLike) -> ArrayLike: """Return the linear dispersion by evaluating the stored callable.""" return self._dlam_dpix_func(wavelength)
[docs] class SimpleDisperser(Disperser): """A disperser defined on a pixel-sampled wavelength grid. The wavelength array is interpreted as a sequence of pixel centers so that *dλ/dpix* is computed directly from the spacing of the array (via ``jnp.gradient``). The resolving power is derived from exactly **one** of three keyword arguments: * ``R`` — resolving power, scalar or array matching *wavelength*. * ``dlam`` — spectral resolution element Δλ (same unit as *wavelength*). * ``dvel`` — velocity fwhm in **km/s**. Converted via *R = c / dvel*. A scalar value produces a **constant** R, Δλ, or Δv across the grid (and the corresponding resolving-power array is derived accordingly). An array value must have the same length as *wavelength*. Parameters ---------- wavelength : u.Quantity Pixel-center wavelengths. Must be 1-D. R : ArrayLike, optional Resolving power (scalar or per-pixel array). dlam : u.Quantity, optional Spectral resolution Δλ, must be in wavelength units (scalar or per-pixel array) dvel : u.Quantity, optional Velocity resolution in velocity units (scalar or per-pixel array). name : str, optional Human-readable label. r_scale : RScale, optional Token for the resolving-power scale. flux_scale : FluxScale, optional Token for the flux normalisation. pix_offset : PixOffset, optional Token for the pixel shift. Raises ------ ValueError If zero or more than one of ``R``, ``dlam``, ``dvel`` is provided, or if an array argument has the wrong length. Notes ----- When :meth:`R` or :meth:`dlam_dpix` is called at wavelengths that differ from the stored grid, the values are linearly interpolated with ``jnp.interp``. """ def __init__( self, wavelength: u.Quantity, *, R: ArrayLike | None = None, # noqa: N803 dlam: u.Quantity | None = None, dvel: u.Quantity | None = None, name: str = '', r_scale: RScale | None = None, flux_scale: FluxScale | None = None, pix_offset: PixOffset | None = None, ) -> None: wavelength = _ensure_wavelength(wavelength, 'wavelength', ndim=1) super().__init__( wavelength.unit, name=name, r_scale=r_scale, flux_scale=flux_scale, pix_offset=pix_offset, ) n_specified = sum(x is not None for x in (R, dlam, dvel)) if n_specified != 1: msg = f'Exactly one of R, dlam, or dvel must be provided, but {n_specified} were given.' raise ValueError(msg) self._wavelength = jnp.asarray(wavelength.value, dtype=float) # Compute dlam_dpix from the pixel grid. # jnp.gradient returns Array | list[Array]; cast to Array for type checker. self._dlam_dpix_grid: Array = jnp.asarray(jnp.gradient(self._wavelength)) # Compute resolving power on the grid. if R is not None: self._R_grid = self._validated_input(R, 'R') elif dlam is not None: dlam = _ensure_wavelength(dlam, 'dlam', ndim=[0, 1]) dlam_arr = self._validated_input(dlam.to(wavelength.unit).value, 'dlam') self._R_grid = self._wavelength / dlam_arr else: dvel = _ensure_velocity(dvel, 'dvel', ndim=[0, 1]) dvel_arr = self._validated_input(dvel.to(u.km / u.s).value, 'dvel') self._R_grid = C_KMS / dvel_arr def _validated_input(self, value: ArrayLike, name: str) -> jnp.ndarray: """Convert *value* to a grid-shaped array. Scalars (0-d) are broadcast to the grid shape. 1-d arrays must match the grid length exactly; anything else raises `ValueError`. Parameters ---------- value : ArrayLike Scalar or 1-d array supplied by the caller. name : str Parameter name used in error messages. Returns ------- jnp.ndarray Array with the same shape as ``self._wavelength``. """ arr = jnp.asarray(value, dtype=float) if arr.ndim == 0: return jnp.broadcast_to(arr, self._wavelength.shape) if arr.shape != self._wavelength.shape: msg = ( f'{name} must be a scalar or have the same shape as ' f'wavelength {self._wavelength.shape}, ' f'got shape {arr.shape}.' ) raise ValueError(msg) return arr
[docs] def R(self, wavelength: ArrayLike) -> ArrayLike: """Return the resolving power, interpolated onto *wavelength*.""" return jnp.interp(wavelength, self._wavelength, self._R_grid)
[docs] def dlam_dpix(self, wavelength: ArrayLike) -> ArrayLike: """Return the linear dispersion, interpolated onto *wavelength*.""" return jnp.interp(wavelength, self._wavelength, self._dlam_dpix_grid)