"""
HS3 Functions implementation.
Provides classes for handling HS3 functions including product functions,
generic functions with mathematical expressions, and interpolation functions.
"""
from __future__ import annotations
import logging
from enum import IntEnum
from typing import Annotated, Literal, cast
import pytensor.tensor as pt
from pydantic import (
ConfigDict,
Field,
)
from pyhs3.context import Context
from pyhs3.distributions.histfactory.interpolations import (
interpolate_code0,
interpolate_code1,
interpolate_code4,
interpolate_parabolic,
interpolate_poly6,
)
from pyhs3.distributions.histogram import HistogramData
from pyhs3.exceptions import custom_error_msg
from pyhs3.functions.core import Function
from pyhs3.generic_parse import GenericExpressionMixin
from pyhs3.typing.aliases import TensorVar
log = logging.getLogger(__name__)
def _asym_interpolation(
theta: TensorVar, kappa_sum: float, kappa_diff: float
) -> TensorVar:
"""
Implement asymmetric interpolation for ProcessNormalization.
Based on the jaxfit implementation:
https://github.com/nsmith-/jaxfit/blob/8479cd73e733ba35462287753fab44c0c560037b/src/jaxfit/roofit/combine.py#L197
and CMS Combine's ``ProcessNormalization::logKappaForX`` logic.
The shift is
.. math::
\\text{shift} = \\tfrac{1}{2}\\left(
\\kappa_\\text{diff}\\,\\theta
+ \\kappa_\\text{sum}\\,\\theta\\,\\text{smoothStep}(\\theta)
\\right)
where ``smoothStep`` is the CMS smooth-step function
.. math::
\\text{smoothStep}(\\theta) = \\begin{cases}
\\theta\\,(3\\theta^4 - 10\\theta^2 + 15)/8 & |\\theta| < 1 \\\\
\\operatorname{sign}(\\theta) & |\\theta| \\geq 1
\\end{cases}
The factor multiplying ``kappa_sum`` is therefore
:math:`\\theta\\cdot\\text{smoothStep}(\\theta)`, a smooth approximation of
:math:`|\\theta|` that is ``0`` at :math:`\\theta=0`, equals ``1`` at
:math:`|\\theta|=1`, and matches the first derivative of :math:`|\\theta|`
at the boundary.
Args:
theta: The nuisance parameter value
kappa_sum: logKappaHi + logKappaLo
kappa_diff: logKappaHi - logKappaLo
Returns:
The interpolated shift value
"""
# smoothStep(theta) for |theta| < 1: theta * (3*theta^4 - 10*theta^2 + 15) / 8
# The factor multiplying kappa_sum is theta * smoothStep(theta), which for
# |theta| < 1 equals theta^2 * (3*theta^4 - 10*theta^2 + 15) / 8 and for
# |theta| >= 1 equals |theta| (a smooth approximation of |theta|).
theta_sq = theta * theta
theta_quad = theta_sq * theta_sq
poly_result = theta_sq * (3.0 * theta_quad - 10.0 * theta_sq + 15.0) / 8.0
# Linear behaviour for |theta| >= 1: theta * sign(theta) == |theta|
linear_result = pt.abs(theta)
# Choose between polynomial and linear based on |theta|
abs_theta = pt.abs(theta)
smooth_function = cast(
TensorVar, pt.switch(abs_theta < 1.0, poly_result, linear_result)
)
# Final asymmetric interpolation formula
return cast(TensorVar, 0.5 * (kappa_diff * theta + kappa_sum * smooth_function))
[docs]
class SumFunction(Function):
"""Sum function that adds summands together.
HS3 Reference:
:ref:`hs3:hs3.sum`
"""
type: Literal["sum"] = Field(default="sum", repr=False)
summands: list[str] = Field(..., repr=False)
def _expression(self, context: Context) -> TensorVar:
"""
Evaluate the sum function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the sum of all summands.
"""
if not self.summands:
return pt.constant(0.0)
result = context[self.summands[0]]
for summand in self.summands[1:]:
result = result + context[summand]
return result
[docs]
class ProductFunction(Function):
"""Product function that multiplies factors together.
HS3 Reference:
:ref:`hs3:hs3.product`
"""
type: Literal["product"] = Field(default="product", repr=False)
factors: list[int | float | str] = Field(..., repr=False)
def _expression(self, context: Context) -> TensorVar:
"""
Evaluate the product function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the product of all factors.
"""
if not self.factors:
return pt.constant(1.0)
# Get list of factors using flattened parameter keys
factor_values = self.get_parameter_list(context, "factors")
result = factor_values[0]
for factor_value in factor_values[1:]:
result = result * factor_value
return result
[docs]
class GenericFunction(GenericExpressionMixin, Function):
"""
Generic function with custom mathematical expression.
Evaluates arbitrary mathematical expressions using SymPy parsing
and PyTensor computation. Supports common mathematical operations
including arithmetic, trigonometric, exponential, and logarithmic functions.
The expression is parsed once during initialization and converted to
a PyTensor computation graph for efficient evaluation.
Parameters:
name (str): Name of the function.
expression (str): Mathematical expression string to evaluate.
Examples:
>>> func = GenericFunction(name="quadratic", expression="x**2 + 2*x + 1")
>>> func = GenericFunction(name="sinusoid", expression="sin(x) * exp(-t)")
HS3 Reference:
:hs3:label:`generic_function <hs3.generic-function>`
"""
type: Literal["generic_function"] = Field(default="generic_function", repr=False)
def _expression(self, context: Context) -> TensorVar:
"""
Evaluate the generic function expression.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the parsed mathematical expression.
"""
return self._eval_expression(context)
class InterpolationCode(IntEnum):
"""
Enumeration of interpolation codes for systematic variations.
Defines the different interpolation methods used by InterpolationFunction
for systematic uncertainty variations. Each code represents a different
mathematical approach to interpolating between nominal, low, and high values.
"""
LIN_LIN_ADD = 0
EXP_EXP_MUL = 1
EXP_LIN_ADD = 2
EXP_MIX_ADD = 3
POL_LIN_ADD = 4
POL_EXP_MUL = 5
POL_LIN_MUL = 6
[docs]
class InterpolationFunction(Function):
r"""
Piecewise interpolation function implementation.
Implements ROOT's PiecewiseInterpolation logic to morph between nominal
and variation distributions based on nuisance parameter values.
Supports multiple interpolation codes (0-6) for different mathematical approaches.
HS3 Reference:
Note: Interpolation functions are not explicitly defined in the current HS3 specification.
Mathematical Formulations:
For **additive** interpolation modes (codes 0, 2, 3, 4):
.. math::
\text{result} = \text{nominal} + \sum_i I_i(\theta_i; \text{low}_i, \text{nominal}, \text{high}_i)
For **multiplicative** interpolation modes (codes 1, 5, 6):
.. math::
\text{result} = \text{nominal} \times \prod_i [1 + I_i(\theta_i; \text{low}_i/\text{nominal}, 1, \text{high}_i/\text{nominal})]
Parameters:
name: Name of the function
high: High variation parameter names
low: Low variation parameter names
nom: Nominal parameter name
interpolationCodes: Interpolation method codes (0-6)
positiveDefinite: Whether function should be positive definite
vars: Variable names this function depends on (nuisance parameters)
"""
model_config = ConfigDict(use_enum_values=True)
type: Literal["interpolation"] = Field(default="interpolation", repr=False)
high: list[str] = Field(..., repr=False)
low: list[str] = Field(..., repr=False)
nom: str = Field(..., repr=False)
interpolationCodes: Annotated[
list[InterpolationCode],
custom_error_msg(
{
"enum": "Unknown interpolation code {input} in function '{name}'. Valid codes are {expected}."
}
),
] = Field(..., repr=False)
positiveDefinite: bool = Field(..., repr=False)
vars: list[str] = Field(..., repr=False)
def _flexible_interp_single(
self,
interp_code: int,
low_val: TensorVar,
high_val: TensorVar,
boundary: float,
nominal: TensorVar,
param_val: TensorVar,
) -> TensorVar:
r"""
Implement flexible interpolation for a single parameter.
Based on ROOT's ``FlexibleInterpVar`` /
``RooFit::Detail::MathFuncs::flexibleInterpSingle`` with support for
interpolation codes 0-6. This method computes the interpolation
contribution :math:`I_i(\theta_i)` for a single nuisance parameter.
The verified-correct building blocks from
:mod:`pyhs3.distributions.histfactory.interpolations` are reused
directly (those return the full ``nominal + delta`` value, so the
nominal/baseline is subtracted off here to recover the additive delta
or multiplicative factor that ROOT's ``flexibleInterpSingle`` returns):
- **Code 0** (additive): piecewise-linear, ``interpolate_code0``.
- **Code 1** (multiplicative): piecewise-exponential, ``interpolate_code1``.
- **Codes 2, 3** (additive): parabolic interpolation with linear
extrapolation, ``interpolate_parabolic``. ROOT maps code 3 onto code 2.
- **Code 4** (additive): 6th-degree polynomial interpolation with
linear extrapolation, ``interpolate_poly6``.
- **Code 5** (multiplicative): 6th-degree polynomial interpolation in
log space with exponential extrapolation, ``interpolate_code4``.
- **Code 6** (multiplicative): 6th-degree polynomial interpolation with
linear extrapolation in ratio space, ``interpolate_poly6`` applied to
``high/nominal`` and ``low/nominal``.
Args:
interp_code: Interpolation code (0-6) determining the mathematical approach
low_val: Low variation value (used when :math:`\theta < 0`)
high_val: High variation value (used when :math:`\theta \geq 0`)
boundary: Boundary value for switching between interpolation and extrapolation (typically 1.0)
nominal: Nominal value (baseline)
param_val: Parameter value :math:`\theta` (nuisance parameter)
Returns:
Interpolated contribution :math:`I_i(\theta_i)` to be added (additive modes)
or multiplied (multiplicative modes) with the result
Note:
The returned value interpretation depends on the interpolation code:
- Codes 0,2,3,4: Direct additive contribution
- Codes 1,5,6: Multiplicative factor (already with 1 subtracted)
"""
# The interpolations helpers assume a boundary of 1.0; ProcessNormalization
# / PiecewiseInterpolation always use this boundary.
del boundary
if interp_code == 0:
# Piecewise-linear interpolation/extrapolation (additive)
return cast(
TensorVar,
interpolate_code0(param_val, nominal, high_val, low_val) - nominal,
)
if interp_code == 1:
# Piecewise-exponential interpolation/extrapolation (multiplicative)
return cast(
TensorVar,
interpolate_code1(param_val, nominal, high_val, low_val) / nominal
- 1.0,
)
if interp_code in (2, 3):
# Parabolic interpolation with linear extrapolation (additive).
# ROOT converts code 3 to code 2. interpolate_parabolic matches ROOT's
# a*alpha^2 + b*alpha central region with continuous linear extensions.
return cast(
TensorVar,
interpolate_parabolic(param_val, nominal, high_val, low_val) - nominal,
)
if interp_code == 4:
# 6th-degree polynomial interpolation + linear extrapolation (additive)
return cast(
TensorVar,
interpolate_poly6(param_val, nominal, high_val, low_val) - nominal,
)
if interp_code == 5:
# 6th-degree polynomial in log space + exponential extrapolation
# (multiplicative)
return cast(
TensorVar,
interpolate_code4(param_val, nominal, high_val, low_val) / nominal
- 1.0,
)
# Code 6: 6th-degree polynomial + linear extrapolation in ratio space
# (multiplicative). Work in nominal=1 ratio space, as ROOT does.
one = pt.constant(1.0)
ratio_high = high_val / nominal
ratio_low = low_val / nominal
return cast(
TensorVar,
interpolate_poly6(param_val, one, ratio_high, ratio_low) - 1.0,
)
def _expression(self, context: Context) -> TensorVar:
r"""
Evaluate the interpolation function.
Implements ROOT's PiecewiseInterpolation algorithm following the mathematical
formulations described in the class docstring. The algorithm proceeds as:
1. Start with nominal value: :math:`\text{result} = \text{nominal}`
2. For each nuisance parameter :math:`\theta_i`, compute interpolation contribution :math:`I_i(\theta_i)`
3. Combine contributions based on interpolation mode:
- **Additive modes** (codes 0,2,3,4): :math:`\text{result} += I_i(\theta_i)`
- **Multiplicative modes** (codes 1,5,6): :math:`\text{result} \times= (1 + I_i(\theta_i))`
4. Apply positive definite constraint: :math:`\text{result} = \max(\text{result}, 0)` if requested
Args:
context: Mapping of names to pytensor variables containing:
- Nominal parameter (referenced by `nom`)
- High/low variation parameters (referenced by `high`/`low` lists)
- Nuisance parameters (referenced by `vars` list)
Returns:
PyTensor expression representing the interpolated result
Note:
The evaluation order ensures that all interpolation contributions are properly
combined according to their mathematical modes before applying constraints.
"""
# Start with nominal value
nominal = context[self.nom]
result = nominal
# Apply interpolation for each nuisance parameter
for i, var_name in enumerate(self.vars):
if (
i >= len(self.high)
or i >= len(self.low)
or i >= len(self.interpolationCodes)
):
log.warning(
"Parameter index %d exceeds variation lists for function %s",
i,
self.name,
)
continue
param_val = context[var_name]
low_val = context[self.low[i]]
high_val = context[self.high[i]]
interp_code = self.interpolationCodes[i]
# Calculate interpolated contribution
contribution = self._flexible_interp_single(
interp_code=interp_code,
low_val=low_val,
high_val=high_val,
boundary=1.0,
nominal=nominal,
param_val=param_val,
)
# Add contribution based on interpolation mode
if interp_code in [0, 2, 3, 4]: # Additive modes
result = result + contribution
else: # Multiplicative modes (1, 5, 6)
result = result * (1.0 + contribution)
# Apply positive definite constraint if requested
if self.positiveDefinite:
result = pt.maximum(result, 0.0)
return result
[docs]
class ProcessNormalizationFunction(Function):
r"""
Process normalization function with systematic variations.
Implements the CMS Combine ProcessNormalization class which computes
a normalization factor based on systematic variations. This matches
the actual CMS Combine implementation and JSON structure from combine files.
Mathematical formulation:
result = nominalValue * exp(symShift + asymShift) * otherFactors
where:
- symShift = sum(logKappa[i] * theta[i]) for symmetric variations
- asymShift = sum(_asym_interpolation(theta[i], kappa_sum[i], kappa_diff[i]))
for asymmetric variations with kappa_sum = logKappaHi + logKappaLo
and kappa_diff = logKappaHi - logKappaLo
- otherFactors = product of all additional multiplicative terms
Parameters:
name: Name of the function
nominalValue: Baseline normalization value (default 1.0)
thetaList: Names of symmetric variation nuisance parameters
logKappa: Log-kappa values for symmetric variations (optional, defaults to empty)
asymmThetaList: Names of asymmetric variation nuisance parameters
logAsymmKappa: List of [logKappaLo, logKappaHi] pairs for asymmetric variations (optional)
otherFactorList: Names of additional multiplicative factors
"""
type: Literal["CMS::process_normalization"] = Field(
default="CMS::process_normalization", repr=False
)
nominalValue: float = Field(
default=1.0, json_schema_extra={"preprocess": False}, repr=False
)
thetaList: list[str] = Field(default_factory=list, repr=False)
logKappa: list[float] = Field(
default_factory=list, json_schema_extra={"preprocess": False}, repr=False
)
asymmThetaList: list[str] = Field(default_factory=list, repr=False)
logAsymmKappa: list[list[float]] = Field(
default_factory=list, json_schema_extra={"preprocess": False}, repr=False
)
otherFactorList: list[str] = Field(default_factory=list, repr=False)
def _expression(self, context: Context) -> TensorVar:
"""
Evaluate the process normalization function.
Implements the full CMS Combine ProcessNormalization logic:
result = nominalValue * exp(symShift + asymShift) * otherFactors
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the normalization factor.
"""
# Start with nominal value
result = pt.constant(self.nominalValue)
# Symmetric variations: symShift = sum(logKappa[i] * theta[i])
symShift = pt.constant(0.0)
for i, theta_name in enumerate(self.thetaList):
theta = context[theta_name]
# Use provided logKappa value if available, otherwise assume 0.0 (no effect)
log_kappa = self.logKappa[i] if i < len(self.logKappa) else 0.0
symShift = symShift + log_kappa * theta
# Asymmetric variations: use asymmetric interpolation
asymShift = pt.constant(0.0)
for i, theta_name in enumerate(self.asymmThetaList):
theta = context[theta_name]
log_kappa_lo, log_kappa_hi = self.logAsymmKappa[i]
kappa_sum = log_kappa_hi + log_kappa_lo
kappa_diff = log_kappa_hi - log_kappa_lo
asymShift = asymShift + _asym_interpolation(theta, kappa_sum, kappa_diff)
# Apply exponential scaling: nominal * exp(symShift + asymShift)
result = result * pt.exp(symShift + asymShift)
# Multiply by additional factors
for factor_name in self.otherFactorList:
factor = context[factor_name]
result = result * factor
return cast(TensorVar, result)
class CMSAsymPowFunction(Function):
r"""
CMS AsymPow function implementation.
Implements CMS's AsymPow function which provides asymmetric power-law
variations for systematic uncertainties. Used in CMS combine for
asymmetric systematic variations.
.. math::
f(\theta; \kappa_{low}, \kappa_{high}) = \begin{cases}
\kappa_{low}^{-\theta}, & \text{if } \theta < 0 \\
\kappa_{high}^{\theta}, & \text{if } \theta \geq 0
\end{cases}
Parameters:
name: Name of the function
kappaLow: Low-side variation factor (used for θ < 0)
kappaHigh: High-side variation factor (used for θ ≥ 0)
theta: Parameter name for the nuisance parameter
"""
type: Literal["CMS::asympow"] = Field(default="CMS::asympow", repr=False)
kappaLow: str | float | int = Field(..., repr=False)
kappaHigh: str | float | int = Field(..., repr=False)
theta: str = Field(..., repr=False)
def _expression(self, context: Context) -> TensorVar:
"""
Evaluate the AsymPow function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the asymmetric power function.
"""
kappa_low = context[self._parameters["kappaLow"]]
kappa_high = context[self._parameters["kappaHigh"]]
theta = context[self._parameters["theta"]]
# AsymPow: kappaLow^(-theta) for theta < 0, kappaHigh^theta for theta >= 0
return cast(
TensorVar,
pt.switch(
theta < 0,
cast(TensorVar, pt.power(kappa_low, -theta)), # type: ignore[no-untyped-call]
cast(TensorVar, pt.power(kappa_high, theta)), # type: ignore[no-untyped-call]
),
)
class HistogramFunction(Function):
r"""
Histogram function implementation.
Implements a histogram-based function that provides piecewise constant
values based on bin lookup. Used for non-parametric functions and
data-driven backgrounds.
.. math::
f(x) = h_i \quad \text{where } x \in \text{bin}_i
Parameters:
name: Name of the function
data: histogram data with binning and contents
"""
type: Literal["histogram"] = Field(default="histogram", repr=False)
data: HistogramData = Field(
..., json_schema_extra={"preprocess": False}, repr=False
)
class RooRecursiveFractionFunction(Function):
r"""
ROOT RooRecursiveFraction function implementation.
Implements ROOT's ``RooRecursiveFraction`` which computes a recursive
fraction. Used for constructing a set of fractions that automatically sum
to one (e.g. ``RooAddPdf`` with ``recursiveFractions=True``).
For a coefficient list :math:`(a_0, a_1, \dots, a_{n-1})` ROOT's
``RooRecursiveFraction::evaluate()`` returns
.. math::
f = a_0 \prod_{i=1}^{n-1} (1 - a_i)
so that the leading coefficient is scaled by the complement of all the
remaining ones. A single coefficient returns :math:`a_0` itself
(empty product).
Example:
:math:`(0.2, 0.5, 0.5) \to 0.2 \cdot (1-0.5) \cdot (1-0.5) = 0.05`.
The non-recursive branch keeps the simple normalization
:math:`a_0 / \sum_j a_j` for the (rare) flat-fraction convention.
Parameters:
name: Name of the function
coefficients: List of coefficient parameter names
recursive: Whether to use recursive fraction calculation
"""
type: Literal["roorecursivefraction_dist"] = Field(
default="roorecursivefraction_dist", repr=False
)
coefficients: list[int | float | str] = Field(alias="list", repr=False)
recursive: bool = Field(default=True, repr=False)
def _expression(self, context: Context) -> TensorVar:
"""
Evaluate the recursive fraction function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the recursive fraction.
"""
if not self.coefficients:
return cast(TensorVar, pt.constant(0.0))
coeffs = self.get_parameter_list(context, "coefficients")
if not self.recursive:
# Simple normalization: a_0 / sum(all)
total = sum(coeffs)
return cast(TensorVar, coeffs[0] / total)
# Recursive fraction (ROOT RooRecursiveFraction::evaluate):
# f = a_0 * prod_{i>=1} (1 - a_i)
# A single coefficient yields a_0 (empty product).
result: TensorVar = coeffs[0]
for coeff in coeffs[1:]:
result = cast(TensorVar, result * (1.0 - coeff))
return result
# Registry for functions defined in this module
# NOTE: HistogramFunction is intentionally NOT registered here because it has no
# _expression() implementation. Workspaces referencing "histogram" will get
# the normal clean unknown-type validation error from the discriminated union.
functions: dict[str, type[Function]] = {
"sum": SumFunction,
"product": ProductFunction,
"generic_function": GenericFunction,
"interpolation": InterpolationFunction,
"CMS::process_normalization": ProcessNormalizationFunction,
"CMS::asympow": CMSAsymPowFunction,
"roorecursivefraction_dist": RooRecursiveFractionFunction,
}