"""
HS3 Functions implementation.
Provides classes for handling HS3 functions including product functions,
generic functions with mathematical expressions, and interpolation functions.
"""
from __future__ import annotations
import logging
from enum import IntEnum
from typing import Annotated, Literal, cast
import pytensor.tensor as pt
import sympy as sp
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
model_validator,
)
from pyhs3.context import Context
from pyhs3.data import Axis
from pyhs3.exceptions import custom_error_msg
from pyhs3.functions.core import Function
from pyhs3.generic_parse import analyze_sympy_expr, parse_expression, sympy_to_pytensor
from pyhs3.typing.aliases import TensorVar
log = logging.getLogger(__name__)
def _asym_interpolation(
theta: TensorVar, kappa_sum: float, kappa_diff: float
) -> TensorVar:
"""
Implement asymmetric interpolation for ProcessNormalization.
Based on the jaxfit implementation:
https://github.com/nsmith-/jaxfit/blob/8479cd73e733ba35462287753fab44c0c560037b/src/jaxfit/roofit/combine.py#L197
and CMS Combine logic.
Uses polynomial interpolation for |theta| < 1 and linear extrapolation beyond.
Args:
theta: The nuisance parameter value
kappa_sum: logKappaHi + logKappaLo
kappa_diff: logKappaHi - logKappaLo
Returns:
The interpolated shift value
"""
# Polynomial interpolation for |theta| < 1
# Polynomial: (3*theta^4 - 10*theta^2 + 15) / 8
theta_sq = theta * theta
theta_quad = theta_sq * theta_sq
poly_result = (3.0 * theta_quad - 10.0 * theta_sq + 15.0) / 8.0
# Linear extrapolation for |theta| >= 1
linear_result = pt.abs(theta)
# Choose between polynomial and linear based on |theta|
abs_theta = pt.abs(theta)
smooth_function = cast(
TensorVar, pt.switch(abs_theta < 1.0, poly_result, linear_result)
)
# Final asymmetric interpolation formula
return cast(TensorVar, 0.5 * (kappa_diff * theta + kappa_sum * smooth_function))
[docs]
class SumFunction(Function):
"""Sum function that adds summands together."""
type: Literal["sum"] = Field(default="sum", repr=False)
summands: list[str] = Field(..., repr=False)
def expression(self, context: Context) -> TensorVar:
"""
Evaluate the sum function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the sum of all summands.
"""
if not self.summands:
return pt.constant(0.0)
result = context[self.summands[0]]
for summand in self.summands[1:]:
result = result + context[summand]
return result
[docs]
class ProductFunction(Function):
"""Product function that multiplies factors together."""
type: Literal["product"] = Field(default="product", repr=False)
factors: list[int | float | str] = Field(..., repr=False)
def expression(self, context: Context) -> TensorVar:
"""
Evaluate the product function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the product of all factors.
"""
if not self.factors:
return pt.constant(1.0)
# Get list of factors using flattened parameter keys
factor_values = self.get_parameter_list(context, "factors")
result = factor_values[0]
for factor_value in factor_values[1:]:
result = result * factor_value
return result
[docs]
class GenericFunction(Function):
"""
Generic function with custom mathematical expression.
Evaluates arbitrary mathematical expressions using SymPy parsing
and PyTensor computation. Supports common mathematical operations
including arithmetic, trigonometric, exponential, and logarithmic functions.
The expression is parsed once during initialization and converted to
a PyTensor computation graph for efficient evaluation.
Parameters:
name (str): Name of the function.
expression (str): Mathematical expression string to evaluate.
Examples:
>>> func = GenericFunction(name="quadratic", expression="x**2 + 2*x + 1")
>>> func = GenericFunction(name="sinusoid", expression="sin(x) * exp(-t)")
"""
model_config = ConfigDict(arbitrary_types_allowed=True, serialize_by_alias=True)
type: Literal["generic_function"] = Field(default="generic_function", repr=False)
expression_str: str = Field(alias="expression", repr=False)
_sympy_expr: sp.Expr = PrivateAttr(default=None)
_dependent_vars: list[str] = PrivateAttr(default_factory=list)
@model_validator(mode="after")
def setup_expression(self) -> GenericFunction:
"""Parse and analyze the expression during initialization."""
# Parse and analyze the expression during initialization
self._sympy_expr = parse_expression(self.expression_str)
# Analyze the expression to determine dependencies
analysis = analyze_sympy_expr(self._sympy_expr)
independent_vars = [str(symbol) for symbol in analysis["independent_vars"]]
self._dependent_vars = [str(symbol) for symbol in analysis["dependent_vars"]]
# Set parameters based on the analyzed expression
self._parameters = {var: var for var in independent_vars}
return self
def expression(self, context: Context) -> TensorVar:
"""
Evaluate the generic function expression.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the parsed mathematical expression.
"""
# Get required variables using the parameters determined during initialization
variables = [context[name] for name in self._parameters.values()]
# Convert using the pre-parsed sympy expression
return sympy_to_pytensor(self._sympy_expr, variables)
class InterpolationCode(IntEnum):
"""
Enumeration of interpolation codes for systematic variations.
Defines the different interpolation methods used by InterpolationFunction
for systematic uncertainty variations. Each code represents a different
mathematical approach to interpolating between nominal, low, and high values.
"""
LIN_LIN_ADD = 0
EXP_EXP_MUL = 1
EXP_LIN_ADD = 2
EXP_MIX_ADD = 3
POL_LIN_ADD = 4
POL_EXP_MUL = 5
POL_LIN_MUL = 6
[docs]
class InterpolationFunction(Function):
r"""
Piecewise interpolation function implementation.
Implements ROOT's PiecewiseInterpolation logic to morph between nominal
and variation distributions based on nuisance parameter values.
Supports multiple interpolation codes (0-6) for different mathematical approaches.
Mathematical Formulations:
For **additive** interpolation modes (codes 0, 2, 3, 4):
.. math::
\text{result} = \text{nominal} + \sum_i I_i(\theta_i; \text{low}_i, \text{nominal}, \text{high}_i)
For **multiplicative** interpolation modes (codes 1, 5, 6):
.. math::
\text{result} = \text{nominal} \times \prod_i [1 + I_i(\theta_i; \text{low}_i/\text{nominal}, 1, \text{high}_i/\text{nominal})]
Parameters:
name: Name of the function
high: High variation parameter names
low: Low variation parameter names
nom: Nominal parameter name
interpolationCodes: Interpolation method codes (0-6)
positiveDefinite: Whether function should be positive definite
vars: Variable names this function depends on (nuisance parameters)
"""
model_config = ConfigDict(use_enum_values=True)
type: Literal["interpolation"] = Field(default="interpolation", repr=False)
high: list[str] = Field(..., repr=False)
low: list[str] = Field(..., repr=False)
nom: str = Field(..., repr=False)
interpolationCodes: Annotated[
list[InterpolationCode],
custom_error_msg(
{
"enum": "Unknown interpolation code {input} in function '{name}'. Valid codes are {expected}."
}
),
] = Field(..., repr=False)
positiveDefinite: bool = Field(..., repr=False)
vars: list[str] = Field(..., repr=False)
def _flexible_interp_single(
self,
interp_code: int,
low_val: TensorVar,
high_val: TensorVar,
boundary: float,
nominal: TensorVar,
param_val: TensorVar,
) -> TensorVar:
r"""
Implement flexible interpolation for a single parameter.
Based on ROOT's flexibleInterpSingle method with support for
interpolation codes 0-6. This method computes the interpolation
contribution :math:`I_i(\theta_i)` for a single nuisance parameter.
Args:
interp_code: Interpolation code (0-6) determining the mathematical approach
low_val: Low variation value (used when :math:`\theta < 0`)
high_val: High variation value (used when :math:`\theta \geq 0`)
boundary: Boundary value for switching between interpolation and extrapolation (typically 1.0)
nominal: Nominal value (baseline)
param_val: Parameter value :math:`\theta` (nuisance parameter)
Returns:
Interpolated contribution :math:`I_i(\theta_i)` to be added (additive modes)
or multiplied (multiplicative modes) with the result
Note:
The returned value interpretation depends on the interpolation code:
- Codes 0,2,3,4: Direct additive contribution
- Codes 1,5,6: Multiplicative factor (subtract 1 before use)
"""
# Codes 0, 2, 3, 4 are additive modes
# Codes 1, 5, 6 are multiplicative modes
if interp_code == 0:
# Linear interpolation/extrapolation (additive)
return cast(
TensorVar,
pt.switch(
param_val >= 0,
param_val * (high_val - nominal),
param_val * (nominal - low_val),
),
)
if interp_code == 1:
# Exponential interpolation/extrapolation (multiplicative)
ratio_high = high_val / nominal
ratio_low = low_val / nominal
return cast(
TensorVar,
pt.switch(
param_val >= 0,
cast(TensorVar, pt.power(ratio_high, param_val)) - 1.0, # type: ignore[no-untyped-call]
cast(TensorVar, pt.power(ratio_low, -param_val)) - 1.0, # type: ignore[no-untyped-call]
),
)
if interp_code == 2:
# Exponential interpolation, linear extrapolation (additive)
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) <= boundary,
# Exponential interpolation for |theta| <= 1
pt.switch(
param_val >= 0,
(high_val - nominal) * (pt.exp(param_val) - 1),
(nominal - low_val) * (pt.exp(-param_val) - 1),
),
# Linear extrapolation for |theta| > 1
pt.switch(
param_val >= 0,
(high_val - nominal)
* (
pt.exp(boundary)
- 1
+ (param_val - boundary) * pt.exp(boundary)
),
(nominal - low_val)
* (
pt.exp(boundary)
- 1
+ (-param_val - boundary) * pt.exp(boundary)
),
),
),
)
if interp_code == 3:
# Similar to code 2 but with different extrapolation
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) <= boundary,
# Exponential interpolation for |theta| <= 1
pt.switch(
param_val >= 0,
(high_val - nominal) * (pt.exp(param_val) - 1),
(nominal - low_val) * (pt.exp(-param_val) - 1),
),
# Linear extrapolation for |theta| > 1
pt.switch(
param_val >= 0,
param_val * (high_val - nominal),
param_val * (nominal - low_val),
),
),
)
if interp_code == 4:
# Polynomial interpolation + linear extrapolation (additive)
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) >= boundary,
# Linear extrapolation for |theta| >= 1
pt.switch(
param_val >= 0,
param_val * (high_val - nominal),
param_val * (nominal - low_val),
),
# 6th order polynomial interpolation for |theta| < 1
pt.switch(
param_val >= 0,
param_val
* (high_val - nominal)
* (
1
+ param_val * param_val * (-3 + param_val * param_val) / 16
),
param_val
* (nominal - low_val)
* (
1
+ param_val * param_val * (-3 + param_val * param_val) / 16
),
),
),
)
if interp_code == 5:
# Polynomial interpolation + exponential extrapolation (multiplicative)
ratio_high = high_val / nominal
ratio_low = low_val / nominal
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) >= boundary,
# Exponential extrapolation for |theta| >= 1
pt.switch(
param_val >= 0,
cast(TensorVar, pt.power(ratio_high, param_val)) - 1.0, # type: ignore[no-untyped-call]
cast(TensorVar, pt.power(ratio_low, -param_val)) - 1.0, # type: ignore[no-untyped-call]
),
# 6th order polynomial interpolation for |theta| < 1
pt.switch(
param_val >= 0,
param_val
* (ratio_high - 1.0)
* (
1
+ param_val * param_val * (-3 + param_val * param_val) / 16
),
param_val
* (ratio_low - 1.0)
* (
1
+ param_val * param_val * (-3 + param_val * param_val) / 16
),
),
),
)
# Code 6: Polynomial interpolation + linear extrapolation (multiplicative)
ratio_high = high_val / nominal
ratio_low = low_val / nominal
return cast(
TensorVar,
pt.switch(
pt.abs(param_val) >= boundary,
# Linear extrapolation for |theta| >= 1
pt.switch(
param_val >= 0,
param_val * (ratio_high - 1.0),
param_val * (ratio_low - 1.0),
),
# 6th order polynomial interpolation for |theta| < 1
pt.switch(
param_val >= 0,
param_val
* (ratio_high - 1.0)
* (1 + param_val * param_val * (-3 + param_val * param_val) / 16),
param_val
* (ratio_low - 1.0)
* (1 + param_val * param_val * (-3 + param_val * param_val) / 16),
),
),
)
def expression(self, context: Context) -> TensorVar:
r"""
Evaluate the interpolation function.
Implements ROOT's PiecewiseInterpolation algorithm following the mathematical
formulations described in the class docstring. The algorithm proceeds as:
1. Start with nominal value: :math:`\text{result} = \text{nominal}`
2. For each nuisance parameter :math:`\theta_i`, compute interpolation contribution :math:`I_i(\theta_i)`
3. Combine contributions based on interpolation mode:
- **Additive modes** (codes 0,2,3,4): :math:`\text{result} += I_i(\theta_i)`
- **Multiplicative modes** (codes 1,5,6): :math:`\text{result} \times= (1 + I_i(\theta_i))`
4. Apply positive definite constraint: :math:`\text{result} = \max(\text{result}, 0)` if requested
Args:
context: Mapping of names to pytensor variables containing:
- Nominal parameter (referenced by `nom`)
- High/low variation parameters (referenced by `high`/`low` lists)
- Nuisance parameters (referenced by `vars` list)
Returns:
PyTensor expression representing the interpolated result
Note:
The evaluation order ensures that all interpolation contributions are properly
combined according to their mathematical modes before applying constraints.
"""
# Start with nominal value
nominal = context[self.nom]
result = nominal
# Apply interpolation for each nuisance parameter
for i, var_name in enumerate(self.vars):
if (
i >= len(self.high)
or i >= len(self.low)
or i >= len(self.interpolationCodes)
):
log.warning(
"Parameter index %d exceeds variation lists for function %s",
i,
self.name,
)
continue
param_val = context[var_name]
low_val = context[self.low[i]]
high_val = context[self.high[i]]
interp_code = self.interpolationCodes[i]
# Calculate interpolated contribution
contribution = self._flexible_interp_single(
interp_code=interp_code,
low_val=low_val,
high_val=high_val,
boundary=1.0,
nominal=nominal,
param_val=param_val,
)
# Add contribution based on interpolation mode
if interp_code in [0, 2, 3, 4]: # Additive modes
result = result + contribution
else: # Multiplicative modes (1, 5, 6)
result = result * (1.0 + contribution)
# Apply positive definite constraint if requested
if self.positiveDefinite:
result = pt.maximum(result, 0.0)
return result
[docs]
class ProcessNormalizationFunction(Function):
r"""
Process normalization function with systematic variations.
Implements the CMS Combine ProcessNormalization class which computes
a normalization factor based on systematic variations. This matches
the actual CMS Combine implementation and JSON structure from combine files.
Mathematical formulation:
result = nominalValue * exp(symShift + asymShift) * otherFactors
where:
- symShift = sum(logKappa[i] * theta[i]) for symmetric variations
- asymShift = sum(_asym_interpolation(theta[i], kappa_sum[i], kappa_diff[i]))
for asymmetric variations with kappa_sum = logKappaHi + logKappaLo
and kappa_diff = logKappaHi - logKappaLo
- otherFactors = product of all additional multiplicative terms
Parameters:
name: Name of the function
nominalValue: Baseline normalization value (default 1.0)
thetaList: Names of symmetric variation nuisance parameters
logKappa: Log-kappa values for symmetric variations (optional, defaults to empty)
asymmThetaList: Names of asymmetric variation nuisance parameters
logAsymmKappa: List of [logKappaLo, logKappaHi] pairs for asymmetric variations (optional)
otherFactorList: Names of additional multiplicative factors
"""
type: Literal["CMS::process_normalization"] = Field(
default="CMS::process_normalization", repr=False
)
nominalValue: float = Field(
default=1.0, json_schema_extra={"preprocess": False}, repr=False
)
thetaList: list[str] = Field(default_factory=list, repr=False)
logKappa: list[float] = Field(
default_factory=list, json_schema_extra={"preprocess": False}, repr=False
)
asymmThetaList: list[str] = Field(default_factory=list, repr=False)
logAsymmKappa: list[list[float]] = Field(
default_factory=list, json_schema_extra={"preprocess": False}, repr=False
)
otherFactorList: list[str] = Field(default_factory=list, repr=False)
def expression(self, context: Context) -> TensorVar:
"""
Evaluate the process normalization function.
Implements the full CMS Combine ProcessNormalization logic:
result = nominalValue * exp(symShift + asymShift) * otherFactors
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the normalization factor.
"""
# Start with nominal value
result = pt.constant(self.nominalValue)
# Symmetric variations: symShift = sum(logKappa[i] * theta[i])
symShift = pt.constant(0.0)
for i, theta_name in enumerate(self.thetaList):
theta = context[theta_name]
# Use provided logKappa value if available, otherwise assume 0.0 (no effect)
log_kappa = self.logKappa[i] if i < len(self.logKappa) else 0.0
symShift = symShift + log_kappa * theta
# Asymmetric variations: use asymmetric interpolation
asymShift = pt.constant(0.0)
for i, theta_name in enumerate(self.asymmThetaList):
theta = context[theta_name]
log_kappa_lo, log_kappa_hi = self.logAsymmKappa[i]
kappa_sum = log_kappa_hi + log_kappa_lo
kappa_diff = log_kappa_hi - log_kappa_lo
asymShift = asymShift + _asym_interpolation(theta, kappa_sum, kappa_diff)
# Apply exponential scaling: nominal * exp(symShift + asymShift)
result = result * pt.exp(symShift + asymShift)
# Multiply by additional factors
for factor_name in self.otherFactorList:
factor = context[factor_name]
result = result * factor
return cast(TensorVar, result)
class CMSAsymPowFunction(Function):
r"""
CMS AsymPow function implementation.
Implements CMS's AsymPow function which provides asymmetric power-law
variations for systematic uncertainties. Used in CMS combine for
asymmetric systematic variations.
.. math::
f(\theta; \kappa_{low}, \kappa_{high}) = \begin{cases}
\kappa_{low}^{-\theta}, & \text{if } \theta < 0 \\
\kappa_{high}^{\theta}, & \text{if } \theta \geq 0
\end{cases}
Parameters:
name: Name of the function
kappaLow: Low-side variation factor (used for θ < 0)
kappaHigh: High-side variation factor (used for θ ≥ 0)
theta: Parameter name for the nuisance parameter
"""
type: Literal["CMS::asympow"] = Field(default="CMS::asympow", repr=False)
kappaLow: str | float | int = Field(..., repr=False)
kappaHigh: str | float | int = Field(..., repr=False)
theta: str = Field(..., repr=False)
def expression(self, context: Context) -> TensorVar:
"""
Evaluate the AsymPow function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the asymmetric power function.
"""
kappa_low = context[self._parameters["kappaLow"]]
kappa_high = context[self._parameters["kappaHigh"]]
theta = context[self._parameters["theta"]]
# AsymPow: kappaLow^(-theta) for theta < 0, kappaHigh^theta for theta >= 0
return cast(
TensorVar,
pt.switch(
theta < 0,
cast(TensorVar, pt.power(kappa_low, -theta)), # type: ignore[no-untyped-call]
cast(TensorVar, pt.power(kappa_high, theta)), # type: ignore[no-untyped-call]
),
)
class HistogramData(BaseModel):
"""
Histogram data implementation for the HistogramFunction.
Parameters:
axes: list of Axis used to describe the binning
contents: list of bin content parameter values
"""
model_config = ConfigDict()
axes: list[Axis] = Field(..., repr=False)
contents: list[float] = Field(..., repr=False)
class HistogramFunction(Function):
r"""
Histogram function implementation.
Implements a histogram-based function that provides piecewise constant
values based on bin lookup. Used for non-parametric functions and
data-driven backgrounds.
.. math::
f(x) = h_i \quad \text{where } x \in \text{bin}_i
Parameters:
name: Name of the function
data: histogram data with binning and contents
"""
type: Literal["histogram"] = Field(default="histogram", repr=False)
data: HistogramData = Field(
..., json_schema_extra={"preprocess": False}, repr=False
)
class RooRecursiveFractionFunction(Function):
r"""
ROOT RooRecursiveFraction function implementation.
Implements ROOT's RooRecursiveFraction which computes fractions recursively.
Used for constrained fraction calculations where fractions must sum to 1.
.. math::
f_i = \frac{a_i}{\sum_{j=i}^n a_j}
where the recursive fractions ensure proper normalization.
Parameters:
name: Name of the function
coefficients: List of coefficient parameter names
recursive: Whether to use recursive fraction calculation
"""
type: Literal["roorecursivefraction_dist"] = Field(
default="roorecursivefraction_dist", repr=False
)
coefficients: list[int | float | str] = Field(alias="list", repr=False)
recursive: bool = Field(default=True, repr=False)
def expression(self, context: Context) -> TensorVar:
"""
Evaluate the recursive fraction function.
Args:
context: Mapping of names to PyTensor variables.
Returns:
TensorVar: PyTensor expression representing the recursive fraction.
"""
if not self.coefficients:
return cast(TensorVar, pt.constant(0.0))
coeffs = self.get_parameter_list(context, "coefficients")
if not self.recursive:
# Simple normalization
total = sum(coeffs)
return cast(TensorVar, coeffs[0] / total)
# Recursive fraction calculation
# For first coefficient: a_0 / (a_0 + a_1 + ... + a_n)
# For i-th coefficient: a_i / (a_i + a_{i+1} + ... + a_n) * (1 - sum of previous fractions)
if len(coeffs) == 1:
return cast(TensorVar, pt.constant(1.0))
# Calculate the first recursive fraction: a_0 / sum(all)
total_sum = sum(coeffs)
first_fraction = coeffs[0] / total_sum
return cast(TensorVar, first_fraction)
# Registry for functions defined in this module
functions: dict[str, type[Function]] = {
"sum": SumFunction,
"product": ProductFunction,
"generic_function": GenericFunction,
"interpolation": InterpolationFunction,
"CMS::process_normalization": ProcessNormalizationFunction,
"CMS::asympow": CMSAsymPowFunction,
"histogram": HistogramFunction,
"roorecursivefraction_dist": RooRecursiveFractionFunction,
}