Source code for pyhs3.functions
"""
HS3 Functions implementation.
Provides classes for handling HS3 functions including product functions,
generic functions with mathematical expressions, and interpolation functions.
"""
from __future__ import annotations
import logging
from collections.abc import Iterator
from typing import Annotated, Any
from pydantic import (
Field,
PrivateAttr,
RootModel,
)
from pyhs3.exceptions import custom_error_msg
from pyhs3.functions import standard
from pyhs3.functions.core import Function
log = logging.getLogger(__name__)
SumFunction = standard.SumFunction
ProductFunction = standard.ProductFunction
GenericFunction = standard.GenericFunction
InterpolationFunction = standard.InterpolationFunction
ProcessNormalizationFunction = standard.ProcessNormalizationFunction
CMSAsymPowFunction = standard.CMSAsymPowFunction
HistogramFunction = standard.HistogramFunction
RooRecursiveFractionFunction = standard.RooRecursiveFractionFunction
# Combine all function registries
registered_functions: dict[str, type[Function]] = {
**standard.functions,
}
# Type alias for all function types using discriminated union
FunctionType = Annotated[
SumFunction
| ProductFunction
| GenericFunction
| InterpolationFunction
| ProcessNormalizationFunction
| CMSAsymPowFunction
| HistogramFunction
| RooRecursiveFractionFunction,
Field(discriminator="type"),
]
[docs]
class Functions(RootModel[list[FunctionType]]):
"""
Collection of HS3 functions for parameter computation.
Manages a set of function instances that compute parameter values
based on other parameters. Functions can be products, generic
mathematical expressions, or interpolation functions.
Provides dict-like access to functions by name and handles
function creation from configuration dictionaries.
Attributes:
funcs: Mapping from function names to Function instances.
"""
root: Annotated[
list[FunctionType],
custom_error_msg(
{
"union_tag_invalid": "Unknown function type '{tag}' does not match any of the expected functions: {expected_tags}"
}
),
] = Field(default_factory=list)
_map: dict[str, Function] = PrivateAttr(default_factory=dict)
def model_post_init(self, __context: Any, /) -> None:
"""Initialize computed collections after Pydantic validation."""
self._map = {func.name: func for func in self.root}
def __getitem__(self, item: str) -> Function:
return self._map[item]
def __contains__(self, item: str) -> bool:
return item in self._map
def __iter__(self) -> Iterator[Function]: # type: ignore[override] # https://github.com/pydantic/pydantic/issues/8872
return iter(self.root)
def __len__(self) -> int:
return len(self.root)