Source code for pyhs3.distributions.composite

"""
Composite distribution implementations.

Provides classes for handling composite probability distributions that combine
multiple other distributions, including mixtures and products of distributions.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any, Literal, cast

import pytensor.tensor as pt
from pydantic import (
    Field,
    ValidationInfo,
    field_serializer,
    field_validator,
    model_serializer,
)

from pyhs3.context import Context
from pyhs3.distributions.core import Distribution
from pyhs3.typing.aliases import TensorVar


[docs] class MixtureDist(Distribution): r""" Mixture of probability distributions. Implements a weighted combination of multiple distributions following ROOT's RooAddPdf. Supports both N and N-1 coefficient configurations where :math:`N` represents number of distributions (`summands`): **N-1 coefficients:** .. math:: f(x) = \sum_{i=1}^{n-1} c_i \cdot f_i(x) + (1 - \sum_{i=1}^{n-1} c_i) \cdot f_n(x) **N coefficients:** .. math:: f(x) = \frac{\sum_{i=1}^{n} c_i \cdot f_i(x)}{\sum_{i=1}^{n} c_i} **N coefficients with `ref_coef_norm`:** .. math:: f(x) = \frac{\sum_{i=1}^{n} c_i \cdot f_i(x)}{\sum_{j \in \text{ref\_coef\_norm}} c_j} Parameters: coefficients (list[str]): Names of coefficient parameters. summands (list[str]): Names of component distributions. extended (bool): Whether the mixture is extended (affects normalization). Must be True for N coefficients, False for N-1 coefficients. ref_coef_norm (list[str] | None): Optional list of coefficient names for custom normalization. Only valid when using N coefficients (extended=True). ROOT Reference: :rootref:`RooAddPdf <classRooAddPdf.html>` """ type: Literal["mixture_dist"] = "mixture_dist" summands: list[str] coefficients: list[str] extended: bool = False ref_coef_norm: list[str] | None = Field( default=None, json_schema_extra={"preprocess": False} ) @model_serializer(mode="wrap") def serialize_model(self, handler: Callable[[Any], Any]) -> Any: """Do not serialize ref_coef_norm if it is unspecified (None).""" data = handler(self) if self.ref_coef_norm is None: del data["ref_coef_norm"] return data @field_validator("ref_coef_norm", mode="before") @classmethod def split_comma_separated_ref_coef_norm(cls, v: object) -> object: """Convert comma-separated string to list for ref_coef_norm.""" if isinstance(v, str): v = v.strip() return None if v == "" else v.split(",") return v @field_validator("ref_coef_norm", mode="after") @classmethod def validate_ref_coef_norm_usage( cls, ref_coef_norm: list[str] | None, info: ValidationInfo ) -> list[str] | None: """Validate that ref_coef_norm is only used with N=N coefficient case.""" if ref_coef_norm is not None: # Get summands and coefficients from the values being validated summands = info.data.get("summands", []) coefficients = info.data.get("coefficients", []) n_coeffs = len(coefficients) n_summands = len(summands) if n_coeffs != n_summands: msg = ( f"ref_coef_norm can only be used with N coefficients and N summands " f"(N={n_summands}), but got {n_coeffs} coefficients." ) raise ValueError(msg) return ref_coef_norm @field_serializer("ref_coef_norm") def serialize_ref_coef_norm(self, ref_coef_norm: list[str] | None) -> str | None: """Convert list back to comma-separated string for serialization.""" if ref_coef_norm is None: return None return ",".join(ref_coef_norm) @field_validator("coefficients", mode="after") @classmethod def validate_coefficient_count( cls, coefficients: list[str], info: ValidationInfo ) -> list[str]: """Validate that coefficient count matches summand count appropriately.""" # Get summands from the values being validated summands = info.data.get("summands", []) n_coeffs = len(coefficients) n_summands = len(summands) if n_coeffs not in (n_summands, n_summands - 1): msg = ( f"Invalid coefficient configuration: {n_coeffs} coefficients " f"for {n_summands} summands. Must have N ({n_summands}) or " f"N-1 ({n_summands - 1}) coefficients." ) raise ValueError(msg) return coefficients @field_validator("extended", mode="after") @classmethod def validate_extended_matches_coefficients( cls, extended: bool, info: ValidationInfo ) -> bool: """Validate that extended matches coefficient configuration.""" # Get summands and coefficients from the values being validated summands = info.data.get("summands", []) coefficients = info.data.get("coefficients", []) n_coeffs = len(coefficients) n_summands = len(summands) # Validate extended matches coefficient configuration if n_coeffs == n_summands and not extended: msg = ( f"extended must be True when N coefficients = N summands " f"({n_coeffs} coefficients, {n_summands} summands)." ) raise ValueError(msg) if n_coeffs == n_summands - 1 and extended: msg = ( f"extended must be False when N-1 coefficients with N summands " f"({n_coeffs} coefficients, {n_summands} summands)." ) raise ValueError(msg) return extended def expression(self, context: Context) -> TensorVar: """ Builds a symbolic expression for the mixture distribution. Handles both N and N-1 coefficient cases: - N-1 coefficients: Traditional approach with automatic normalization - N coefficients: Direct summation with optional custom normalization Args: context (dict): Mapping of names to pytensor variables. Returns: pytensor.tensor.variable.TensorVariable: Symbolic representation of the mixture PDF. """ n_coeffs = len(self.coefficients) n_summands = len(self.summands) if n_coeffs == n_summands: # N coefficients case: direct summation with normalization mixturesum = pt.constant(0.0) # Calculate the mixture sum for i, coeff in enumerate(self.coefficients): mixturesum += context[coeff] * context[self.summands[i]] # Handle normalization if self.ref_coef_norm is not None: # Custom normalization using specified coefficients norm_sum = pt.constant(0.0) for norm_coeff in self.ref_coef_norm: norm_sum += context[norm_coeff] mixturesum = mixturesum / norm_sum else: # Standard normalization: divide by sum of all coefficients coeffsum = pt.constant(0.0) for coeff in self.coefficients: coeffsum += context[coeff] mixturesum = mixturesum / coeffsum else: # N-1 coefficients case: traditional approach with automatic last term mixturesum = pt.constant(0.0) coeffsum = pt.constant(0.0) # Sum the first N-1 terms for i, coeff in enumerate(self.coefficients): coeffsum += context[coeff] mixturesum += context[coeff] * context[self.summands[i]] # Add the last term with remaining coefficient last_index = len(self.summands) - 1 f_last = context[self.summands[last_index]] mixturesum += (1 - coeffsum) * f_last return cast(TensorVar, mixturesum) def expected_yield(self, context: Context) -> TensorVar: """ Compute the total expected yield nu in the extended case. - N coefficients case: nu = sum(coefficients) or sum(ref_coef_norm) if specified - N-1 coefficients case: not defined (extended=False always) Args: context: Mapping of names to pytensor variables Returns: Expected yield (nu) for extended likelihood Raises: RuntimeError: If called on non-extended PDF """ if not self.extended: msg = "expected_yield only valid for extended PDFs" raise RuntimeError(msg) nu = pt.constant(0.0) if self.ref_coef_norm is not None: # Use only the coefficients specified in ref_coef_norm for coeff in self.ref_coef_norm: nu += context[coeff] else: # Use all coefficients for coeff in self.coefficients: nu += context[coeff] return cast(TensorVar, nu) def extended_likelihood(self, context: Context, n_observed: int) -> TensorVar: """ Poisson term for the extended likelihood. Computes: log Pois(N_obs | nu) = N_obs * log(nu) - nu Args: context: Mapping of names to pytensor variables n_observed: Number of observed events Returns: Log Poisson probability for extended likelihood Raises: RuntimeError: If called on non-extended PDF """ if not self.extended: msg = "extended_likelihood only valid when extended=True" raise RuntimeError(msg) nu = self.expected_yield(context) n_obs = pt.constant(n_observed, dtype="float64") # log(Pois(N|nu)), dropping the constant -log(N!) log_pois = n_obs * pt.log(nu) - nu return cast(TensorVar, log_pois)
[docs] class ProductDist(Distribution): r""" Product distribution implementation. Implements a product of PDFs as defined in ROOT's RooProdPdf. The probability density function is defined as: .. math:: f(x, \ldots) = \prod_{i=1}^{N} \text{PDF}_i(x, \ldots) where each PDF_i is a component distribution that may share observables. Parameters: factors: List of component distribution names to multiply together Note: In the context of pytensor variables/tensors, this is implemented as an elementwise product of all factor distributions. """ type: Literal["product_dist"] = "product_dist" factors: list[str] def expression(self, context: Context) -> TensorVar: """ Evaluate the product distribution. Args: context: Mapping of names to pytensor variables Returns: Symbolic representation of the product PDF """ if not self.factors: return cast(TensorVar, pt.constant(1.0)) pt_factors = pt.stack([context[factor] for factor in self.factors]) return cast(TensorVar, pt.prod(pt_factors, axis=0)) # type: ignore[no-untyped-call]
# Registry of composite distributions distributions: dict[str, type[Distribution]] = { "mixture_dist": MixtureDist, "product_dist": ProductDist, } # Define what should be exported from this module __all__ = [ "MixtureDist", "ProductDist", "distributions", ]