Skip to content

Fix JAX sampling funcs overwriting existing var's dims and coords #6041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 20, 2022

Conversation

jhrcook
Copy link
Contributor

@jhrcook jhrcook commented Aug 9, 2022

What is this PR about?

The original solution to fix Issue #5932 overwrites existing dimensions and coordinate data on the variables when creating the ArviZ InferenceData object. (See the issue for a demonstration of this behavior.) This PR address that by using the dims and coords in the idata_kwargs argument to update the extracted dimensions and coordinates.

Checklist

Bugfixes / New features

  • Fix JAX sampling funcs overwriting existing var's dims and coords.

@codecov
Copy link

codecov bot commented Aug 9, 2022

Codecov Report

Merging #6041 (3033125) into main (906fcdc) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6041      +/-   ##
==========================================
+ Coverage   89.26%   89.27%   +0.01%     
==========================================
  Files          72       72              
  Lines       12890    12897       +7     
==========================================
+ Hits        11506    11514       +8     
+ Misses       1384     1383       -1     
Impacted Files Coverage Δ
pymc/sampling_jax.py 97.15% <100.00%> (+0.09%) ⬆️
pymc/step_methods/hmc/base_hmc.py 90.55% <0.00%> (+0.78%) ⬆️

@jhrcook jhrcook marked this pull request as ready for review August 9, 2022 17:11
Copy link
Contributor

@bherwerth bherwerth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some minor comments & questions. Otherwise, looks all good to me.

assert list(x_coords_expected) == list(posterior["x"].coords[x_dim_expected].values)

assert posterior["z"].dims[2] == "z_coord"
assert set(posterior["z"].coords["z_coord"].values) == {"apple", "banana", "orange"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you compare sets here rather than the stronger requirement of equality of lists?

coords.update(idata_kwargs.pop("coords"))
if "dims" in idata_kwargs:
dims.update(idata_kwargs.pop("dims"))
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider dropping the this line as it's not necessary. Or maybe this is a recommended style, which I am not yet aware of, for a function that should not return anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's purely a style choice that I generally do, but will of course follow the style of the repo.

@@ -376,6 +388,8 @@ def sample_blackjax_nuts(
}

posterior = mcmc_samples
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding to the comment something like "and drop keys 'coords' and 'dims' from 'idata_kwargs' if present'

@@ -596,6 +611,8 @@ def sample_numpyro_nuts(
}

posterior = mcmc_samples
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extending the comment as above.

jhrcook and others added 2 commits August 14, 2022 14:38
- add details to comments about changing idata_kwargs
- change test to comparing lists instead of sets
- remove an explicit 'return None'
@michaelosthege michaelosthege requested a review from twiecki August 19, 2022 15:17
@twiecki twiecki merged commit 9024c2b into pymc-devs:main Aug 20, 2022
@twiecki
Copy link
Member

twiecki commented Aug 20, 2022

Thanks @jhrcook and @bherwerth!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants