Source code for pyhs3.distributions.composite

"""
Composite distribution implementations.

Provides classes for handling composite probability distributions that combine
multiple other distributions, including mixtures and products of distributions.
"""

from __future__ import annotations

from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Literal, cast

import pytensor.tensor as pt
from pydantic import (
    Field,
    PrivateAttr,
    ValidationInfo,
    field_serializer,
    field_validator,
    model_serializer,
)
from pytensor.graph.traversal import explicit_graph_inputs

from pyhs3.context import Context
from pyhs3.distributions.core import Distribution, LogProbTerms
from pyhs3.typing.aliases import TensorVar

if TYPE_CHECKING:
    from pyhs3.distributions import Distributions


[docs] class MixtureDist(Distribution): r""" Mixture of probability distributions. Implements a weighted combination of multiple distributions following ROOT's RooAddPdf. Supports both N and N-1 coefficient configurations where :math:`N` represents number of distributions (`summands`): **N-1 coefficients:** .. math:: f(x) = \sum_{i=1}^{n-1} c_i \cdot f_i(x) + (1 - \sum_{i=1}^{n-1} c_i) \cdot f_n(x) **N coefficients:** .. math:: f(x) = \frac{\sum_{i=1}^{n} c_i \cdot f_i(x)}{\sum_{i=1}^{n} c_i} **N coefficients with `ref_coef_norm`:** .. math:: f(x) = \frac{\sum_{i=1}^{n} c_i \cdot f_i(x)}{\sum_{j \in \text{ref\_coef\_norm}} c_j} Parameters: coefficients (list[str]): Names of coefficient parameters. summands (list[str]): Names of component distributions. extended (bool): Whether the mixture is extended (affects normalization). Must be True for N coefficients, False for N-1 coefficients. ref_coef_norm (list[str] | None): Optional list of coefficient names for custom normalization. Only valid when using N coefficients (extended=True). ROOT Reference: :root:`RooAddPdf` """ type: Literal["mixture_dist"] = "mixture_dist" summands: list[str] coefficients: list[str] extended: bool = False ref_coef_norm: list[str] | None = Field( default=None, json_schema_extra={"preprocess": False} ) # PyTensor nodes built by the most recent likelihood() call, reused by # unnormalized_expression() and expected_yield() so that model.log_prob # shares the same subgraphs instead of building duplicates the compiler # then has to merge (a measurable compile-time cost on large workspaces). _cached_unnorm_expr: TensorVar | None = PrivateAttr(default=None) _cached_nu_expr: TensorVar | None = PrivateAttr(default=None) @model_serializer(mode="wrap") def serialize_model(self, handler: Callable[[Any], Any]) -> Any: """Do not serialize ref_coef_norm if it is unspecified (None).""" data = handler(self) if self.ref_coef_norm is None: del data["ref_coef_norm"] return data @field_validator("ref_coef_norm", mode="before") @classmethod def split_comma_separated_ref_coef_norm(cls, v: object) -> object: """Convert comma-separated string to list for ref_coef_norm.""" if isinstance(v, str): v = v.strip() return None if v == "" else v.split(",") return v @field_validator("ref_coef_norm", mode="after") @classmethod def validate_ref_coef_norm_usage( cls, ref_coef_norm: list[str] | None, info: ValidationInfo ) -> list[str] | None: """Validate that ref_coef_norm is only used with N=N coefficient case.""" if ref_coef_norm is not None: # Get summands and coefficients from the values being validated summands = info.data.get("summands", []) coefficients = info.data.get("coefficients", []) n_coeffs = len(coefficients) n_summands = len(summands) if n_coeffs != n_summands: msg = ( f"ref_coef_norm can only be used with N coefficients and N summands " f"(N={n_summands}), but got {n_coeffs} coefficients." ) raise ValueError(msg) return ref_coef_norm @field_serializer("ref_coef_norm") def serialize_ref_coef_norm(self, ref_coef_norm: list[str] | None) -> str | None: """Convert list back to comma-separated string for serialization.""" if ref_coef_norm is None: return None return ",".join(ref_coef_norm) @field_validator("coefficients", mode="after") @classmethod def validate_coefficient_count( cls, coefficients: list[str], info: ValidationInfo ) -> list[str]: """Validate that coefficient count matches summand count appropriately.""" # Get summands from the values being validated summands = info.data.get("summands", []) n_coeffs = len(coefficients) n_summands = len(summands) if n_coeffs not in (n_summands, n_summands - 1): msg = ( f"Invalid coefficient configuration: {n_coeffs} coefficients " f"for {n_summands} summands. Must have N ({n_summands}) or " f"N-1 ({n_summands - 1}) coefficients." ) raise ValueError(msg) return coefficients @field_validator("extended", mode="after") @classmethod def validate_extended_matches_coefficients( cls, extended: bool, info: ValidationInfo ) -> bool: """Validate that extended matches coefficient configuration.""" # Get summands and coefficients from the values being validated summands = info.data.get("summands", []) coefficients = info.data.get("coefficients", []) n_coeffs = len(coefficients) n_summands = len(summands) # Validate extended matches coefficient configuration if n_coeffs == n_summands: if not extended: msg = ( f"extended must be True when N coefficients = N summands " f"({n_coeffs} coefficients, {n_summands} summands)." ) raise ValueError(msg) elif extended: msg = ( f"extended must be False when N-1 coefficients with N summands " f"({n_coeffs} coefficients, {n_summands} summands)." ) raise ValueError(msg) return extended def likelihood(self, context: Context) -> TensorVar: """ Builds a symbolic expression for the mixture distribution. Handles both N and N-1 coefficient cases: - N-1 coefficients: Traditional approach with automatic normalization - N coefficients: Direct summation with optional custom normalization Args: context (dict): Mapping of names to pytensor variables. Returns: pytensor.tensor.variable.TensorVariable: Symbolic representation of the mixture PDF. """ n_coeffs = len(self.coefficients) n_summands = len(self.summands) if n_coeffs == n_summands: # N coefficients case: direct summation with normalization mixturesum = pt.constant(0.0) # Calculate the mixture sum (unnormalized: Σ cᵢ fᵢ) for i, coeff in enumerate(self.coefficients): mixturesum += context[coeff] * context[self.summands[i]] # Cache the unnormalized Σcᵢfᵢ so that model.log_prob can build # the extended log-likelihood as Σⱼ log(Σcᵢfᵢ(xⱼ)) - nu without # introducing a separate pt.log(nu) node that triggers costly # optimizer rewrites when combined with log(Σcᵢfᵢ/nu). self._cached_unnorm_expr = mixturesum # Handle normalization if self.ref_coef_norm is not None: # Custom normalization using specified coefficients norm_sum = pt.constant(0.0) for norm_coeff in self.ref_coef_norm: norm_sum += context[norm_coeff] # Cache the normalisation sum so expected_yield() can reuse # the same PyTensor node rather than building a duplicate # summation subgraph that the compiler then has to merge. self._cached_nu_expr = norm_sum mixturesum = mixturesum / norm_sum else: # Standard normalization: divide by sum of all coefficients coeffsum = pt.constant(0.0) for coeff in self.coefficients: coeffsum += context[coeff] # Cache so expected_yield() reuses this node. self._cached_nu_expr = coeffsum mixturesum = mixturesum / coeffsum else: # N-1 coefficients case: traditional approach with automatic last term mixturesum = pt.constant(0.0) coeffsum = pt.constant(0.0) # Sum the first N-1 terms for i, coeff in enumerate(self.coefficients): coeffsum += context[coeff] mixturesum += context[coeff] * context[self.summands[i]] # Add the last term with remaining coefficient last_index = len(self.summands) - 1 f_last = context[self.summands[last_index]] mixturesum += (1 - coeffsum) * f_last return cast(TensorVar, mixturesum) def expected_yield(self, context: Context) -> TensorVar: """ Compute the total expected yield nu in the extended case. - N coefficients case: nu = sum(coefficients) or sum(ref_coef_norm) if specified - N-1 coefficients case: not defined (extended=False always) Args: context: Mapping of names to pytensor variables Returns: Expected yield (nu) for extended likelihood Raises: RuntimeError: If called on non-extended PDF """ if not self.extended: msg = "expected_yield only valid for extended PDFs" raise RuntimeError(msg) # Reuse the normalisation sum cached by likelihood() to avoid building # a duplicate PyTensor summation subgraph (which would slow compilation). if self._cached_nu_expr is not None: return self._cached_nu_expr # Fallback: likelihood() hasn't been called yet, compute fresh. nu = pt.constant(0.0) if self.ref_coef_norm is not None: # Use only the coefficients specified in ref_coef_norm for coeff in self.ref_coef_norm: nu += context[coeff] else: # Use all coefficients for coeff in self.coefficients: nu += context[coeff] return cast(TensorVar, nu) def unnormalized_expression(self, context: Context) -> TensorVar: """Return the unnormalized mixture Σcᵢfᵢ (before dividing by Σcᵢ). After :meth:`likelihood` has been called this just returns the cached intermediate result; it is used by :meth:`pyhs3.model.Model.log_prob` to build the extended log-likelihood as Σⱼ log(Σcᵢfᵢ(xⱼ)) - nu which is algebraically equivalent to Σⱼ log(Σcᵢfᵢ(xⱼ)/nu) + N·log(nu) - nu but avoids introducing a separate ``pt.log(nu)`` node that can trigger expensive rewrite cascades in the PyTensor graph optimiser. The Poisson yield term enters the likelihood once per channel and involves the observed event count, so it is assembled by :meth:`pyhs3.model.Model.log_prob` (which owns the channel-dataset pairing) rather than via :meth:`extended_likelihood` (whose result is multiplied into the per-event density and would be overcounted when summing over events). """ if self._cached_unnorm_expr is not None: return self._cached_unnorm_expr # Fallback: recompute (likelihood() hasn't been called yet). unnorm = pt.constant(0.0) for i, coeff in enumerate(self.coefficients): unnorm += context[coeff] * context[self.summands[i]] return cast(TensorVar, unnorm) def log_prob_terms( # pylint: disable=redefined-outer-name self, expressions: Mapping[str, TensorVar], distributions: Distributions, ) -> LogProbTerms: """ Extended-likelihood contributions for the mixture. For an extended mixture the per-channel log-likelihood is built from the unnormalized mixture and the expected yield: Σⱼ log(Σcᵢfᵢ(xⱼ)) - nu which simplifies from Σⱼ log(Σcᵢfᵢ(xⱼ)/nu) + N·log(nu) - nu (the log(nu) terms cancel). This is algebraically identical to the full extended likelihood but avoids introducing pt.log(nu) into the graph, preventing costly optimizer rewrite cascades triggered by log(a/b) + N·log(b). For weighted data this form also reproduces RooFit's sum-of-weights convention: weighting the log(Σcᵢfᵢ) term gives Σⱼwⱼ·log(PDF) + (Σwⱼ)·log(nu) - nu. The data-only -log(N!) constant is omitted and N_eff = Σwⱼ is used for weighted data; both are RooFit conventions adopted because HS3 does not specify the extended term exactly. See https://github.com/hep-statistics-serialization-standard/hep-statistics-serialization-standard/issues/91 and https://github.com/scipp-atlas/pyhs3/issues/241. Non-extended mixtures contribute the default per-event log(PDF). """ if not self.extended: return super().log_prob_terms(expressions, distributions) if self._cached_unnorm_expr is None or self._cached_nu_expr is None: msg = ( f"log_prob_terms for extended mixture '{self.name}' requires " f"likelihood() to have been called (the model graph build does " f"this) so the unnormalized mixture and yield nodes exist." ) raise RuntimeError(msg) return LogProbTerms( per_event=[pt.log(self._cached_unnorm_expr)], channel=[-self._cached_nu_expr], )
[docs] class ProductDist(Distribution): r""" Product distribution implementation. Implements a product of PDFs as defined in ROOT's RooProdPdf. The probability density function is defined as: .. math:: f(x, \ldots) = \prod_{i=1}^{N} \text{PDF}_i(x, \ldots) where each PDF_i is a component distribution that may share observables. Parameters: factors: List of component distribution names to multiply together Note: In the context of pytensor variables/tensors, this is implemented as an elementwise product of all factor distributions. """ type: Literal["product_dist"] = "product_dist" factors: list[str] def likelihood(self, context: Context) -> TensorVar: """ Evaluate the product distribution. Args: context: Mapping of names to pytensor variables Returns: Symbolic representation of the product PDF """ if not self.factors: return cast(TensorVar, pt.constant(1.0)) vals = [context[factor] for factor in self.factors] return cast(TensorVar, pt.mul(*vals)) def log_prob_terms( # pylint: disable=redefined-outer-name self, expressions: Mapping[str, TensorVar], distributions: Distributions, ) -> LogProbTerms: """ Split factors into per-event shape terms and per-channel constraints. Factors that depend on an observable (detected by a pt.vector free input — observable arrays are built as pt.vector, NP scalars as pt.scalar) delegate to their own log_prob_terms, so e.g. an extended mixture inside a product contributes its yield term correctly. Factors that depend only on scalar nuisance parameters are constraint PDFs and are returned as named constraints for the model to add once globally. This prevents the N-fold overcounting that occurs when naively summing log(shape x Π_j constr_j) over N events, which multiplies each constr_j term by N rather than counting it once. The shape-vs-constraint classification is structural inference (matching RooProdPdf), not HS3 spec semantics — nothing in an HS3 file marks a factor as a constraint. If the spec adopts explicit constraint encoding, this inference can be removed; see https://github.com/hep-statistics-serialization-standard/hep-statistics-serialization-standard/issues/90 and https://github.com/scipp-atlas/pyhs3/issues/240. """ terms = LogProbTerms() for factor_name in self.factors: factor_expr = expressions[factor_name] free_inputs = [ v for v in explicit_graph_inputs([factor_expr]) if v.name is not None ] # Observable-dependent factors have at least one vector (rank-1) input. is_shape = any(v.type.ndim >= 1 for v in free_inputs) if is_shape: factor_terms = distributions[factor_name].log_prob_terms( expressions, distributions ) terms.per_event.extend(factor_terms.per_event) terms.channel.extend(factor_terms.channel) terms.constraints.update(factor_terms.constraints) else: terms.constraints[factor_name] = cast(TensorVar, pt.log(factor_expr)) return terms
# Registry of composite distributions distributions: dict[str, type[Distribution]] = { "mixture_dist": MixtureDist, "product_dist": ProductDist, } # Define what should be exported from this module __all__ = [ "MixtureDist", "ProductDist", "distributions", ]