Skip to content

Cannot pass dims to idata_kwargs parameter in sample_numpyro_nuts #5932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jhrcook opened this issue Jun 26, 2022 · 4 comments · Fixed by #5983
Closed

Cannot pass dims to idata_kwargs parameter in sample_numpyro_nuts #5932

jhrcook opened this issue Jun 26, 2022 · 4 comments · Fixed by #5983

Comments

@jhrcook
Copy link
Contributor

jhrcook commented Jun 26, 2022

A minimal, self-contained, and reproducible example.

import pymc as pm
import pymc.sampling_jax
import arviz as az

coords = {"param": ["a", "b"]}
with pm.Model(coords=coords) as model:
    chol, corr, stds = pm.LKJCholeskyCov(
        "chol", n=2, eta=2.0, sd_dist=pm.Gamma.dist(2, 1)
    )
    trace = pm.sample(
        10, tune=10, cores=2, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
    )

with model:
    trace = pymc.sampling_jax.sample_numpyro_nuts(
        10, tune=10, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
    )

Full traceback.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [5], in <cell line: 1>()
      1 with model:
----> 2     trace = pymc.sampling_jax.sample_numpyro_nuts(10, tune=10, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}})

File /usr/local/Caskroom/miniconda/base/envs/bluishred/lib/python3.10/site-packages/pymc/sampling_jax.py:564, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    559 attrs = {
    560     "sampling_time": (tic3 - tic2).total_seconds(),
    561 }
    563 posterior = mcmc_samples
--> 564 az_trace = az.from_dict(
    565     posterior=posterior,
    566     log_likelihood=log_likelihood,
    567     observed_data=find_observations(model),
    568     constant_data=find_constants(model),
    569     sample_stats=_sample_stats_to_xarray(pmap_numpyro),
    570     coords=coords,
    571     dims=dims,
    572     attrs=make_attrs(attrs, library=numpyro),
    573     **idata_kwargs,
    574 )
    576 return az_trace

TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'dims'

Details.

I am trying to add coordinate labels to the variables produced by LKJCholeskyCov() by passing the dimensions to the idata_kwargs parameter of the PyMC sampling function (as demonstrated in Oriol Abril's blog post (https://oriolabrilpla.cat/python/arviz/pymc/xarray/2022/06/07/pymc-arviz.html#2nd-example:-radon-multilevel-model). The method works when sampling with the default PyMC sampler, but fails with the Numpyro JAX backend. I have provided a full example in the code above, but please let me know if more details are needed

Versions and main components

  • PyMC Version: 4.0.1
  • Aesara Version: 2.7.3
  • Python Version: 3.10.5
  • Other relevant libraries: jax=v0.3.14, jaxlib=0.3.10, numpyro=0.9.2
  • Operating system: macOS Monterey (v12.4)
  • How did you install PyMC/PyMC3: conda
@bherwerth
Copy link
Contributor

Thanks for reporting.

The same problem occurs with coords in idata_kwargs and when using sample_numpyro_nuts. This is shown by running the following code with py.test:

import pymc as pm
from functools import partial

from pymc import sampling_jax
import pytest

@pytest.mark.parametrize(
    "sampler",
    [
        partial(pm.sample, cores=1),
        sampling_jax.sample_blackjax_nuts,
        sampling_jax.sample_numpyro_nuts,
    ]
)
@pytest.mark.parametrize(
    "idata_kwargs",
    [
        {},
        {"dims": {"x": ["x_coord"]}},
        {"coords": {"x": ["a", "b"]}},
        {"coords": {"x": ["a", "b"]}, "dims": {"x": ["x_coord"]}},
    ]
)
def test_idata_kwargs(sampler, idata_kwargs):
    with pm.Model() as model:
        x = pm.Uniform("x", -10, 10, shape=(2,))
        trace = sampler(100, tune=100, idata_kwargs=idata_kwargs)

This produces the following output:

================================================================================================ short test summary info ================================================================================================
FAILED issue_5932.py::test_idata_kwargs[idata_kwargs1-sample_blackjax_nuts] - TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'dims'
FAILED issue_5932.py::test_idata_kwargs[idata_kwargs1-sample_numpyro_nuts] - TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'dims'
FAILED issue_5932.py::test_idata_kwargs[idata_kwargs2-sample_blackjax_nuts] - TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'coords'
FAILED issue_5932.py::test_idata_kwargs[idata_kwargs2-sample_numpyro_nuts] - TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'coords'
FAILED issue_5932.py::test_idata_kwargs[idata_kwargs3-sample_blackjax_nuts] - TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'coords'
FAILED issue_5932.py::test_idata_kwargs[idata_kwargs3-sample_numpyro_nuts] - TypeError: arviz.data.io_dict.from_dict() got multiple values for keyword argument 'coords'
================================================================================== 6 failed, 6 passed, 3 warnings in 101.87s (0:01:41) ==================================================================================

I think that the problem that dims and coords is defined in the sampling functions and then passed to az.from_dict together with idata_kwargs:

pymc/pymc/sampling_jax.py

Lines 368 to 377 in 2583b7f

az_trace = az.from_dict(
posterior=posterior,
log_likelihood=log_likelihood,
observed_data=find_observations(model),
constant_data=find_constants(model),
coords=coords,
dims=dims,
attrs=make_attrs(attrs, library=blackjax),
**idata_kwargs,
)

A solution could be to pass ikwargs = dict(dims=dims, coords=coords, ...) after updating ikwargs with idata_kwargs, similar sampling.sample:

pymc/pymc/sampling.py

Lines 685 to 687 in 2583b7f

ikwargs = dict(model=model, save_warmup=not discard_tuned_samples)
if idata_kwargs:
ikwargs.update(idata_kwargs)

bherwerth added a commit to bherwerth/pymc that referenced this issue Jul 16, 2022
twiecki pushed a commit that referenced this issue Jul 17, 2022
* Use 'partial' to construct inference data (#5932)

* Add test cases for dims & coords

* Comment on use of 'partial'

* Extend docstring and pass posterior after partial
@jhrcook
Copy link
Contributor Author

jhrcook commented Aug 9, 2022

The implementation to address this issue overwrites the dimensions extracted from the random variables in the model. Therefore, supplying some dimesions through idata_kwargs={"dims": {...}} causes all other variables to lose their information. Here is an example:

import arviz as az
import pymc as pm
import pymc.sampling_jax

coords = {
    "param": ["a", "b"],
    "animal": ["bear", "penguin", "lizard"],
}

with pm.Model(coords=coords) as model:
    a = pm.Normal("a", 0, 1, dims="animal")
    chol, corr, stds = pm.LKJCholeskyCov(
        "chol", n=2, eta=2.0, sd_dist=pm.Gamma.dist(2, 1)
    )
    trace_pm = pm.sample(
        10, tune=10, cores=2, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
    )

with model:
    trace_numpyro = pymc.sampling_jax.sample_numpyro_nuts(
        10, tune=10, chains=2, idata_kwargs={"dims": {"chol_stds": ["param"]}}
    )

The information is retained for variable a in the PyMC NUTS posterior:

>>> trace_pm.posterior["a"].dims
('chain', 'draw', 'animal')
>>> trace_pm.posterior["a"].coords
Coordinates:
  * chain    (chain) int64 0 1
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9
  * animal   (animal) <U7 'bear' 'penguin' 'lizard'

The information is lost for variable a in the Numpyro NUTS posterior:

>>> trace_numpyro.posterior["a"].dims
('chain', 'draw', 'a_dim_0')
>>> trace_numpyro.posterior["a"].coords
Coordinates:
  * chain    (chain) int64 0 1
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9
  * a_dim_0  (a_dim_0) int64 0 1 2

I think it should be a simple enough fix, so I will try working on it and submit a PR. I'll also add cases to the tests added in the original solution.

@bherwerth
Copy link
Contributor

Thank you for noting this and proposing a fix. When I wrote the original solution, I was assuming the dims and coords from idata_kwargs should be passed without change. However, I agree that updating existing dims and coords makes more sense and ensures consistency with pm.sample.

@jhrcook
Copy link
Contributor Author

jhrcook commented Aug 9, 2022

I think that's a perfectly fair interpretation. Both use cases could be valuable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants