Defining Custom Functions and Distributions¶
This guide shows how to create custom Functions and Distributions for use with pyhs3.
Both Functions and Distributions inherit from the Evaluable
base class, which provides
automatic parameter preprocessing to eliminate boilerplate validation code.
Quick Start¶
Creating a custom distribution or function requires:
Inherit from
Distribution
orFunction
Define fields with appropriate type annotations
Implement the
expression()
methodRegister your component (optional)
The Evaluable
base class automatically handles parameter processing based on your field type annotations.
Basic Distribution Example¶
Here’s a simple custom Gaussian distribution:
from typing import Literal
import pytensor.tensor as pt
from pyhs3.distributions.core import Distribution
from pyhs3.context import Context
from pyhs3.typing.aliases import TensorVar
class CustomGaussianDist(Distribution):
"""Custom Gaussian distribution implementation."""
type: Literal["custom_gaussian"] = "custom_gaussian"
mean: str | float # Parameter name or numeric value
sigma: str | float # Parameter name or numeric value
def expression(self, context: Context) -> TensorVar:
"""Evaluate the Gaussian PDF."""
# Get processed parameters from context
mean_val = context[self._parameters["mean"]]
sigma_val = context[self._parameters["sigma"]]
# Assume 'x' is the observable variable
x = context["x"] # Would come from domain/data definition
# Gaussian PDF formula
norm = 1.0 / (sigma_val * pt.sqrt(2 * pt.pi))
exp_term = pt.exp(-0.5 * ((x - mean_val) / sigma_val) ** 2)
return norm * exp_term
What happens automatically:
If
mean="mu_param"
, thenself._parameters["mean"] == "mu_param"
If
mean=1.5
, thenself._parameters["mean"] == "constant_myname_mean"
and a constant is createdThe
parameters
property returns all parameter names this distribution depends onThe
constants
property provides PyTensor constants for numeric values
Basic Function Example¶
Here’s a custom product function:
from typing import Literal
import pytensor.tensor as pt
from pyhs3.functions.core import Function
from pyhs3.context import Context
from pyhs3.typing.aliases import TensorVar
class WeightedProductFunction(Function):
"""Product function with weights."""
type: Literal["weighted_product"] = "weighted_product"
factors: list[str | float] # Mix of parameter names and values
weights: list[str | float] # Corresponding weights
def expression(self, context: Context) -> TensorVar:
"""Evaluate weighted product: prod(factor[i] ** weight[i])."""
# Get parameter lists in original order
factor_vals = self.get_parameter_list(context, "factors")
weight_vals = self.get_parameter_list(context, "weights")
result = pt.constant(1.0)
for factor, weight in zip(factor_vals, weight_vals, strict=True):
result = result * (factor**weight)
return result
What happens automatically:
factors=["param1", 2.0, "param2"]
creates indexed parameters:factors[0]
,factors[1]
,factors[2]
get_parameter_list(context, "factors")
reconstructs the original list from contextConstants are generated for numeric values:
constant_myname_factors[1]
for2.0
Automatic Parameter Processing¶
The Evaluable
base class automatically processes field annotations:
Supported Field Types:
class MyComponent(Evaluable):
# String fields -> direct parameter mapping
param_name: str # -> self._parameters["param_name"] = field_value
# Numeric fields -> generate constants
numeric_val: (
float # -> self._parameters["numeric_val"] = "constant_name_numeric_val"
)
# Union types -> runtime detection
mixed_param: str | float # -> string or constant depending on value
flexible: str | int | float # -> handles any combination
# Lists -> indexed processing
string_list: list[str] # -> param_name[0], param_name[1], ...
mixed_list: list[str | float] # -> mix of strings and generated constants
# Excluded fields
config_flag: bool # -> automatically excluded
internal_val: float = Field( # -> explicitly excluded
default=1.0, json_schema_extra={"preprocess": False}
)
Exclusion Rules:
Boolean fields are automatically excluded (not parameters)
Fields marked with
json_schema_extra={"preprocess": False}
are excludedBase class fields (
name
,type
) are excludedNone values are skipped
Advanced Examples¶
Complex Distribution with Mixed Parameters:
from pydantic import Field
class FlexibleDist(Distribution):
type: Literal["flexible"] = "flexible"
# Core parameters (will be processed)
location: str | float
scale: str | float
coefficients: list[str | float]
# Configuration (excluded from processing)
use_log_scale: bool = False
tolerance: float = Field(default=1e-6, json_schema_extra={"preprocess": False})
def expression(self, context: Context) -> TensorVar:
loc = context[self._parameters["location"]]
scale = context[self._parameters["scale"]]
# Get coefficient list
coeffs = self.get_parameter_list(context, "coefficients")
# Use configuration values directly
if self.use_log_scale:
scale = pt.exp(scale)
# ... implementation
return result
Function with Validation:
from pydantic import model_validator
class ValidatedFunction(Function):
type: Literal["validated"] = "validated"
inputs: list[str]
weights: list[float] = Field(json_schema_extra={"preprocess": False})
@model_validator(mode="after")
def validate_lengths(self) -> "ValidatedFunction":
"""Custom validation after auto-processing."""
if len(self.inputs) != len(self.weights):
raise ValueError("inputs and weights must have same length")
return self
def expression(self, context: Context) -> TensorVar:
# inputs were auto-processed into indexed parameters
input_vals = self.get_parameter_list(context, "inputs")
result = pt.constant(0.0)
for inp, weight in zip(input_vals, self.weights, strict=True):
result = result + inp * weight
return result
Registration and Discovery¶
Option 1: Manual Registration
Add your components to the appropriate registry:
# For distributions
from pyhs3.distributions.core import registered_distributions
registered_distributions["custom_gaussian"] = CustomGaussianDist
# For functions
from pyhs3.functions.core import registered_functions
registered_functions["weighted_product"] = WeightedProductFunction
Option 2: Plugin System (if available)
Check if pyhs3 supports a plugin entry point system for automatic discovery.
Usage in Workspaces¶
Once defined, your custom components work like built-in ones:
# In JSON/YAML workspace definition
{
"distributions": [
{
"name": "signal_pdf",
"type": "custom_gaussian",
"mean": "mu_signal", # Parameter reference
"sigma": 0.1, # Numeric constant
}
],
"functions": [
{
"name": "weighted_norm",
"type": "weighted_product",
"factors": ["norm1", "norm2", 1.5], # Mixed types
"weights": [2.0, 1.0, 0.5], # Config values
}
],
}
# In Python
workspace = Workspace.from_file("my_workspace.json")
model = workspace.model()
# Your components are automatically instantiated and available
Error Handling and Debugging¶
Common Issues:
Unsupported field type:
class BadDist(Distribution): type: Literal["bad"] = "bad" complex_field: dict # Not supported!
Fix: Add
json_schema_extra={"preprocess": False}
or use supported types.Missing implementation:
dist = MyDist(name="test", param="value") # RuntimeError: Component type my_type expression not implemented
Fix: Implement the
expression()
method.Context key errors:
def expression(self, context): return context["missing_param"] # KeyError!
Fix: Use
self._parameters
orself.get_parameter_list()
to get correct keys.
Debugging Tips:
# Inspect what was auto-processed
dist = MyDist(name="test", param1="alpha", param2=1.5)
print("Parameters:", dist.parameters) # All parameter names
print("Internal mapping:", dist._parameters) # Field -> parameter mapping
print("Constants:", list(dist.constants.keys())) # Generated constant names
print("Constant values:", dist._constants_values) # Stored numeric values
Best Practices¶
Use descriptive type literals for easy identification
Document your components with clear docstrings and examples
Handle edge cases in your
expression()
methodTest thoroughly with different parameter combinations
Consider performance - PyTensor operations should be efficient
Follow naming conventions - use clear, descriptive field names
Validate inputs when auto-processing isn’t sufficient
The automatic parameter processing handles most common cases, letting you focus on the mathematical implementation rather than parameter management boilerplate.