"""Generic disperser implementations.
These are the building blocks for custom instruments.
Import from this module directly::
from unite.instrument.generic import GenericDisperser, SimpleDisperser
The :class:`~unite.spectrum.Spectrum` class lives in :mod:`unite.spectrum`.
"""
from __future__ import annotations
from collections.abc import Callable
import jax.numpy as jnp
from astropy import units as u
from jax import Array
from jax.typing import ArrayLike
from unite._utils import C_KMS, _ensure_velocity, _ensure_wavelength
from unite.instrument.base import Disperser, FluxScale, PixOffset, RScale
# ---------------------------------------------------------------------------
# Generic dispersers
# ---------------------------------------------------------------------------
[docs]
class GenericDisperser(Disperser):
"""A disperser defined by user-supplied, JAX-jittable callables.
This is the most flexible concrete disperser: you provide arbitrary
functions for *R(λ)* and *dλ/dpix(λ)* and they are forwarded directly.
Parameters
----------
R_func : Callable[[ArrayLike], ArrayLike]
A JAX-jittable function that returns the resolving power for a given
array of wavelengths.
dlam_dpix_func : Callable[[ArrayLike], ArrayLike]
A JAX-jittable function that returns the linear dispersion for a
given array of wavelengths.
unit : astropy.units.UnitBase
The wavelength unit the functions expect.
name : str, optional
Human-readable label.
r_scale : RScale, optional
Token for the resolving-power scale.
flux_scale : FluxScale, optional
Token for the flux normalisation.
pix_offset : PixOffset, optional
Token for the pixel shift.
Examples
--------
>>> import jax.numpy as jnp
>>> from astropy import units as u
>>> d = GenericDisperser(
... R_func=lambda w: jnp.full_like(w, 2700.0),
... dlam_dpix_func=lambda w: w / 2700.0,
... unit=u.Angstrom,
... )
>>> d.R(jnp.array([5000.0]))
Array([2700.], dtype=float32)
"""
def __init__(
self,
R_func: Callable[[ArrayLike], ArrayLike], # noqa: N803
dlam_dpix_func: Callable[[ArrayLike], ArrayLike],
unit: u.UnitBase,
*,
name: str = '',
r_scale: RScale | None = None,
flux_scale: FluxScale | None = None,
pix_offset: PixOffset | None = None,
) -> None:
super().__init__(
unit,
name=name,
r_scale=r_scale,
flux_scale=flux_scale,
pix_offset=pix_offset,
)
self._R_func = R_func
self._dlam_dpix_func = dlam_dpix_func
[docs]
def R(self, wavelength: ArrayLike) -> ArrayLike:
"""Return the resolving power by evaluating the stored callable."""
return self._R_func(wavelength)
[docs]
def dlam_dpix(self, wavelength: ArrayLike) -> ArrayLike:
"""Return the linear dispersion by evaluating the stored callable."""
return self._dlam_dpix_func(wavelength)
[docs]
class SimpleDisperser(Disperser):
"""A disperser defined on a pixel-sampled wavelength grid.
The wavelength array is interpreted as a sequence of pixel centers so that
*dλ/dpix* is computed directly from the spacing of the array (via
``jnp.gradient``).
The resolving power is derived from exactly **one** of three keyword
arguments:
* ``R`` — resolving power, scalar or array matching *wavelength*.
* ``dlam`` — spectral resolution element Δλ (same unit as *wavelength*).
* ``dvel`` — velocity fwhm in **km/s**. Converted via *R = c / dvel*.
A scalar value produces a **constant** R, Δλ, or Δv across the grid (and
the corresponding resolving-power array is derived accordingly). An array
value must have the same length as *wavelength*.
Parameters
----------
wavelength : u.Quantity
Pixel-center wavelengths. Must be 1-D.
R : ArrayLike, optional
Resolving power (scalar or per-pixel array).
dlam : u.Quantity, optional
Spectral resolution Δλ, must be in wavelength units (scalar or per-pixel array)
dvel : u.Quantity, optional
Velocity resolution in velocity units (scalar or per-pixel array).
name : str, optional
Human-readable label.
r_scale : RScale, optional
Token for the resolving-power scale.
flux_scale : FluxScale, optional
Token for the flux normalisation.
pix_offset : PixOffset, optional
Token for the pixel shift.
Raises
------
ValueError
If zero or more than one of ``R``, ``dlam``, ``dvel`` is provided,
or if an array argument has the wrong length.
Notes
-----
When :meth:`R` or :meth:`dlam_dpix` is called at wavelengths that differ
from the stored grid, the values are linearly interpolated with
``jnp.interp``.
"""
def __init__(
self,
wavelength: u.Quantity,
*,
R: ArrayLike | None = None, # noqa: N803
dlam: u.Quantity | None = None,
dvel: u.Quantity | None = None,
name: str = '',
r_scale: RScale | None = None,
flux_scale: FluxScale | None = None,
pix_offset: PixOffset | None = None,
) -> None:
wavelength = _ensure_wavelength(wavelength, 'wavelength', ndim=1)
super().__init__(
wavelength.unit,
name=name,
r_scale=r_scale,
flux_scale=flux_scale,
pix_offset=pix_offset,
)
n_specified = sum(x is not None for x in (R, dlam, dvel))
if n_specified != 1:
msg = f'Exactly one of R, dlam, or dvel must be provided, but {n_specified} were given.'
raise ValueError(msg)
self._wavelength = jnp.asarray(wavelength.value, dtype=float)
# Compute dlam_dpix from the pixel grid.
# jnp.gradient returns Array | list[Array]; cast to Array for type checker.
self._dlam_dpix_grid: Array = jnp.asarray(jnp.gradient(self._wavelength))
# Compute resolving power on the grid.
if R is not None:
self._R_grid = self._validated_input(R, 'R')
elif dlam is not None:
dlam = _ensure_wavelength(dlam, 'dlam', ndim=[0, 1])
dlam_arr = self._validated_input(dlam.to(wavelength.unit).value, 'dlam')
self._R_grid = self._wavelength / dlam_arr
else:
dvel = _ensure_velocity(dvel, 'dvel', ndim=[0, 1])
dvel_arr = self._validated_input(dvel.to(u.km / u.s).value, 'dvel')
self._R_grid = C_KMS / dvel_arr
def _validated_input(self, value: ArrayLike, name: str) -> jnp.ndarray:
"""Convert *value* to a grid-shaped array.
Scalars (0-d) are broadcast to the grid shape. 1-d arrays must match
the grid length exactly; anything else raises `ValueError`.
Parameters
----------
value : ArrayLike
Scalar or 1-d array supplied by the caller.
name : str
Parameter name used in error messages.
Returns
-------
jnp.ndarray
Array with the same shape as ``self._wavelength``.
"""
arr = jnp.asarray(value, dtype=float)
if arr.ndim == 0:
return jnp.broadcast_to(arr, self._wavelength.shape)
if arr.shape != self._wavelength.shape:
msg = (
f'{name} must be a scalar or have the same shape as '
f'wavelength {self._wavelength.shape}, '
f'got shape {arr.shape}.'
)
raise ValueError(msg)
return arr
[docs]
def R(self, wavelength: ArrayLike) -> ArrayLike:
"""Return the resolving power, interpolated onto *wavelength*."""
return jnp.interp(wavelength, self._wavelength, self._R_grid)
[docs]
def dlam_dpix(self, wavelength: ArrayLike) -> ArrayLike:
"""Return the linear dispersion, interpolated onto *wavelength*."""
return jnp.interp(wavelength, self._wavelength, self._dlam_dpix_grid)