pyhs3.transpileΒΆ
JAX transpilation helpers for pyhs3.
Provides jaxify() and JaxifiedGraph for converting any PyTensor
expression into a callable JAX function suitable for use with JAX-based
optimizers (e.g. optimistix).
Requires the jax optional extra:
pip install pyhs3[jax]
which pulls in pytensor[jax] and, transitively, JAX itself.
Functions
|
Convert a PyTensor expression into a JAX-callable |
Classes
|
A JAX-callable wrapper around a compiled PyTensor expression. |