-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix JAX sampling funcs overwriting existing var's dims and coords #6041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #6041 +/- ##
==========================================
+ Coverage 89.26% 89.27% +0.01%
==========================================
Files 72 72
Lines 12890 12897 +7
==========================================
+ Hits 11506 11514 +8
+ Misses 1384 1383 -1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added some minor comments & questions. Otherwise, looks all good to me.
pymc/tests/test_sampling_jax.py
Outdated
assert list(x_coords_expected) == list(posterior["x"].coords[x_dim_expected].values) | ||
|
||
assert posterior["z"].dims[2] == "z_coord" | ||
assert set(posterior["z"].coords["z_coord"].values) == {"apple", "banana", "orange"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you compare sets here rather than the stronger requirement of equality of lists?
pymc/sampling_jax.py
Outdated
coords.update(idata_kwargs.pop("coords")) | ||
if "dims" in idata_kwargs: | ||
dims.update(idata_kwargs.pop("dims")) | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider dropping the this line as it's not necessary. Or maybe this is a recommended style, which I am not yet aware of, for a function that should not return anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's purely a style choice that I generally do, but will of course follow the style of the repo.
pymc/sampling_jax.py
Outdated
@@ -376,6 +388,8 @@ def sample_blackjax_nuts( | |||
} | |||
|
|||
posterior = mcmc_samples | |||
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding to the comment something like "and drop keys 'coords' and 'dims' from 'idata_kwargs' if present'
pymc/sampling_jax.py
Outdated
@@ -596,6 +611,8 @@ def sample_numpyro_nuts( | |||
} | |||
|
|||
posterior = mcmc_samples | |||
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider extending the comment as above.
- add details to comments about changing idata_kwargs - change test to comparing lists instead of sets - remove an explicit 'return None'
Thanks @jhrcook and @bherwerth! |
What is this PR about?
The original solution to fix Issue #5932 overwrites existing dimensions and coordinate data on the variables when creating the ArviZ InferenceData object. (See the issue for a demonstration of this behavior.) This PR address that by using the
dims
andcoords
in theidata_kwargs
argument to update the extracted dimensions and coordinates.Checklist
Bugfixes / New features