pyhs3.transpileΒΆ

JAX transpilation helpers for pyhs3.

Provides jaxify() and JaxifiedGraph for converting any PyTensor expression into a callable JAX function suitable for use with JAX-based optimizers (e.g. optimistix).

Requires the jax optional extra:

pip install pyhs3[jax]

which pulls in pytensor[jax] and, transitively, JAX itself.

Functions

jaxify(output, *[, inputs])

Convert a PyTensor expression into a JAX-callable JaxifiedGraph.

Classes

JaxifiedGraph(inputs, input_names, fn)

A JAX-callable wrapper around a compiled PyTensor expression.