-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Cannot pass dims
to idata_kwargs
parameter in sample_numpyro_nuts
#5932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for reporting. The same problem occurs with import pymc as pm
from functools import partial
from pymc import sampling_jax
import pytest
@pytest.mark.parametrize(
"sampler",
[
partial(pm.sample, cores=1),
sampling_jax.sample_blackjax_nuts,
sampling_jax.sample_numpyro_nuts,
]
)
@pytest.mark.parametrize(
"idata_kwargs",
[
{},
{"dims": {"x": ["x_coord"]}},
{"coords": {"x": ["a", "b"]}},
{"coords": {"x": ["a", "b"]}, "dims": {"x": ["x_coord"]}},
]
)
def test_idata_kwargs(sampler, idata_kwargs):
with pm.Model() as model:
x = pm.Uniform("x", -10, 10, shape=(2,))
trace = sampler(100, tune=100, idata_kwargs=idata_kwargs) This produces the following output:
I think that the problem that Lines 368 to 377 in 2583b7f
A solution could be to pass Lines 685 to 687 in 2583b7f
|
The implementation to address this issue overwrites the dimensions extracted from the random variables in the model. Therefore, supplying some dimesions through import arviz as az
import pymc as pm
import pymc.sampling_jax
coords = {
"param": ["a", "b"],
"animal": ["bear", "penguin", "lizard"],
}
with pm.Model(coords=coords) as model:
a = pm.Normal("a", 0, 1, dims="animal")
chol, corr, stds = pm.LKJCholeskyCov(
"chol", n=2, eta=2.0, sd_dist=pm.Gamma.dist(2, 1)
)
trace_pm = pm.sample(
10, tune=10, cores=2, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
)
with model:
trace_numpyro = pymc.sampling_jax.sample_numpyro_nuts(
10, tune=10, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
) The information is retained for variable >>> trace_pm.posterior["a"].dims
('chain', 'draw', 'animal')
>>> trace_pm.posterior["a"].coords
Coordinates:
* chain (chain) int64 0 1
* draw (draw) int64 0 1 2 3 4 5 6 7 8 9
* animal (animal) <U7 'bear' 'penguin' 'lizard' The information is lost for variable >>> trace_numpyro.posterior["a"].dims
('chain', 'draw', 'a_dim_0')
>>> trace_numpyro.posterior["a"].coords
Coordinates:
* chain (chain) int64 0 1
* draw (draw) int64 0 1 2 3 4 5 6 7 8 9
* a_dim_0 (a_dim_0) int64 0 1 2 I think it should be a simple enough fix, so I will try working on it and submit a PR. I'll also add cases to the tests added in the original solution. |
Thank you for noting this and proposing a fix. When I wrote the original solution, I was assuming the |
I think that's a perfectly fair interpretation. Both use cases could be valuable. |
A minimal, self-contained, and reproducible example.
Full traceback.
Details.
I am trying to add coordinate labels to the variables produced by
LKJCholeskyCov()
by passing the dimensions to theidata_kwargs
parameter of the PyMC sampling function (as demonstrated in Oriol Abril's blog post (https://oriolabrilpla.cat/python/arviz/pymc/xarray/2022/06/07/pymc-arviz.html#2nd-example:-radon-multilevel-model). The method works when sampling with the default PyMC sampler, but fails with the Numpyro JAX backend. I have provided a full example in the code above, but please let me know if more details are neededVersions and main components
jax=v0.3.14
,jaxlib=0.3.10
,numpyro=0.9.2
The text was updated successfully, but these errors were encountered: