Skip to content

Fix error when passing coords and dims in sampling_jax #5983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 17, 2022

Conversation

bherwerth
Copy link
Contributor

Fix #5932 : Allow for passing coords and dims within idata_kwargs in sample_blackjax_nuts and sample_numpyro_nuts.

Previously, the arguments dims and coords as defined in the sampling function were passed to az.from_dict along with idata_kwargs. This caused an error when idata_kwargs contained keys coords and/or dims (duplicate keyword arguments), cf #5932.

Proposed solution:

  • By using functools.partial on az.from_dict, new defaults are set first. Only afterwards idata_kwargs is passed along. This results in dims and coords in idata_kwargs taking precedence (if given) instead of causing duplicate keyword args.
  • Test cases were added to test_idata_kwargs for providing coords and/or dims in idata_kwargs.
    ...

Checklist

Major / Breaking Changes

  • ...

Bugfixes / New features

Docs / Maintenance

  • ...

@codecov
Copy link

codecov bot commented Jul 16, 2022

Codecov Report

Merging #5983 (13e6ed1) into main (2583b7f) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5983      +/-   ##
==========================================
- Coverage   89.35%   89.35%   -0.01%     
==========================================
  Files          73       73              
  Lines       13251    13253       +2     
==========================================
+ Hits        11841    11842       +1     
- Misses       1410     1411       +1     
Impacted Files Coverage Δ
pymc/sampling_jax.py 96.96% <100.00%> (+0.03%) ⬆️
pymc/step_methods/hmc/base_hmc.py 89.76% <0.00%> (-0.79%) ⬇️

@twiecki twiecki merged commit 2a9e86c into pymc-devs:main Jul 17, 2022
@bherwerth bherwerth deleted the sampling-jax-fix-idata-kwargs branch July 18, 2022 19:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot pass dims to idata_kwargs parameter in sample_numpyro_nuts
3 participants