-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix pm.distributions.transforms.ordered
fails on >1 dimension
#5660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting
keepdims=True
will keep the last dimension while computing logprob for multivariate distributions. This behaviour can be a lot confusing to interpret when adding transforms for multivariate distributions.Example:
After adding Ordered transform
I think there can be a simple check here: https://github.com/aesara-devs/aeppl/blob/65d25c39fa771809b4d992c9323f59e78f282cd5/aeppl/transforms.py#L372-L388
to squeeze out the last dimension in case when transforms are added for multivariate distributions.
cc @ricardoV94
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to tell the transform what is the support dimensionality of the RV. Can be an internal variable like
n
in the transform that's used by the LKJCholeskyCovUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I get it right, we need to somehow pass
op.ndim_supp
to transforms to decide whether or not to keep last dimension. Seems like a simpler solution.Also, I think SimplexTransform defined in AePPL needs to account for this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to pass it to all transforms, we can instead initialize this transform with that information. Like we initialize Interval or LKJ transform with the specific information they need. This would follow a similar API and be less disruptive
In this case instead of having a single ordered already initialized, the user would initialize it (or us behind the scenes, like pm.interval does). It can be a kwarg that by default is 0 (for univariate RVs)
pseudo-code:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything wrong with the Simplex? That seems to be vector specific transform, I don't think it would be used in other cases.
We can think of it as being a VectorSimplex (i.e., a Simplex for RVs with ndim_supp=1), which is what it is used for in PyMC. If we wanted to use it with other ndim_supp we could make it more flexible...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing the required info to transforms during initialization indeed makes sense. @purna135 can you take up adding the proposed change in this PR itself?
I noticed a test in
test_transforms.py
that applies Simplex transform to a batched univariate distribution.pymc/pymc/tests/test_transforms.py
Lines 509 to 526 in bae5087
This prompted me to make sure that even Simplex should not reduce the last dimension in case when applied to univariate distributions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @Sayam753. I'll work on these suggested changes and let you know if I need any assistance.