Source code for unite.compute

"""Model evaluator: decompose posterior predictions into per-line and per-region contributions.

Given posterior samples and :class:`~unite.model.ModelArgs`, this module
reconstructs the full model prediction for each spectrum, broken down into
individual line and continuum-region contributions in **original flux units**.
"""

from __future__ import annotations

from dataclasses import dataclass

import jax
import jax.numpy as jnp
import numpy as np

from unite._compose import compose_leave_one_out
from unite._lsf import _FWHM_TO_SIGMA, _lsf_convolve
from unite._utils import C_KMS
from unite.continuum.compute import eval_continuum, eval_continuum_regions
from unite.line.compute import _peak_to_area_tau, evaluate_lines, integrate_lines
from unite.model import ModelArgs
from unite.prior import Fixed


[docs] @dataclass class SpectrumPrediction: """Decomposed model prediction for a single spectrum. All arrays are in original (un-normalized) flux units. For **emission lines**, each entry in :attr:`lines` is the intrinsic (un-attenuated) flux profile: ``flux * profile``. Summing all line contributions and continuum regions always reconstructs :attr:`total` regardless of zorder configuration. For **absorption lines**, each entry in :attr:`lines` is the flux *removed* by that absorber (negative): ``total - total_without_j``. """ #: Pixel-center wavelengths in the disperser's unit. Shape ``(n_pixels,)``. wavelength: np.ndarray #: Total model flux (lines + continuum). Shape ``(n_samples, n_pixels)``. total: np.ndarray #: Per-line contributions keyed by informative line labels (e.g. ``'Ha'``, ``'[NII]_6585'``). #: For emission lines: intrinsic (un-attenuated) flux profile (positive). #: For absorption lines: flux removed by the absorber (negative). #: Shape ``(n_samples, n_pixels)`` each. lines: dict[str, np.ndarray] #: Per-continuum-region contributions keyed by informative region labels #: (e.g. ``'linear_6400_6700'``, ``'powerlaw_0.95_2.5'``). #: Shape ``(n_samples, n_pixels)`` each. continuum_regions: dict[str, np.ndarray] #: LSF-convolved optical depth profiles ``tau_j * phi_j(λ)`` for absorption #: lines, evaluated at pixel midpoints. Dimensionless and non-negative. #: Keyed by line label; empty dict for emission-only models. #: Shape ``(n_samples, n_pixels)`` each. tau_profiles: dict[str, np.ndarray]
[docs] def evaluate_model( samples: dict[str, np.ndarray], args: ModelArgs ) -> list[SpectrumPrediction]: """Evaluate the model for each posterior sample and decompose contributions. Uses :func:`jax.vmap` to evaluate all samples in a single vectorised XLA kernel launch rather than a Python loop, giving a large speed-up when the number of posterior samples is large. Parameters ---------- samples : dict of str to ndarray Posterior samples as returned by ``mcmc.get_samples()`` or ``Predictive``. Each value has shape ``(n_samples,)`` or ``(n_samples, ...)``. args : ModelArgs Pre-built data bundle from :meth:`ModelBuilder.build`. Returns ------- list of SpectrumPrediction One prediction per spectrum in ``args.spectra``. """ cm = args.matrices z_sys = args.redshift n_lines = cm.wavelengths.shape[0] has_tau = bool(cm.tau_names) # --- Build parameter dict with a uniform (n_samples,) leading axis --- context: dict[str, jnp.ndarray] = {} n_samples = None for pname in args.dependency_order: prior = args.all_priors[pname] if isinstance(prior, Fixed): context[pname] = jnp.asarray(prior.value) else: arr = jnp.asarray(samples[pname]) context[pname] = arr if n_samples is None and arr.ndim >= 1: n_samples = arr.shape[0] if n_samples is None: n_samples = 1 # Broadcast Fixed (scalar) params to (n_samples,) so every leaf has the # same leading axis and vmap can map uniformly over the dict pytree. context = { k: (jnp.broadcast_to(v, (n_samples,)) if v.ndim == 0 else v) for k, v in context.items() } # --- Per-spectrum evaluation (vectorised over samples) --- results: list[SpectrumPrediction] = [] for i, spectrum in enumerate(args.spectra): disp = spectrum.disperser wl_scale = args.spec_to_canonical[i] inv_wl_scale = 1.0 / wl_scale wl_out = np.asarray(spectrum.wavelength) # Static arrays: do not depend on sample values. low_base = jnp.asarray(spectrum.low * wl_scale) high_base = jnp.asarray(spectrum.high * wl_scale) if disp.pix_offset is not None: mid_disp = jnp.asarray((spectrum.low + spectrum.high) / 2.0) dlam = disp.dlam_dpix(mid_disp) * wl_scale # (n_pix,) else: dlam = None line_scale = float(args.line_flux_scales[i]) cont_scale = float(args.continuum_scales[i]) # Keyword-only defaults bind the per-iteration values at definition # time, avoiding late-binding closure issues (ruff B023) while # allowing jax.vmap to vmap only over the positional `params` arg. def _single( params, *, _low=low_base, _high=high_base, _dlam=dlam, _disp=disp, _inv_wl_scale=inv_wl_scale, _line_scale=line_scale, _cont_scale=cont_scale, ): """Evaluate one posterior sample. ``params`` is a dict of 0-D arrays.""" # --- Line parameters --- if cm.flux_names: flux_vec = jnp.stack([params[n] for n in cm.flux_names]) flux_per_line = flux_vec @ cm.flux_matrix * cm.strengths else: flux_per_line = jnp.zeros(n_lines) if cm.tau_names: tau_vec = jnp.stack([params[n] for n in cm.tau_names]) tau_per_line = tau_vec @ cm.tau_matrix else: tau_per_line = jnp.zeros(n_lines) z_vec = jnp.stack([params[n] for n in cm.z_names]) z_per_line = z_vec @ cm.z_matrix centers = cm.wavelengths * (1.0 + z_sys + z_per_line) p0_kms = ( jnp.stack([params[n] for n in cm.p0_names]) @ cm.p0_matrix if cm.p0_names else jnp.zeros(n_lines) ) p0 = centers * p0_kms / C_KMS p1v_kms = ( jnp.stack([params[n] for n in cm.p1v_names]) @ cm.p1v_matrix if cm.p1v_names else jnp.zeros(n_lines) ) p1v = centers * p1v_kms / C_KMS p1d = ( jnp.stack([params[n] for n in cm.p1d_names]) @ cm.p1d_matrix if cm.p1d_names else jnp.zeros(n_lines) ) p1 = p1v + p1d p2 = ( jnp.stack([params[n] for n in cm.p2_names]) @ cm.p2_matrix if cm.p2_names else jnp.zeros(n_lines) ) # --- Convert peak-tau to area-tau --- if cm.tau_names: tau_per_line = _peak_to_area_tau( tau_per_line, centers, p0, p1, p2, cm.profile_codes, cm.is_tau ) # --- Calibration --- r_scale = params[_disp.r_scale.name] if _disp.r_scale is not None else 1.0 flux_scale_val = ( params[_disp.flux_scale.name] if _disp.flux_scale is not None else 1.0 ) pix_offset = ( params[_disp.pix_offset.name] if _disp.pix_offset is not None else 0.0 ) # --- Pixel edges (apply sub-pixel offset if present) --- low = _low high = _high if _dlam is not None: low = low + pix_offset * _dlam high = high + pix_offset * _dlam wavelength = (low + high) / 2.0 # --- LSF --- lsf_fwhm = centers / (_disp.R(centers * _inv_wl_scale) * r_scale) # LSF FWHM at pixel centres for continuum convolution. cont_lsf_fwhm = wavelength / (_disp.R(wavelength * _inv_wl_scale) * r_scale) # Scaled line fluxes for this spectrum. scaled_flux = flux_per_line * _line_scale # --- Optical depth profiles at pixel midpoints (mode-independent) --- # Evaluate LSF-convolved profile at pixel centres for all lines, then # mask to absorption lines only. This is the same phi the analytic # integration mode uses for exp(-tau*phi), so it is consistent with # what the model actually computes. phi_mid = evaluate_lines( wavelength, centers, lsf_fwhm, p0, p1, p2, cm.profile_codes ) tau_profiles_arr = jnp.where( cm.is_tau[:, None], tau_per_line[:, None] * phi_mid, 0.0 ) # --- Continuum (per-region for decomposition) --- # For analytic/quadrature: evaluate at pixel centres with LSF. # For convolution: evaluated on the fine grid inside the branch below. cont_total_scaled: jnp.ndarray = jnp.zeros_like(wavelength) cont_regions_scaled: list[jnp.ndarray] = [] if args.integration_mode != 'convolution': cont_regions = eval_continuum_regions( wavelength, args, params, z_sys, cont_lsf_fwhm ) cont_total = jnp.zeros_like(wavelength) for region in cont_regions: cont_total = cont_total + region cont_total_scaled = cont_total * _cont_scale cont_regions_scaled = [ r * _cont_scale * flux_scale_val for r in cont_regions ] # --- Line decomposition --- if args.integration_mode == 'quadrature': # GL quadrature: evaluate full composed model at sub-pixel # nodes and integrate. Leave-one-out at each node gives # exact per-line contributions. mid = (low + high) / 2.0 half_width = (high - low) / 2.0 nodes = args.quadrature_nodes # (n_nodes,) weights = args.quadrature_weights # (n_nodes,) assert nodes is not None assert weights is not None # Sub-pixel wavelengths: (n_nodes, n_pix) x = mid[None, :] + half_width[None, :] * nodes[:, None] # Evaluate all profiles and continuum at each node. def _at_node(wav): phi = evaluate_lines( wav, centers, lsf_fwhm, p0, p1, p2, cm.profile_codes ) node_lsf = wav / (_disp.R(wav * _inv_wl_scale) * r_scale) cont = ( eval_continuum(wav, args, params, z_sys, node_lsf) * _cont_scale ) total, deltas = compose_leave_one_out( phi, scaled_flux, tau_per_line, cm.is_tau, cm.applies_matrix, args.cont_applies, cont, has_tau=has_tau, ) return total, deltas # (n_nodes, n_pix) and (n_nodes, n_lines, n_pix) node_totals, node_deltas = jax.vmap(_at_node)(x) # Pixel-average via GL weighted sum (factor 0.5 from # the [-1,1] → [low,high] change of variable). total = 0.5 * jnp.dot(weights, node_totals) line_contribs = 0.5 * jnp.einsum('n,nlp->lp', weights, node_deltas) elif args.integration_mode == 'convolution': # Numerical LSF convolution: evaluate intrinsic model (lsf_fwhm=0) # on a fine sub-pixel grid, convolve with the wavelength-dependent # Gaussian LSF, then pixel-average. Per-line decomposition uses # compose_leave_one_out on the fine grid; convolution is linear so # convolving deltas separately is exact. n_super = args.n_super half_width = args.conv_half_width assert n_super is not None assert half_width is not None n_pixels = low.shape[0] # Fine grid: (n_super, n_pixels) → flattened (N,). offsets = (jnp.arange(n_super) + 0.5) / n_super x_fine = low[None, :] + offsets[:, None] * (high - low)[None, :] x_flat = x_fine.ravel() # Intrinsic profiles (lsf_fwhm=0) on fine grid. zero_lsf = jnp.zeros_like(centers) phi_fine = evaluate_lines( x_flat, centers, zero_lsf, p0, p1, p2, cm.profile_codes ) # LSF sigma at pixel centres (reuse cont_lsf_fwhm computed above). sigma_pix = cont_lsf_fwhm * _FWHM_TO_SIGMA # Per-region continuum on fine grid (lsf_fwhm=0): pixel-average # then convolve at pixel resolution. cont_regions_fine = eval_continuum_regions( x_flat, args, params, z_sys, 0.0 ) cont_regions_scaled = [ _lsf_convolve( wavelength, (r * _cont_scale).reshape(n_super, n_pixels).mean(axis=0), sigma_pix, half_width, ) * flux_scale_val for r in cont_regions_fine ] # Total continuum on fine grid for model composition. cont_fine_total = jnp.zeros_like(x_flat) for r in cont_regions_fine: cont_fine_total = cont_fine_total + r cont_fine_scaled = cont_fine_total * _cont_scale # Leave-one-out decomposition on fine grid. total_fine, deltas_fine = compose_leave_one_out( phi_fine, scaled_flux, tau_per_line, cm.is_tau, cm.applies_matrix, args.cont_applies, cont_fine_scaled, has_tau=has_tau, ) # Pixel-average first, then convolve at pixel resolution. # Cost: O(n_super * N + N * sigma/dlam) vs O(n_super² * N * sigma/dlam). total_pix = total_fine.reshape(n_super, n_pixels).mean(axis=0) total_conv = _lsf_convolve(wavelength, total_pix, sigma_pix, half_width) deltas_pix = deltas_fine.reshape(-1, n_super, n_pixels).mean(axis=1) deltas_conv = jax.vmap( lambda d: _lsf_convolve(wavelength, d, sigma_pix, half_width) )(deltas_pix) # (n_lines, n_pixels) total = flux_scale_val * total_conv line_contribs = flux_scale_val * deltas_conv return total, line_contribs, cont_regions_scaled, tau_profiles_arr else: # Analytic: CDF-based per-line integration, then compose. pixints = integrate_lines( low, high, centers, lsf_fwhm, p0, p1, p2, cm.profile_codes ) / (high - low) total, line_contribs = compose_leave_one_out( pixints, scaled_flux, tau_per_line, cm.is_tau, cm.applies_matrix, args.cont_applies, cont_total_scaled, has_tau=has_tau, ) # Apply flux_scale to total and per-line contributions. total = flux_scale_val * total line_contribs = flux_scale_val * line_contribs return total, line_contribs, cont_regions_scaled, tau_profiles_arr # Vectorise over the leading sample axis of every parameter in context. # JAX treats the dict as a pytree and maps axis 0 of each leaf. total_arr, line_arr, cont_arr, tau_arr = jax.vmap(_single)(context) # total_arr: (n_samples, n_pix) # line_arr: (n_samples, n_lines, n_pix) # cont_arr: list of (n_samples, n_pix), one per continuum region # tau_arr: (n_samples, n_lines, n_pix) lines_dict: dict[str, np.ndarray] = { args.line_labels[j]: np.asarray(line_arr[:, j, :]) for j in range(n_lines) } cont_dict: dict[str, np.ndarray] = {} if args.cont_config is not None: for k in range(len(args.cont_config)): cont_dict[args.continuum_labels[k]] = np.asarray(cont_arr[k]) is_abs = np.asarray(cm.is_tau) tau_dict: dict[str, np.ndarray] = { args.line_labels[j]: np.asarray(tau_arr[:, j, :]) for j in range(n_lines) if is_abs[j] } results.append( SpectrumPrediction( wavelength=wl_out, total=np.asarray(total_arr), lines=lines_dict, continuum_regions=cont_dict, tau_profiles=tau_dict, ) ) return results