Model

Model builder and numpyro model function for spectral line fitting.

The ModelBuilder assembles a LineConfiguration, an optional ContinuumConfiguration, and a Spectra collection into a numpyro model function that can be passed to any numpyro inference algorithm (NUTS, SVI, etc.).

class unite.model.ModelArgs(matrices, spectra, redshift, cont_config, cont_resolved_params, all_priors, dependency_order, name_to_token, spec_to_canonical, cont_low, cont_high, cont_center, cont_nw_conv, cont_forms, norm_factors, line_flux_scales, continuum_scales, canonical_unit, flux_units, line_scale_quantity, continuum_scale_quantity, line_labels, continuum_labels, cont_applies, integration_mode='analytic', quadrature_nodes=None, quadrature_weights=None, n_super=None, conv_half_width=None, _evaluators=None, _profile_codes_local=None, _integrate_fn=None, _evaluate_fn=None, _evaluate_at_centers_fn=None)[source]

Bases: object

Bundle of arguments passed to unite_model().

Created by ModelBuilder.build(); not intended for direct construction by users.

Parameters:
matrices: ConfigMatrices

Precomputed parameter matrices and line metadata.

spectra: list[Spectrum]

Individual spectra.

redshift: float

Systemic redshift.

cont_config: ContinuumConfiguration | None

Continuum configuration, or None if not used.

cont_resolved_params: list[dict[str, Parameter]] | None

Resolved {param_name: ContinuumParam} mappings per region, from ContinuumConfiguration.resolved_params.

all_priors: dict[str, Prior]

All parameters with their priors (line, calibration, continuum).

dependency_order: list[str]

Topological sampling order for all parameters.

name_to_token: dict[str, object]
spec_to_canonical: list[float]
cont_low: list[float] | None
cont_high: list[float] | None
cont_center: list[float] | None
cont_nw_conv: list[float] | None
cont_forms: list | None

Pre-converted continuum forms (static wavelength config in canonical unit).

norm_factors: list[float]
line_flux_scales: list[float]

Per-spectrum line flux scale (in each spectrum’s flux_unit * canonical_wl_unit).

continuum_scales: list[float]

Per-spectrum continuum scale (in each spectrum’s flux_unit).

canonical_unit: object

Wavelength unit of canonical frame (first spectrum’s disperser unit).

flux_units: list

Per-spectrum flux density units.

line_scale_quantity: Quantity | None

The Quantity line_scale and continuum_scale from Spectra (for results output).

continuum_scale_quantity: Quantity | None
line_labels: list[str]

Human-readable column labels for each line, parallel to matrices.wavelengths. Derived from user-supplied line names and rest-frame wavelengths.

continuum_labels: list[str]

Human-readable column labels for each continuum region, parallel to cont_config. Derived from form type and wavelength bounds.

cont_applies: Array

Boolean mask: which tau lines (by line index) attenuate the continuum. cont_applies[k] is True when line_zorders[k] > cont_zorder and is_tau[k]. Shape (n_lines,).

integration_mode: str = 'analytic'

Line integration mode: 'analytic' (default) uses exact CDF-based integration for all line profiles individually; 'quadrature' uses Gauss-Legendre quadrature to integrate the full composed model over pixels.

quadrature_nodes: Array | None = None

Gauss-Legendre quadrature nodes on [-1, 1]. None when integration_mode != 'quadrature'.

quadrature_weights: Array | None = None

Gauss-Legendre quadrature weights. None when integration_mode != 'quadrature'.

n_super: int | None = None

Number of uniform sub-pixel evaluation points per pixel for convolution mode. None when integration_mode != 'convolution'.

conv_half_width: int | None = None

Half-width of the banded LSF convolution kernel in fine-grid indices. Pre-computed at build time as a Python int (not a traced value). None when integration_mode != 'convolution'.

unite.model.unite_model(args)[source]

Numpyro model function for multi-spectrum emission-line fitting.

All lines are integrated simultaneously via jax.vmap() with lax.switch dispatching to the correct profile kernel per line. Parameter broadcasting from unique tokens to per-line arrays is done with precomputed indicator matrices.

Wavelength unit conversion is handled via pre-computed scalar factors stored in args.spec_to_canonical (one per spectrum). Flux is normalized per spectrum so that the likelihood operates on O(1) values.

Return type:

None

Parameters:
argsModelArgs

Pre-built data bundle from ModelBuilder.build().

Parameters:

args (ModelArgs)

class unite.model.ModelBuilder(line_config, continuum_config, spectra)[source]

Bases: object

Assemble configuration objects into a numpyro model.

Collects all unique parameter tokens (line, calibration, continuum), builds precomputed indicator matrices, performs a topological sort for dependency resolution, and packages everything into a (model_fn, model_args) pair.

Parameters:
line_configLineConfiguration

Emission/absorption line configuration.

continuum_configContinuumConfiguration or None

Continuum configuration. None for a lines-only model.

spectraSpectra

Spectrum collection with systemic redshift.

Parameters:

Examples

>>> model_fn, args = ModelBuilder(line_config, cont, spectra).build()
>>> kernel = numpyro.infer.NUTS(model_fn)
>>> mcmc = numpyro.infer.MCMC(kernel, num_warmup=500, num_samples=1000)
>>> mcmc.run(jax.random.PRNGKey(0), args)
property matrices: ConfigMatrices

Precomputed matrices (after coverage filtering).

build(*, integration_mode='analytic', n_nodes=7, n_super=10, conv_half_width=None)[source]

Build the numpyro model function and its arguments.

Return type:

tuple[Callable, ModelArgs]

Parameters:
integration_modestr, optional

How line profiles are integrated over pixels. One of:

  • 'analytic' (default) — exact CDF-based integration for emission profiles and pixel-center evaluation for absorption profiles.

  • 'quadrature' — Gauss-Legendre quadrature for all profiles (both emission and absorption). More accurate for absorption lines at the cost of speed.

  • 'convolution' — evaluates the intrinsic model (lsf_fwhm=0) on a uniform fine sub-pixel grid of n_super points per pixel, numerically convolves with the wavelength-dependent Gaussian LSF, then pixel-averages. Correctly computes LSF [F · exp(-τ · φ_intrinsic)] rather than F · exp(-τ · LSF φ), eliminating the LSF pre-convolution approximation for absorption lines.

n_nodesint, optional

Number of Gauss-Legendre quadrature nodes per pixel (default: 7). Only used when integration_mode='quadrature'. Higher values give more accurate integration at greater computational cost.

n_superint, optional

Number of uniform sub-pixel evaluation points per pixel (default: 10). Only used when integration_mode='convolution'. Higher values resolve narrower intrinsic line profiles at the cost of speed. n_super=10 is adequate for NIRSpec gratings; increase to 20 for narrow absorbers at PRISM resolution.

conv_half_widthint or None, optional

Half-width of the banded LSF convolution kernel in fine-grid indices (default: None). When None, auto-computed at build time as ceil(4 * max_sigma / min_dx_fine * 1.5) where max_sigma is the largest LSF sigma across all spectra and min_dx_fine is the finest sub-pixel spacing. Only used when integration_mode='convolution'.

Returns:
model_fncallable

The numpyro model function (signature: model_fn(args)).

model_argsModelArgs

Pre-built data bundle to pass to the model function.

Raises:
ValueError

If integration_mode is not one of the valid values.

Parameters:
  • integration_mode (str)

  • n_nodes (int)

  • n_super (int)

  • conv_half_width (int | None)

fit(num_warmup=250, num_samples=1000, num_chains=1, seed=0, progress_bar=True, integration_mode='analytic', n_nodes=7, n_super=10)[source]

Fit the model using NUTS sampling (convenience wrapper).

This method builds the model, runs MCMC with the NUTS kernel, and returns the posterior samples. For more control over the sampler (e.g., custom kernel, SVI, nested sampling), call build() directly and use numpyro’s inference APIs.

Return type:

tuple[dict, ModelArgs]

Parameters:
num_warmupint, optional

Number of warmup iterations per chain (default: 1000).

num_samplesint, optional

Number of posterior samples per chain (default: 1000).

num_chainsint, optional

Number of MCMC chains to run in parallel (default: 1).

seedint, optional

Random seed for JAX’s PRNG (default: 0).

progress_barbool, optional

Whether to display a progress bar (default: True).

integration_modestr, optional

Line integration mode (default: 'analytic'). See build() for details.

n_nodesint, optional

Gauss-Legendre quadrature nodes per pixel (default: 7). See build() for details.

n_superint, optional

Sub-pixel evaluation points per pixel for convolution mode (default: 10). See build() for details.

Returns:
tuple

(samples, model_args) where samples is a dictionary with parameter names as keys and shape (num_chains, num_samples) per parameter, and model_args is the ModelArgs bundle.

Parameters:
  • num_warmup (int)

  • num_samples (int)

  • num_chains (int)

  • seed (int)

  • progress_bar (bool)

  • integration_mode (str)

  • n_nodes (int)

  • n_super (int)

Examples

>>> samples, model_args = builder.fit(num_warmup=200, num_samples=500, num_chains=4)