Source code for unite.compute

"""Model evaluator: decompose posterior predictions into per-line and per-region contributions.

Given posterior samples and :class:`~unite.model.ModelArgs`, this module
reconstructs the full model prediction for each spectrum, broken down into
individual line and continuum-region contributions in **original flux units**.
"""

from __future__ import annotations

from dataclasses import dataclass

import jax
import jax.numpy as jnp
import numpy as np

from unite._compose import compose_leave_one_out
from unite._lsf import _FWHM_TO_SIGMA, _lsf_convolve
from unite.continuum.compute import eval_continuum_regions
from unite.line.compute import (
    _build_line_params,
    _peak_to_area_tau,
    evaluate_lines,
    integrate_lines,
)
from unite.model import ModelArgs
from unite.prior import Fixed


[docs] @dataclass class SpectrumPrediction: """Decomposed model prediction for a single spectrum. All arrays are in original (un-normalized) flux units. For **emission lines**, each entry in :attr:`lines` is the intrinsic (un-attenuated) flux profile: ``flux * profile``. Summing all line contributions and continuum regions always reconstructs :attr:`total` regardless of zorder configuration. For **absorption lines**, each entry in :attr:`lines` is the flux *removed* by that absorber (negative): ``total - total_without_j``. """ #: Pixel-center wavelengths in the disperser's unit. Shape ``(n_pixels,)``. wavelength: np.ndarray #: Total model flux (lines + continuum). Shape ``(n_samples, n_pixels)``. total: np.ndarray #: Per-line contributions keyed by informative line labels (e.g. ``'Ha'``, ``'[NII]_6585'``). #: For emission lines: intrinsic (un-attenuated) flux profile (positive). #: For absorption lines: flux removed by the absorber (negative). #: Shape ``(n_samples, n_pixels)`` each. lines: dict[str, np.ndarray] #: Per-continuum-region contributions keyed by informative region labels #: (e.g. ``'linear_6400_6700'``, ``'powerlaw_0.95_2.5'``). #: Shape ``(n_samples, n_pixels)`` each. continuum_regions: dict[str, np.ndarray] #: LSF-convolved optical depth profiles ``tau_j * phi_j(λ)`` for absorption #: lines, evaluated at pixel midpoints. Dimensionless and non-negative. #: Keyed by line label; empty dict for emission-only models. #: Shape ``(n_samples, n_pixels)`` each. tau_profiles: dict[str, np.ndarray]
[docs] def evaluate_model( samples: dict[str, np.ndarray], args: ModelArgs ) -> list[SpectrumPrediction]: """Evaluate the model for each posterior sample and decompose contributions. Uses :func:`jax.vmap` to evaluate all samples in a single vectorised XLA kernel launch rather than a Python loop, giving a large speed-up when the number of posterior samples is large. Parameters ---------- samples : dict of str to ndarray Posterior samples as returned by ``mcmc.get_samples()`` or ``Predictive``. Each value has shape ``(n_samples,)`` or ``(n_samples, ...)``. args : ModelArgs Pre-built data bundle from :meth:`ModelBuilder.build`. Returns ------- list of SpectrumPrediction One prediction per spectrum in ``args.spectra``. """ cm = args.matrices z_sys = args.redshift n_lines = cm.wavelengths.shape[0] has_tau = bool(cm.tau_names) # Resolve per-config dispatch (specialized to used profiles, or module-level fallback). _pcodes = ( args._profile_codes_local if args._profile_codes_local is not None else cm.profile_codes ) _int_fn = args._integrate_fn if args._integrate_fn is not None else integrate_lines _eval_fn = args._evaluate_fn if args._evaluate_fn is not None else evaluate_lines # --- Build parameter dict with a uniform (n_samples,) leading axis --- context: dict[str, jnp.ndarray] = {} n_samples = None for pname in args.dependency_order: prior = args.all_priors[pname] if isinstance(prior, Fixed): context[pname] = jnp.asarray(prior.value) else: arr = jnp.asarray(samples[pname]) context[pname] = arr if n_samples is None and arr.ndim >= 1: n_samples = arr.shape[0] if n_samples is None: n_samples = 1 # Broadcast Fixed (scalar) params to (n_samples,) so every leaf has the # same leading axis and vmap can map uniformly over the dict pytree. context = { k: (jnp.broadcast_to(v, (n_samples,)) if v.ndim == 0 else v) for k, v in context.items() } # --- Per-spectrum evaluation (vectorised over samples) --- # Lazily build and JIT-compile per-spectrum vmapped evaluators on first call. # This eliminates Python-level JAX retracing on every subsequent evaluate_model # call with the same ModelArgs (e.g. make_spectra_tables + make_parameter_table). if args._evaluators is None: args._evaluators = [None] * len(args.spectra) results: list[SpectrumPrediction] = [] for i, spectrum in enumerate(args.spectra): disp = spectrum.disperser wl_scale = args.spec_to_canonical[i] inv_wl_scale = 1.0 / wl_scale wl_out = np.asarray(spectrum.wavelength) # Static arrays: do not depend on sample values. edges_base = jnp.asarray(spectrum.edges * wl_scale) edges_disp_base = jnp.asarray(spectrum.edges) keep_mask = jnp.asarray(spectrum.keep_mask) if disp.pix_offset is not None: dlam_edges = disp.dlam_dpix(edges_disp_base) * wl_scale # (E,) else: dlam_edges = None line_scale = float(args.line_flux_scales[i]) cont_scale = float(args.continuum_scales[i]) # Keyword-only defaults bind the per-iteration values at definition # time, avoiding late-binding closure issues (ruff B023) while # allowing jax.vmap to vmap only over the positional `params` arg. def _single( params, *, _edges=edges_base, _keep=keep_mask, _dlam_edges=dlam_edges, _disp=disp, _inv_wl_scale=inv_wl_scale, _line_scale=line_scale, _cont_scale=cont_scale, _ifn=_int_fn, _efn=_eval_fn, _pc=_pcodes, ): """Evaluate one posterior sample. ``params`` is a dict of 0-D arrays.""" # --- Line parameters --- flux_per_line, tau_per_line, centers, p0, p1, p2 = _build_line_params( cm, params, n_lines, z_sys ) # --- Convert peak-tau to area-tau --- if cm.tau_names: tau_per_line = _peak_to_area_tau( tau_per_line, centers, p0, p1, p2, _pc, cm.is_tau, _eval_fn=args._evaluate_at_centers_fn, ) # --- Calibration --- r_scale = params[_disp.r_scale.name] if _disp.r_scale is not None else 1.0 flux_scale_val = ( params[_disp.flux_scale.name] if _disp.flux_scale is not None else 1.0 ) pix_offset = ( params[_disp.pix_offset.name] if _disp.pix_offset is not None else 0.0 ) # --- Edge topology (apply sub-pixel offset if present) --- edges = _edges if _dlam_edges is not None: edges = edges + pix_offset * _dlam_edges widths = jnp.diff(edges) # Per-pixel low/high for diagnostics and convolution mode. low = edges[:-1][_keep] high = edges[1:][_keep] wavelength = 0.5 * (low + high) # --- LSF --- # Per-line LSF (used for tau diagnostic at midpoints). lsf_fwhm = centers / (_disp.R(centers * _inv_wl_scale) * r_scale) # Per-edge LSF (shared across lines, used by analytic mode). edges_disp = edges * _inv_wl_scale lsf_at_edges = edges / (_disp.R(edges_disp) * r_scale) # LSF FWHM at pixel centres for continuum convolution mode. cont_lsf_fwhm = wavelength / (_disp.R(wavelength * _inv_wl_scale) * r_scale) # Scaled line fluxes for this spectrum. scaled_flux = flux_per_line * _line_scale # --- Optical depth profiles at pixel midpoints (mode-independent) --- # For emission-only models has_tau is a Python False at trace time, so # the entire evaluate_lines call is compiled out of the XLA program. if has_tau: phi_mid = _efn(wavelength, centers, lsf_fwhm, p0, p1, p2, _pc) tau_profiles_arr = jnp.where( cm.is_tau[:, None], tau_per_line[:, None] * phi_mid, 0.0 ) else: tau_profiles_arr = jnp.zeros((n_lines, wavelength.shape[0])) # --- Continuum (per-region for decomposition) --- # For analytic: evaluate at pixel centres with LSF. # For convolution: evaluated on the fine grid inside the branch below. cont_total_scaled: jnp.ndarray = jnp.zeros_like(wavelength) cont_regions_scaled: list[jnp.ndarray] = [] if args.integration_mode != 'convolution': cont_regions = eval_continuum_regions( wavelength, args, params, z_sys, cont_lsf_fwhm ) cont_total = jnp.zeros_like(wavelength) for region in cont_regions: cont_total = cont_total + region cont_total_scaled = cont_total * _cont_scale cont_regions_scaled = [ r * _cont_scale * flux_scale_val for r in cont_regions ] # --- Line decomposition --- if args.integration_mode == 'convolution': # Numerical LSF convolution: evaluate intrinsic model (lsf_fwhm=0) # on a fine sub-pixel grid, convolve with the wavelength-dependent # Gaussian LSF, then pixel-average. Per-line decomposition uses # compose_leave_one_out on the fine grid; convolution is linear so # convolving deltas separately is exact. n_super = args.n_super half_width = args.conv_half_width assert n_super is not None assert half_width is not None n_pixels = low.shape[0] # Fine grid in wavelength order: all sub-bins of pixel 0, then pixel 1, etc. # Shape: (n_pixels * n_super,). offsets = (jnp.arange(n_super) + 0.5) / n_super x_fine = ( low[None, :] + offsets[:, None] * (high - low)[None, :] ) # (n_super, n_pixels) x_flat = x_fine.T.ravel() # (n_pixels * n_super,) in wavelength order # Intrinsic profiles (lsf_fwhm=0) on fine grid. zero_lsf = jnp.zeros_like(centers) phi_fine = _efn(x_flat, centers, zero_lsf, p0, p1, p2, _pc) # LSF sigma at each fine-grid point. sigma_fine = ( x_flat / (_disp.R(x_flat * _inv_wl_scale) * r_scale) * _FWHM_TO_SIGMA ) # Per-region continuum on fine grid (lsf_fwhm=0): convolve on fine # grid then pixel-average. cont_regions_fine = eval_continuum_regions( x_flat, args, params, z_sys, 0.0 ) cont_regions_scaled = [ _lsf_convolve(x_flat, r * _cont_scale, sigma_fine, half_width) .reshape(n_pixels, n_super) .mean(axis=1) * flux_scale_val for r in cont_regions_fine ] # Total continuum on fine grid for model composition. cont_fine_total = jnp.zeros_like(x_flat) for r in cont_regions_fine: cont_fine_total = cont_fine_total + r cont_fine_scaled = cont_fine_total * _cont_scale # Leave-one-out decomposition on fine grid. total_fine, deltas_fine = compose_leave_one_out( phi_fine, scaled_flux, tau_per_line, cm.is_tau, cm.applies_matrix, args.cont_applies, cont_fine_scaled, has_tau=has_tau, ) # Convolve on fine grid, then pixel-average. total_conv = _lsf_convolve(x_flat, total_fine, sigma_fine, half_width) total_pix = total_conv.reshape(n_pixels, n_super).mean(axis=1) deltas_conv = jax.vmap( lambda d: _lsf_convolve(x_flat, d, sigma_fine, half_width) )(deltas_fine) # (n_lines, n_pixels * n_super) deltas_pix = deltas_conv.reshape(-1, n_pixels, n_super).mean(axis=2) total = flux_scale_val * total_pix line_contribs = flux_scale_val * deltas_pix return total, line_contribs, cont_regions_scaled, tau_profiles_arr else: # Analytic: cumulative-at-edges per line, then diff + mask # → per-pixel-averaged profile. cum_per_line = _ifn( edges, centers, lsf_at_edges, p0, p1, p2, _pc ) # (n_lines, E) per_interval = jnp.diff(cum_per_line, axis=1) / widths pixints = per_interval[:, _keep] total, line_contribs = compose_leave_one_out( pixints, scaled_flux, tau_per_line, cm.is_tau, cm.applies_matrix, args.cont_applies, cont_total_scaled, has_tau=has_tau, ) # Apply flux_scale to total and per-line contributions. total = flux_scale_val * total line_contribs = flux_scale_val * line_contribs return total, line_contribs, cont_regions_scaled, tau_profiles_arr # Build JIT'd vmapped evaluator once per spectrum; reuse on subsequent calls. if args._evaluators[i] is None: args._evaluators[i] = jax.jit(jax.vmap(_single)) total_arr, line_arr, cont_arr, tau_arr = args._evaluators[i](context) # total_arr: (n_samples, n_pix) # line_arr: (n_samples, n_lines, n_pix) # cont_arr: list of (n_samples, n_pix), one per continuum region # tau_arr: (n_samples, n_lines, n_pix) # Transfer line and tau arrays to host once, then use zero-copy NumPy # views for per-line slicing. Individual np.asarray(jax_arr[:, j, :]) # calls each dispatch a separate XLA slice → host copy. line_arr_np = np.asarray(line_arr) lines_dict: dict[str, np.ndarray] = { args.line_labels[j]: line_arr_np[:, j, :] for j in range(n_lines) } cont_dict: dict[str, np.ndarray] = {} if args.cont_config is not None: for k in range(len(args.cont_config)): cont_dict[args.continuum_labels[k]] = np.asarray(cont_arr[k]) is_abs = np.asarray(cm.is_tau) if has_tau: tau_arr_np = np.asarray(tau_arr) tau_dict: dict[str, np.ndarray] = { args.line_labels[j]: tau_arr_np[:, j, :] for j in range(n_lines) if is_abs[j] } else: tau_dict = {} results.append( SpectrumPrediction( wavelength=wl_out, total=np.asarray(total_arr), lines=lines_dict, continuum_regions=cont_dict, tau_profiles=tau_dict, ) ) return results