Skip to content

BUG: sample_blackjax_nuts incorrectly reports sampling time #6550

Closed
@dehorsley

Description

@dehorsley

Describe the issue:

For larger models, sample_blackjax_nuts reports (unfortunately) unrealistically small sampling time and a very large transformation time. This is due to JAX's Asynchronous dispatch model. I'd guess what's currently reported as "sampling time" is really JAX compilation. It should be able to be fixed with a simple block_until_ready(). I can send a PR shortly.

Mostly a cosmetic issue, but it caused me some time trying to work out why transformation was so slow 😄.

Reproduceable code example:

import pymc as pm
from pymc.sampling.jax import sample_blackjax_nuts
import numpy as np

x = np.random.normal(0, 1, 1_000_000)
with pm.Model() as model:
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pm.Normal("mu", 0, 1)
    pm.Normal("x", mu, sigma, observed=x)
    trace_blackjax = sample_blackjax_nuts()

Error message:

Compiling...
Compilation time =  0:00:02.020721
Sampling...
Sampling time =  0:00:09.755916
Transforming variables...
Transformation time =  0:02:29.152843

PyMC version information:

macOS

pytensor            2.9.1
blackjax            0.9.6
pymc                5.0.2+58.g33d641db
jax                 0.4.1
jaxlib              0.4.1

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions