"""Instrument configuration: disperser models with calibration tokens.
:class:`InstrumentConfig` collects one :class:`~unite.instrument.base.Disperser`
per observing configuration. Each disperser carries optional
:class:`~unite.instrument.base.RScale` / :class:`~unite.instrument.base.FluxScale` /
:class:`~unite.instrument.base.PixOffset` tokens (``r_scale``, ``flux_scale``,
``pix_offset``) for shared calibration parameters.
Sharing tokens
--------------
Pass the **same** token instance to multiple dispersers to share a single
parameter in the fitted model::
from unite.instrument import InstrumentConfig, RScale
from unite.instrument.nirspec import G235H, G395H
from unite.prior import TruncatedNormal
r = RScale(prior=TruncatedNormal(1.0, 0.05, 0.8, 1.2))
cfg = InstrumentConfig([
G235H(r_scale=r),
G395H(r_scale=r), # same r — shared parameter
])
Degeneracy warning
------------------
A multi-disperser fit is only identified if at least one disperser has
``flux_scale=None`` (flux anchor) and at least one has ``pix_offset=None``
(pixel-offset anchor). :meth:`validate` issues :class:`UserWarning` when
these conditions are not met. Validation is called automatically when this
object is passed to :class:`~unite.config.Configuration`.
Serialization
-------------
:meth:`to_dict` hoists all unique calibration tokens to a top-level
``calib_params`` section keyed by token name. Disperser entries reference
tokens by name. Shared tokens round-trip correctly — the same object is
reconstructed for both entries.
Examples
--------
>>> from unite.instrument import InstrumentConfig, RScale, FluxScale
>>> from unite.instrument.nirspec import G235H, G395H
>>> from unite.prior import TruncatedNormal
>>> r = RScale(prior=TruncatedNormal(1.0, 0.05, 0.8, 1.2))
>>> flux_0 = FluxScale(prior=TruncatedNormal(1.0, 0.1, 0.5, 2.0))
>>> cfg = InstrumentConfig([
... G235H(r_scale=r),
... G395H(r_scale=r, flux_scale=flux_0),
... ])
>>> cfg.names
['G235H', 'G395H']
"""
from __future__ import annotations
import warnings
from collections.abc import Iterator, Sequence
from pathlib import Path
import yaml
from unite._utils import _alpha_name
from unite.instrument.base import Disperser, FluxScale, PixOffset, RScale
from unite.instrument.nirspec.disperser import (
G140H,
G140M,
G235H,
G235M,
G395H,
G395M,
PRISM,
NIRSpec,
)
from unite.instrument.sdss.disperser import SDSSDisperser
from unite.prior import Parameter, prior_from_dict
# ---------------------------------------------------------------------------
# Disperser serialization registry
# ---------------------------------------------------------------------------
_DISPERSER_REGISTRY: dict[str, type[Disperser]] = {
'G140H': G140H,
'G140M': G140M,
'G235H': G235H,
'G235M': G235M,
'G395H': G395H,
'G395M': G395M,
'PRISM': PRISM,
'SDSSDisperser': SDSSDisperser,
}
_CALIB_REGISTRY: dict[str, type[Parameter]] = {
'RScale': RScale,
'FluxScale': FluxScale,
'PixOffset': PixOffset,
}
# ---------------------------------------------------------------------------
# CalibParam (de)serialization helpers
# ---------------------------------------------------------------------------
def _calib_param_to_dict(token: Parameter) -> dict:
"""Serialize a calibration token to a YAML-safe dictionary."""
return {'type': type(token).__name__, 'prior': token.prior.to_dict()}
def _calib_param_from_dict(name: str, d: dict) -> Parameter:
"""Reconstruct a calibration token from its serialized dict."""
cls = _CALIB_REGISTRY[d['type']]
prior = prior_from_dict(d['prior'])
tok = cls(prior=prior)
tok.name = name # set finalized site name directly to avoid re-prefixing
tok.label = name # keep label consistent with name
return tok
# ---------------------------------------------------------------------------
# Disperser (de)serialization helpers
# ---------------------------------------------------------------------------
def _disperser_to_entry(disperser: Disperser) -> dict:
"""Serialize a disperser to an entry dict (CalibParams referenced by name).
Parameters
----------
disperser : Disperser
Must be a registered type (NIRSpec or SDSS).
Returns
-------
dict
"""
cls_name = type(disperser).__name__
if cls_name not in _DISPERSER_REGISTRY:
msg = (
f'Cannot serialize disperser of type {cls_name!r}. '
f'Only registered dispersers support serialization. '
f'Registered types: {sorted(_DISPERSER_REGISTRY)}.'
)
raise TypeError(msg)
d: dict = {'type': cls_name, 'name': disperser.name}
if isinstance(disperser, NIRSpec):
d['grating'] = disperser.grating
d['r_source'] = disperser.r_source
# CalibParam references by token name (or null for fixed).
for attr in ('r_scale', 'flux_scale', 'pix_offset'):
token = getattr(disperser, attr)
d[attr] = token.name if token is not None else None
return d
def _disperser_from_entry(d: dict, token_registry: dict[str, Parameter]) -> Disperser:
"""Reconstruct a disperser from an entry dict and token registry.
Parameters
----------
d : dict
As produced by :func:`_disperser_to_entry`.
token_registry : dict
Mapping from token names to reconstructed CalibParam objects.
Returns
-------
Disperser
"""
cls_name = d['type']
if cls_name not in _DISPERSER_REGISTRY:
msg = f'Unknown disperser type {cls_name!r}. Registered: {sorted(_DISPERSER_REGISTRY)}.'
raise KeyError(msg)
cls = _DISPERSER_REGISTRY[cls_name]
kwargs: dict = {'name': d.get('name', '')}
if cls_name == 'NIRSpec':
kwargs['grating'] = d['grating']
kwargs['r_source'] = d.get('r_source', 'point')
elif 'r_source' in d:
kwargs['r_source'] = d['r_source']
for attr in ('r_scale', 'flux_scale', 'pix_offset'):
ref = d.get(attr)
kwargs[attr] = token_registry[ref] if ref is not None else None
return cls(**kwargs)
# ---------------------------------------------------------------------------
# Degeneracy warnings
# ---------------------------------------------------------------------------
_FLUX_DEGENERACY_WARNING = (
'InstrumentConfig: no disperser has flux_scale=None (fixed). '
'Relative flux scales are degenerate — set flux_scale=None on one '
'disperser to anchor the flux calibration.'
)
_PIX_DEGENERACY_WARNING = (
'InstrumentConfig: no disperser has pix_offset=None (fixed). '
'Pixel-offset (dispersion) scales are degenerate — set pix_offset=None '
'on one disperser to anchor the wavelength solution.'
)
# ---------------------------------------------------------------------------
# InstrumentConfig
# ---------------------------------------------------------------------------
[docs]
class InstrumentConfig:
"""Configuration for a multi-disperser spectral dataset.
An ordered collection of :class:`~unite.instrument.base.Disperser` objects,
one per observing disperser. Each disperser carries optional
:class:`~unite.instrument.base.CalibParam` tokens for calibration parameters.
The configuration is **data-free**: it describes which dispersers are used
and how they are calibrated. Actual spectral data arrays are attached
later via :meth:`make_spectrum`.
Parameters
----------
dispersers : sequence of Disperser
One entry per disperser. Names (``disperser.name``) must be unique.
Raises
------
ValueError
If any disperser names are duplicated or empty.
"""
def __init__(self, dispersers: Sequence[Disperser]) -> None:
names = [d.name for d in dispersers]
empty = [i for i, n in enumerate(names) if not n]
if empty:
msg = f'Dispersers at indices {empty} have empty names. Set a non-empty name on each disperser.'
raise ValueError(msg)
dupes = sorted({n for n in names if names.count(n) > 1})
if dupes:
msg = f'Duplicate disperser name(s) in InstrumentConfig: {dupes}'
raise ValueError(msg)
self._dispersers: list[Disperser] = list(dispersers)
self._assign_calib_names()
# -- calibration token naming --------------------------------------------
def _assign_calib_names(self) -> None:
"""Assign site names to any anonymous calibration tokens.
Tokens with a user-supplied label already have their site name set by
:class:`~unite.instrument.base.Disperser`. This method handles the
remaining anonymous tokens (``label=None``, ``name=None``):
* **Unshared** token on a single named disperser → ``'{prefix}_{disperser.name}'``
(e.g. ``FluxScale()`` on ``G235H`` → ``'fs_G235H'``).
* **Shared** token across multiple dispersers, or disperser with no name
→ alphabetic counter: ``'r_a'``, ``'r_b'``, …
A shared anonymous token (same instance on multiple dispersers) is
encountered once and receives exactly one name.
"""
_slots = ('r_scale', 'flux_scale', 'pix_offset')
# Pre-compute which dispersers each anonymous token appears on.
tok_dispersers: dict[int, list] = {}
for d in self._dispersers:
for slot in _slots:
tok = getattr(d, slot)
if tok is not None and tok._name is None:
tok_dispersers.setdefault(id(tok), []).append(d)
counters: dict[str, int] = {} # slot → next alpha index
seen: set[int] = set() # id(tok) already named
for d in self._dispersers:
for slot in _slots:
tok = getattr(d, slot)
if tok is None or tok._name is not None or id(tok) in seen:
continue
seen.add(id(tok))
dispersers_with_tok = tok_dispersers.get(id(tok), [])
if tok.label is not None:
# User-supplied label takes priority (name already set by base.py).
pass
elif len(dispersers_with_tok) == 1 and d.name:
# Unshared token on a single named disperser: use its name.
tok.name = f'{slot}_{d.name}'
tok.label = d.name
else:
# Shared across multiple dispersers or unnamed disperser: alpha.
idx = counters.get(slot, 0)
counters[slot] = idx + 1
label = _alpha_name(idx)
tok.name = f'{slot}_{label}'
tok.label = label
# -- validation ----------------------------------------------------------
[docs]
def validate(self) -> None:
"""Check for flux and pixel-offset degeneracies.
Issues a :class:`UserWarning` for each calibration axis (flux,
dispersion) where no disperser is anchored (token ``None``).
"""
if len(self._dispersers) > 1:
if all(d.flux_scale is not None for d in self._dispersers):
warnings.warn(_FLUX_DEGENERACY_WARNING, UserWarning, stacklevel=2)
if all(d.pix_offset is not None for d in self._dispersers):
warnings.warn(_PIX_DEGENERACY_WARNING, UserWarning, stacklevel=2)
# -- names ---------------------------------------------------------------
@property
def names(self) -> list[str]:
"""Names of all dispersers in this configuration."""
return [d.name for d in self._dispersers]
# -- container interface -------------------------------------------------
def __len__(self) -> int:
"""Return the number of dispersers."""
return len(self._dispersers)
def __iter__(self) -> Iterator[Disperser]:
"""Iterate over the dispersers."""
return iter(self._dispersers)
def __getitem__(self, key: int | str) -> Disperser:
"""Look up a disperser by index or name."""
if isinstance(key, str):
for d in self._dispersers:
if d.name == key:
return d
msg = f'No disperser named {key!r} in InstrumentConfig.'
raise KeyError(msg)
return self._dispersers[key]
# -- serialization -------------------------------------------------------
[docs]
def to_dict(self) -> dict:
"""Serialize to a YAML-safe dictionary.
Shared :class:`~unite.instrument.base.CalibParam` tokens are hoisted to a
top-level ``calib_params`` section. Disperser entries reference
tokens by name, so sharing is preserved on round-trip.
Returns
-------
dict
Contains ``'calib_params'`` and ``'entries'`` keys.
"""
# Collect all unique tokens (preserving insertion order).
seen_ids: dict[int, Parameter] = {}
for d in self._dispersers:
for attr in ('r_scale', 'flux_scale', 'pix_offset'):
token = getattr(d, attr)
if token is not None and id(token) not in seen_ids:
seen_ids[id(token)] = token
calib_params = {t.name: _calib_param_to_dict(t) for t in seen_ids.values()}
entries = [_disperser_to_entry(d) for d in self._dispersers]
return {'calib_params': calib_params, 'entries': entries}
[docs]
@classmethod
def from_dict(cls, d: dict) -> InstrumentConfig:
"""Deserialize from a dictionary.
Parameters
----------
d : dict
As produced by :meth:`to_dict`.
Returns
-------
InstrumentConfig
"""
# Reconstruct token registry (name → CalibParam).
token_registry: dict[str, Parameter] = {}
for name, token_d in d.get('calib_params', {}).items():
token_registry[name] = _calib_param_from_dict(name, token_d)
dispersers = [_disperser_from_entry(e, token_registry) for e in d['entries']]
# Bypass validate() on load — already validated when saved.
obj = cls.__new__(cls)
obj._dispersers = dispersers
return obj
# -- YAML serialization --------------------------------------------------
[docs]
def to_yaml(self) -> str:
"""Serialize to a YAML string.
Returns
-------
str
"""
return yaml.dump(self.to_dict(), default_flow_style=False, sort_keys=False)
[docs]
@classmethod
def from_yaml(cls, text: str) -> InstrumentConfig:
"""Deserialize from a YAML string.
Parameters
----------
text : str
YAML string as produced by :meth:`to_yaml`.
Returns
-------
InstrumentConfig
"""
return cls.from_dict(yaml.safe_load(text))
# -- File I/O ------------------------------------------------------------
[docs]
def save(self, path: str | Path) -> None:
"""Save to a YAML file.
Parameters
----------
path : str or Path
Output file path.
"""
Path(path).write_text(self.to_yaml())
[docs]
@classmethod
def load(cls, path: str | Path) -> InstrumentConfig:
"""Load from a YAML file.
Parameters
----------
path : str or Path
Path to a YAML file written by :meth:`save`.
Returns
-------
InstrumentConfig
"""
return cls.from_yaml(Path(path).read_text())
# -- addition ------------------------------------------------------------
def __add__(self, other: InstrumentConfig) -> InstrumentConfig:
"""Combine two configurations (strict mode — raises on duplicate names).
Parameters
----------
other : InstrumentConfig
Returns
-------
InstrumentConfig
New configuration containing dispersers from both *self* and *other*.
Raises
------
ValueError
If any disperser name appears in both configurations.
TypeError
If *other* is not an :class:`InstrumentConfig`.
"""
if not isinstance(other, InstrumentConfig):
return NotImplemented
self_names = set(self.names)
collisions = [n for n in other.names if n in self_names]
if collisions:
msg = f'Duplicate disperser name(s) in InstrumentConfig addition: {collisions}'
raise ValueError(msg)
return InstrumentConfig(list(self._dispersers) + list(other._dispersers))
# -- repr ----------------------------------------------------------------
def __repr__(self) -> str:
"""Return a multi-line summary of this configuration."""
if not self._dispersers:
return 'InstrumentConfig: empty'
header = f'InstrumentConfig: {len(self._dispersers)} disperser(s)'
def _tok_name(token: Parameter | None) -> str:
if token is None:
return '— (fixed)'
return token.name or '—'
rows = [
(
d.name,
repr(d),
_tok_name(d.r_scale),
_tok_name(d.flux_scale),
_tok_name(d.pix_offset),
)
for d in self._dispersers
]
col_headers = ('Name', 'Disperser', 'r_scale', 'flux_scale', 'pix_offset')
widths = [len(h) for h in col_headers]
for row in rows:
for i, cell in enumerate(row):
widths[i] = max(widths[i], len(cell))
fmt = ' '.join(f'{{:<{w}}}' for w in widths)
sep = ' '.join('-' * w for w in widths)
lines = [header, '', ' ' + fmt.format(*col_headers), ' ' + sep]
for row in rows:
lines.append(' ' + fmt.format(*row))
# Add calibration parameter details section (like Line config)
calib_params: dict[str, list[tuple[str, Parameter]]] = {
'r_scale': [],
'flux_scale': [],
'pix_offset': [],
}
seen_ids: dict[str, set[int]] = {
'r_scale': set(),
'flux_scale': set(),
'pix_offset': set(),
}
for d in self._dispersers:
for attr in ('r_scale', 'flux_scale', 'pix_offset'):
token = getattr(d, attr)
if token is not None:
token_id = id(token)
if token_id not in seen_ids[attr]:
seen_ids[attr].add(token_id)
calib_params[attr].append((token.name, token))
# Format calibration parameter sections
for attr, params in calib_params.items():
if params:
lines.append('')
section_title = attr.replace('_', ' ').title()
lines.append(f' {section_title}:')
name_width = max(len(name) for name, _ in params)
for name, token in params:
lines.append(f' {name:<{name_width}} {token.prior!r}')
return '\n'.join(lines)