pyhs3.jaxify

pyhs3.jaxify(output, *, inputs=None)[source]

Convert a PyTensor expression into a JAX-callable JaxifiedGraph.

Parameters:
  • output (TensorVariable[TensorType, Apply[Any]]) – The PyTensor output variable to compile.

  • inputs (Sequence[TensorVariable[TensorType, Apply[Any]]] | None) – Explicit list of input variables. If None, the full set of graph inputs (variables with no owner, i.e. symbolic parameters) is discovered automatically via explicit_graph_inputs.

Returns:

Wrapper exposing the compiled JAX function plus input metadata.

Return type:

JaxifiedGraph

Examples

>>> import math
>>> import pytensor.tensor as pt
>>> x = pt.scalar("x")
>>> mu = pt.scalar("mu")
>>> sigma = pt.scalar("sigma")
>>> pdf = pt.exp(-0.5 * ((x - mu) / sigma) ** 2) / (
...     sigma * pt.sqrt(pt.constant(2 * math.pi, dtype="float64"))
... )
>>> from pyhs3.transpile import jaxify
>>> jg = jaxify(pdf)
>>> float(jg(x=0.0, mu=0.0, sigma=1.0)[0])
0.3989422804014327