Source code for pyhs3.transpile

"""
JAX transpilation helpers for pyhs3.

Provides :func:`jaxify` and :class:`JaxifiedGraph` for converting any PyTensor
expression into a callable JAX function suitable for use with JAX-based
optimizers (e.g. ``optimistix``).

Requires the ``jax`` optional extra::

    pip install pyhs3[jax]

which pulls in ``pytensor[jax]`` and, transitively, JAX itself.
"""

from __future__ import annotations

from collections import Counter
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import cast

from pytensor.compile import mode as _ptmode
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.traversal import explicit_graph_inputs

from pyhs3.typing.aliases import TensorVar

# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class JaxifiedGraph: """A JAX-callable wrapper around a compiled PyTensor expression. Produced by :func:`jaxify`. Supports keyword-argument calls (the primary interface for dict-pytree-based optimizers like optimistix or everwillow) and positional calls via :meth:`call_positional`. Attributes: inputs: Tuple of PyTensor input variables, in evaluation order. input_names: Names of those variables (same order as ``inputs``). fn: The raw JAX callable returned by ``jax_funcify``. """ inputs: tuple[TensorVar, ...] input_names: tuple[str, ...] fn: Callable[..., tuple] # type: ignore[type-arg] def __call__(self, **kwargs: object) -> object: """Call by keyword argument — pure passthrough to :attr:`fn`. ``jax_funcify`` generates a function whose parameter names match the original PyTensor variable names, so kwargs are forwarded directly. Python itself raises ``TypeError`` for missing or unexpected names. The typical usage pattern with optimistix or everwillow is:: @jax.jit def nll(free_params): # free_params is a dict pytree all_params = {**free_params, **fixed_params} return -2 * jnp.log(jg(**all_params)[0]) Parameters ---------- **kwargs: One value per input name, as JAX arrays or Python scalars. Returns ------- Whatever :attr:`fn` returns (typically a 1-tuple of JAX arrays). """ return self.fn(**kwargs) def call_positional(self, *args: object) -> object: """Call with positional arguments — pure passthrough to :attr:`fn`. Parameters ---------- *args: Values in the same order as ``self.input_names``. Returns ------- Whatever :attr:`fn` returns (typically a 1-tuple of JAX arrays). """ return self.fn(*args)
[docs] def jaxify( output: TensorVar, *, inputs: Sequence[TensorVar] | None = None, ) -> JaxifiedGraph: """Convert a PyTensor expression into a JAX-callable :class:`JaxifiedGraph`. Parameters ---------- output: The PyTensor output variable to compile. inputs: Explicit list of input variables. If ``None``, the full set of graph inputs (variables with no owner, i.e. symbolic parameters) is discovered automatically via ``explicit_graph_inputs``. Returns ------- JaxifiedGraph Wrapper exposing the compiled JAX function plus input metadata. Examples -------- >>> import math >>> import pytensor.tensor as pt >>> x = pt.scalar("x") >>> mu = pt.scalar("mu") >>> sigma = pt.scalar("sigma") >>> pdf = pt.exp(-0.5 * ((x - mu) / sigma) ** 2) / ( ... sigma * pt.sqrt(pt.constant(2 * math.pi, dtype="float64")) ... ) >>> from pyhs3.transpile import jaxify >>> jg = jaxify(pdf) >>> float(jg(x=0.0, mu=0.0, sigma=1.0)[0]) 0.3989422804014327 """ try: from pytensor.link.jax.dispatch.basic import jax_funcify # noqa: PLC0415 except ImportError as exc: msg = "pyhs3.transpile requires JAX. Install with `pip install pyhs3[jax]`." raise ImportError(msg) from exc if inputs is None: # Filter out unnamed nodes (constants, shared vars without explicit names) # so that every entry in input_names is a non-None string. inputs = cast( list[TensorVar], [v for v in explicit_graph_inputs([output]) if v.name is not None], ) named_inputs: tuple[TensorVar, ...] = tuple(inputs) raw_names = tuple(v.name for v in named_inputs) if any(name is None for name in raw_names): msg = ( "All inputs must be named for kwargs-based dispatch. " "Provide named TensorVariables or use call_positional()." ) raise ValueError(msg) names: tuple[str, ...] = tuple(cast(str, name) for name in raw_names) duplicates = sorted(name for name, n in Counter(names).items() if n > 1) if duplicates: msg = f"Input names must be unique; duplicates found: {duplicates}" raise ValueError(msg) fgraph = FunctionGraph(inputs=list(inputs), outputs=[output], clone=True) _ptmode.JAX.optimizer.rewrite(fgraph) fn = jax_funcify(fgraph) return JaxifiedGraph(named_inputs, names, fn)