-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix pm.distributions.transforms.ordered
fails on >1 dimension
#5660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hello @ricardoV94, after making these changes I ran the code mentioned in the issue, and it ran without error this time. |
You can just add the original issue example as a new test to |
Codecov Report
@@ Coverage Diff @@
## main #5660 +/- ##
==========================================
+ Coverage 86.86% 87.21% +0.34%
==========================================
Files 75 75
Lines 13715 13715
==========================================
+ Hits 11914 11961 +47
+ Misses 1801 1754 -47
|
test has been added, please have a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Left a suggestion and request
pymc/tests/test_transforms.py
Outdated
def test_transforms_ordered(): | ||
COORDS = {"question": np.arange(10), "thresholds": np.arange(4)} | ||
|
||
with pm.Model(coords=COORDS) as model: | ||
kappa = pm.Normal( | ||
"kappa", | ||
mu=[-3, -1, 1, 2], | ||
sigma=1, | ||
dims=["question", "thresholds"], | ||
transform=pm.distributions.transforms.ordered, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to use coords for this example.
def test_transforms_ordered(): | |
COORDS = {"question": np.arange(10), "thresholds": np.arange(4)} | |
with pm.Model(coords=COORDS) as model: | |
kappa = pm.Normal( | |
"kappa", | |
mu=[-3, -1, 1, 2], | |
sigma=1, | |
dims=["question", "thresholds"], | |
transform=pm.distributions.transforms.ordered, | |
) | |
def test_transforms_ordered(): | |
with pm.Model() as model: | |
kappa = pm.Normal( | |
"kappa", | |
mu=[-3, -1, 1, 2], | |
sigma=1, | |
size=(10, 4), | |
transform=pm.distributions.transforms.ordered, | |
) |
Also should test with the other transform that you have changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done : )
but what about the other tests which are falling ?
Thanks for the contribution so far @purna135 ! I have my eye on this, and it looks like there are just two failing tests before this can be merged? |
Thank you so much, @drbenvincent; I was hoping for a response. |
I would gladly, if I could, but so far my contributions have been to |
@purna135 Have you tried to investigate why those tests are failing locally? |
Yes, I ran the test locally, but I couldn't figure out what was causing it. As a result, I think we should revise the |
You can print the results before and after the changes to be sure what's changed. Then we have to figure out if we introduced a bug or if the old results were wrong (probably the former) |
Is this |
Hi @ricardoV94 I'm still waiting for a response and am unsure how to proceed. Could you please assist me? |
return at.sum(value[..., 1:], axis=-1) | ||
return at.sum(value[..., 1:], axis=-1, keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting keepdims=True
will keep the last dimension while computing logprob for multivariate distributions. This behaviour can be a lot confusing to interpret when adding transforms for multivariate distributions.
Example:
from pymc.distributions.logprob import joint_logpt
with pm.Model() as model:
mv = pm.MvNormal(
"mv",
mu=[-3, -1, 1, 2],
cov=np.eye(4),
initval=np.random.randn(4),
)
print(joint_logpt(mv, sum=False)[0].ndim) # will output 0
After adding Ordered transform
from pymc.distributions.logprob import joint_logpt
with pm.Model() as model:
mv = pm.MvNormal(
"mv",
mu=[-3, -1, 1, 2],
cov=np.eye(4),
initval=np.random.randn(4),
transform=pm.distributions.transforms.ordered,
)
print(joint_logpt(mv, sum=False)[0].ndim) # will output 1
I think there can be a simple check here: https://github.com/aesara-devs/aeppl/blob/65d25c39fa771809b4d992c9323f59e78f282cd5/aeppl/transforms.py#L372-L388
to squeeze out the last dimension in case when transforms are added for multivariate distributions.
cc @ricardoV94
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to tell the transform what is the support dimensionality of the RV. Can be an internal variable like n
in the transform that's used by the LKJCholeskyCov
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to tell the transform what is the support dimensionality of the RV.
If I get it right, we need to somehow pass op.ndim_supp
to transforms to decide whether or not to keep last dimension. Seems like a simpler solution.
Also, I think SimplexTransform defined in AePPL needs to account for this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to pass it to all transforms, we can instead initialize this transform with that information. Like we initialize Interval or LKJ transform with the specific information they need. This would follow a similar API and be less disruptive
In this case instead of having a single ordered already initialized, the user would initialize it (or us behind the scenes, like pm.interval does). It can be a kwarg that by default is 0 (for univariate RVs)
pseudo-code:
class OrderedTransform(TransformRV):
def __init__(self, ndim_supp):
self.ndim_supp = ndim_supp
def log_jac_det(self, value, *inputs):
# use self.ndim_supp to return the right shape
def ordered(ndim_supp=0):
return OrderedTransform(ndim_supp)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything wrong with the Simplex? That seems to be vector specific transform, I don't think it would be used in other cases.
We can think of it as being a VectorSimplex (i.e., a Simplex for RVs with ndim_supp=1), which is what it is used for in PyMC. If we wanted to use it with other ndim_supp we could make it more flexible...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing the required info to transforms during initialization indeed makes sense. @purna135 can you take up adding the proposed change in this PR itself?
Is there anything wrong with the Simplex?
I noticed a test in test_transforms.py
that applies Simplex transform to a batched univariate distribution.
pymc/pymc/tests/test_transforms.py
Lines 509 to 526 in bae5087
@pytest.mark.parametrize( | |
"lower,upper,size,transform", | |
[ | |
(0.0, 1.0, (2,), tr.simplex), | |
(0.5, 5.5, (2, 3), tr.simplex), | |
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])), | |
], | |
) | |
def test_uniform_other(self, lower, upper, size, transform): | |
initval = np.ones(size) / size[-1] | |
model = self.build_model( | |
pm.Uniform, | |
{"lower": lower, "upper": upper}, | |
size=size, | |
initval=initval, | |
transform=transform, | |
) | |
self.check_vectortransform_elementwise_logp(model) |
This prompted me to make sure that even Simplex should not reduce the last dimension in case when applied to univariate distributions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @Sayam753. I'll work on these suggested changes and let you know if I need any assistance.
Hi. Just dropping a message here to see if anyone is able to push this forward? I have a vested interest in this fix, but am not familiar enough to help contribute efficiently. |
Hey! @purna135, is there any news on this? I would also be interested in this fix. I tried to add the suggestions into the code, but I am still having failing tests and tbh a lack of understanding, as well. I would be happy to help here, but I can't accomplish it by myself, too. |
I've been advancing on this locally. There are two remaining issues, as far as I see:
|
Hello, @TimOliverMaier. Thank you so much for addressing this issue and sorry for the delay in responding; I was preoccupied with other issues and didn't have time to solve this one. |
Hey, I opened a new PR. #6255 It is still work in progress, though. |
Fixed in #6255 |
Addressing #5659
Kept the original dims in the log_jac_det by changing
to