pyhs3.jaxify¶
- pyhs3.jaxify(output, *, inputs=None)[source]¶
Convert a PyTensor expression into a JAX-callable
JaxifiedGraph.- Parameters:
output (
TensorVariable[TensorType,Apply[Any]]) – The PyTensor output variable to compile.inputs (
Sequence[TensorVariable[TensorType,Apply[Any]]] |None) – Explicit list of input variables. IfNone, the full set of graph inputs (variables with no owner, i.e. symbolic parameters) is discovered automatically viaexplicit_graph_inputs.
- Returns:
Wrapper exposing the compiled JAX function plus input metadata.
- Return type:
Examples
>>> import math >>> import pytensor.tensor as pt >>> x = pt.scalar("x") >>> mu = pt.scalar("mu") >>> sigma = pt.scalar("sigma") >>> pdf = pt.exp(-0.5 * ((x - mu) / sigma) ** 2) / ( ... sigma * pt.sqrt(pt.constant(2 * math.pi, dtype="float64")) ... ) >>> from pyhs3.transpile import jaxify >>> jg = jaxify(pdf) >>> float(jg(x=0.0, mu=0.0, sigma=1.0)[0]) 0.3989422804014327