Fix error when passing coords
and dims
in sampling_jax
#5983
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix #5932 : Allow for passing
coords
anddims
withinidata_kwargs
insample_blackjax_nuts
andsample_numpyro_nuts
.Previously, the arguments
dims
andcoords
as defined in the sampling function were passed toaz.from_dict
along withidata_kwargs
. This caused an error whenidata_kwargs
contained keyscoords
and/ordims
(duplicate keyword arguments), cf #5932.Proposed solution:
functools.partial
onaz.from_dict
, new defaults are set first. Only afterwardsidata_kwargs
is passed along. This results indims
andcoords
inidata_kwargs
taking precedence (if given) instead of causing duplicate keyword args.test_idata_kwargs
for providingcoords
and/ordims
inidata_kwargs
....
Checklist
Major / Breaking Changes
Bugfixes / New features
dims
toidata_kwargs
parameter insample_numpyro_nuts
#5932Docs / Maintenance