Model¶
Model builder and numpyro model function for spectral line fitting.
The ModelBuilder assembles a LineConfiguration,
an optional ContinuumConfiguration, and a
Spectra collection into a numpyro model
function that can be passed to any numpyro inference algorithm (NUTS, SVI, etc.).
- class unite.model.ModelArgs(matrices, spectra, redshift, cont_config, cont_resolved_params, all_priors, dependency_order, name_to_token, spec_to_canonical, cont_low, cont_high, cont_center, cont_nw_conv, cont_forms, norm_factors, line_flux_scales, continuum_scales, canonical_unit, flux_units, line_scale_quantity, continuum_scale_quantity, line_labels, continuum_labels, cont_applies, integration_mode='analytic', quadrature_nodes=None, quadrature_weights=None, n_super=None, conv_half_width=None, _evaluators=None, _profile_codes_local=None, _integrate_fn=None, _evaluate_fn=None, _evaluate_at_centers_fn=None)[source]¶
Bases:
objectBundle of arguments passed to
unite_model().Created by
ModelBuilder.build(); not intended for direct construction by users.- Parameters:
matrices (ConfigMatrices)
redshift (float)
cont_config (ContinuumConfiguration | None)
cont_forms (list | None)
canonical_unit (object)
flux_units (list)
line_scale_quantity (Quantity | None)
continuum_scale_quantity (Quantity | None)
cont_applies (Array)
integration_mode (str)
quadrature_nodes (Array | None)
quadrature_weights (Array | None)
n_super (int | None)
conv_half_width (int | None)
_profile_codes_local (Any | None)
_integrate_fn (Any | None)
_evaluate_fn (Any | None)
_evaluate_at_centers_fn (Any | None)
- matrices: ConfigMatrices¶
Precomputed parameter matrices and line metadata.
- cont_config: ContinuumConfiguration | None¶
Continuum configuration, or
Noneif not used.
- cont_resolved_params: list[dict[str, Parameter]] | None¶
Resolved
{param_name: ContinuumParam}mappings per region, fromContinuumConfiguration.resolved_params.
- cont_forms: list | None¶
Pre-converted continuum forms (static wavelength config in canonical unit).
- line_flux_scales: list[float]¶
Per-spectrum line flux scale (in each spectrum’s flux_unit * canonical_wl_unit).
- line_scale_quantity: Quantity | None¶
The Quantity line_scale and continuum_scale from Spectra (for results output).
- line_labels: list[str]¶
Human-readable column labels for each line, parallel to
matrices.wavelengths. Derived from user-supplied line names and rest-frame wavelengths.
- continuum_labels: list[str]¶
Human-readable column labels for each continuum region, parallel to
cont_config. Derived from form type and wavelength bounds.
- cont_applies: Array¶
Boolean mask: which tau lines (by line index) attenuate the continuum.
cont_applies[k]is True whenline_zorders[k] > cont_zorderandis_tau[k]. Shape(n_lines,).
- integration_mode: str = 'analytic'¶
Line integration mode:
'analytic'(default) uses exact CDF-based integration for all line profiles individually;'quadrature'uses Gauss-Legendre quadrature to integrate the full composed model over pixels.
- quadrature_nodes: Array | None = None¶
Gauss-Legendre quadrature nodes on
[-1, 1].Nonewhenintegration_mode != 'quadrature'.
- quadrature_weights: Array | None = None¶
Gauss-Legendre quadrature weights.
Nonewhenintegration_mode != 'quadrature'.
- unite.model.unite_model(args)[source]¶
Numpyro model function for multi-spectrum emission-line fitting.
All lines are integrated simultaneously via
jax.vmap()withlax.switchdispatching to the correct profile kernel per line. Parameter broadcasting from unique tokens to per-line arrays is done with precomputed indicator matrices.Wavelength unit conversion is handled via pre-computed scalar factors stored in
args.spec_to_canonical(one per spectrum). Flux is normalized per spectrum so that the likelihood operates on O(1) values.- Return type:
- Parameters:
- argsModelArgs
Pre-built data bundle from
ModelBuilder.build().
- Parameters:
args (ModelArgs)
- class unite.model.ModelBuilder(line_config, continuum_config, spectra)[source]¶
Bases:
objectAssemble configuration objects into a numpyro model.
Collects all unique parameter tokens (line, calibration, continuum), builds precomputed indicator matrices, performs a topological sort for dependency resolution, and packages everything into a
(model_fn, model_args)pair.- Parameters:
- line_configLineConfiguration
Emission/absorption line configuration.
- continuum_configContinuumConfiguration or None
Continuum configuration.
Nonefor a lines-only model.- spectraSpectra
Spectrum collection with systemic redshift.
- Parameters:
line_config (LineConfiguration)
continuum_config (ContinuumConfiguration | None)
spectra (Spectra)
Examples
>>> model_fn, args = ModelBuilder(line_config, cont, spectra).build() >>> kernel = numpyro.infer.NUTS(model_fn) >>> mcmc = numpyro.infer.MCMC(kernel, num_warmup=500, num_samples=1000) >>> mcmc.run(jax.random.PRNGKey(0), args)
- property matrices: ConfigMatrices¶
Precomputed matrices (after coverage filtering).
- build(*, integration_mode='analytic', n_nodes=7, n_super=10, conv_half_width=None)[source]¶
Build the numpyro model function and its arguments.
- Return type:
- Parameters:
- integration_modestr, optional
How line profiles are integrated over pixels. One of:
'analytic'(default) — exact CDF-based integration for emission profiles and pixel-center evaluation for absorption profiles.'quadrature'— Gauss-Legendre quadrature for all profiles (both emission and absorption). More accurate for absorption lines at the cost of speed.'convolution'— evaluates the intrinsic model (lsf_fwhm=0) on a uniform fine sub-pixel grid ofn_superpoints per pixel, numerically convolves with the wavelength-dependent Gaussian LSF, then pixel-averages. Correctly computesLSF ⊗ [F · exp(-τ · φ_intrinsic)]rather thanF · exp(-τ · LSF ⊗ φ), eliminating the LSF pre-convolution approximation for absorption lines.
- n_nodesint, optional
Number of Gauss-Legendre quadrature nodes per pixel (default: 7). Only used when
integration_mode='quadrature'. Higher values give more accurate integration at greater computational cost.- n_superint, optional
Number of uniform sub-pixel evaluation points per pixel (default: 10). Only used when
integration_mode='convolution'. Higher values resolve narrower intrinsic line profiles at the cost of speed.n_super=10is adequate for NIRSpec gratings; increase to 20 for narrow absorbers at PRISM resolution.- conv_half_widthint or None, optional
Half-width of the banded LSF convolution kernel in fine-grid indices (default:
None). WhenNone, auto-computed at build time asceil(4 * max_sigma / min_dx_fine * 1.5)wheremax_sigmais the largest LSF sigma across all spectra andmin_dx_fineis the finest sub-pixel spacing. Only used whenintegration_mode='convolution'.
- Returns:
- model_fncallable
The numpyro model function (signature:
model_fn(args)).- model_argsModelArgs
Pre-built data bundle to pass to the model function.
- Raises:
- ValueError
If integration_mode is not one of the valid values.
- Parameters:
- fit(num_warmup=250, num_samples=1000, num_chains=1, seed=0, progress_bar=True, integration_mode='analytic', n_nodes=7, n_super=10)[source]¶
Fit the model using NUTS sampling (convenience wrapper).
This method builds the model, runs MCMC with the NUTS kernel, and returns the posterior samples. For more control over the sampler (e.g., custom kernel, SVI, nested sampling), call
build()directly and use numpyro’s inference APIs.- Return type:
- Parameters:
- num_warmupint, optional
Number of warmup iterations per chain (default: 1000).
- num_samplesint, optional
Number of posterior samples per chain (default: 1000).
- num_chainsint, optional
Number of MCMC chains to run in parallel (default: 1).
- seedint, optional
Random seed for JAX’s PRNG (default: 0).
- progress_barbool, optional
Whether to display a progress bar (default: True).
- integration_modestr, optional
Line integration mode (default:
'analytic'). Seebuild()for details.- n_nodesint, optional
Gauss-Legendre quadrature nodes per pixel (default: 7). See
build()for details.- n_superint, optional
Sub-pixel evaluation points per pixel for convolution mode (default: 10). See
build()for details.
- Returns:
- tuple
(samples, model_args)wheresamplesis a dictionary with parameter names as keys and shape(num_chains, num_samples)per parameter, andmodel_argsis theModelArgsbundle.
- Parameters:
Examples
>>> samples, model_args = builder.fit(num_warmup=200, num_samples=500, num_chains=4)