Closed
Description
Describe the issue:
For larger models, sample_blackjax_nuts
reports (unfortunately) unrealistically small sampling time and a very large transformation time. This is due to JAX's Asynchronous dispatch model. I'd guess what's currently reported as "sampling time" is really JAX compilation. It should be able to be fixed with a simple block_until_ready()
. I can send a PR shortly.
Mostly a cosmetic issue, but it caused me some time trying to work out why transformation was so slow 😄.
Reproduceable code example:
import pymc as pm
from pymc.sampling.jax import sample_blackjax_nuts
import numpy as np
x = np.random.normal(0, 1, 1_000_000)
with pm.Model() as model:
sigma = pm.HalfCauchy("sigma", 1)
mu = pm.Normal("mu", 0, 1)
pm.Normal("x", mu, sigma, observed=x)
trace_blackjax = sample_blackjax_nuts()
Error message:
Compiling...
Compilation time = 0:00:02.020721
Sampling...
Sampling time = 0:00:09.755916
Transforming variables...
Transformation time = 0:02:29.152843
PyMC version information:
macOS
pytensor 2.9.1
blackjax 0.9.6
pymc 5.0.2+58.g33d641db
jax 0.4.1
jaxlib 0.4.1
Context for the issue:
No response