pyhs3.JaxifiedGraph

class pyhs3.JaxifiedGraph(inputs, input_names, fn)[source]

A JAX-callable wrapper around a compiled PyTensor expression.

Produced by jaxify(). Supports keyword-argument calls (the primary interface for dict-pytree-based optimizers like optimistix or everwillow) and positional calls via call_positional().

inputs

Tuple of PyTensor input variables, in evaluation order.

input_names

Names of those variables (same order as inputs).

fn

The raw JAX callable returned by jax_funcify.

Parameters:
__init__(inputs, input_names, fn)
Parameters:

Methods

__init__(inputs, input_names, fn)

call_positional(*args)

Call with positional arguments — pure passthrough to fn.

Attributes