Skip to content

Fix pm.distributions.transforms.ordered fails on >1 dimension #5660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

purna135
Copy link
Member

Addressing #5659

Kept the original dims in the log_jac_det by changing

def log_jac_det(self, value, *inputs): 
     return at.sum(value[..., 1:], axis=-1)

to

def log_jac_det(self, value, *inputs):
    return at.sum(value[..., 1:], axis=-1, keepdims=True)

@purna135
Copy link
Member Author

Hello @ricardoV94, after making these changes I ran the code mentioned in the issue, and it ran without error this time.
However, I am unsure how to select the appropriate test for this; could you please assist me with this?

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 28, 2022

You can just add the original issue example as a new test to test_transforms.py and assert the result of point_logps is what's expected.

@codecov
Copy link

codecov bot commented Mar 28, 2022

Codecov Report

Merging #5660 (c42288c) into main (fa015e3) will increase coverage by 0.34%.
The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5660      +/-   ##
==========================================
+ Coverage   86.86%   87.21%   +0.34%     
==========================================
  Files          75       75              
  Lines       13715    13715              
==========================================
+ Hits        11914    11961      +47     
+ Misses       1801     1754      -47     
Impacted Files Coverage Δ
pymc/distributions/transforms.py 59.63% <0.00%> (-40.37%) ⬇️
pymc/distributions/dist_math.py 57.71% <0.00%> (-29.72%) ⬇️
pymc/distributions/logprob.py 69.72% <0.00%> (-25.69%) ⬇️
pymc/data.py 63.02% <0.00%> (-20.59%) ⬇️
pymc/aesaraf.py 87.43% <0.00%> (-1.01%) ⬇️
pymc/parallel_sampling.py 86.04% <0.00%> (-0.67%) ⬇️
pymc/sampling.py 88.15% <0.00%> (+0.11%) ⬆️
pymc/distributions/multivariate.py 92.15% <0.00%> (+0.12%) ⬆️
pymc/model.py 85.89% <0.00%> (+0.27%) ⬆️
... and 10 more

@purna135
Copy link
Member Author

test has been added, please have a look.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Left a suggestion and request

Comment on lines 562 to 572
def test_transforms_ordered():
COORDS = {"question": np.arange(10), "thresholds": np.arange(4)}

with pm.Model(coords=COORDS) as model:
kappa = pm.Normal(
"kappa",
mu=[-3, -1, 1, 2],
sigma=1,
dims=["question", "thresholds"],
transform=pm.distributions.transforms.ordered,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to use coords for this example.

Suggested change
def test_transforms_ordered():
COORDS = {"question": np.arange(10), "thresholds": np.arange(4)}
with pm.Model(coords=COORDS) as model:
kappa = pm.Normal(
"kappa",
mu=[-3, -1, 1, 2],
sigma=1,
dims=["question", "thresholds"],
transform=pm.distributions.transforms.ordered,
)
def test_transforms_ordered():
with pm.Model() as model:
kappa = pm.Normal(
"kappa",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=pm.distributions.transforms.ordered,
)

Also should test with the other transform that you have changed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done : )
but what about the other tests which are falling ?

@drbenvincent
Copy link

Thanks for the contribution so far @purna135 ! I have my eye on this, and it looks like there are just two failing tests before this can be merged?

@purna135
Copy link
Member Author

purna135 commented Apr 8, 2022

Thank you so much, @drbenvincent; I was hoping for a response.
Yes, some tests are failing, but I can't figure out how to fix them; could you please assist me?

@drbenvincent
Copy link

Thank you so much, @drbenvincent; I was hoping for a response.
Yes, some tests are failing, but I can't figure out how to fix them; could you please assist me?

I would gladly, if I could, but so far my contributions have been to pymc-examples, and not to pymc itself. So I do not have enough experience. Maybe @ricardoV94 could give some hints when he has time?

@ricardoV94
Copy link
Member

@purna135 Have you tried to investigate why those tests are failing locally?

@purna135
Copy link
Member Author

purna135 commented Apr 8, 2022

Yes, I ran the test locally, but I couldn't figure out what was causing it.
I assumed that because we changed log_jac_det, the expectation of the following line changed as well.
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)

As a result, I think we should revise the assert statement.

@ricardoV94
Copy link
Member

You can print the results before and after the changes to be sure what's changed. Then we have to figure out if we introduced a bug or if the old results were wrong (probably the former)

@purna135
Copy link
Member Author

Is this jacob_det the one I should print to check the results?

@purna135
Copy link
Member Author

Hi @ricardoV94 I'm still waiting for a response and am unsure how to proceed. Could you please assist me?

return at.sum(value[..., 1:], axis=-1)
return at.sum(value[..., 1:], axis=-1, keepdims=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting keepdims=True will keep the last dimension while computing logprob for multivariate distributions. This behaviour can be a lot confusing to interpret when adding transforms for multivariate distributions.

Example:

from pymc.distributions.logprob import joint_logpt

with pm.Model() as model:
    mv = pm.MvNormal(
        "mv",
        mu=[-3, -1, 1, 2],
        cov=np.eye(4),
        initval=np.random.randn(4),
    )

print(joint_logpt(mv, sum=False)[0].ndim)  # will output 0

After adding Ordered transform

from pymc.distributions.logprob import joint_logpt

with pm.Model() as model:
    mv = pm.MvNormal(
        "mv",
        mu=[-3, -1, 1, 2],
        cov=np.eye(4),
        initval=np.random.randn(4),
        transform=pm.distributions.transforms.ordered,
    )

print(joint_logpt(mv, sum=False)[0].ndim)  # will output 1

I think there can be a simple check here: https://github.com/aesara-devs/aeppl/blob/65d25c39fa771809b4d992c9323f59e78f282cd5/aeppl/transforms.py#L372-L388

to squeeze out the last dimension in case when transforms are added for multivariate distributions.

cc @ricardoV94

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to tell the transform what is the support dimensionality of the RV. Can be an internal variable like n in the transform that's used by the LKJCholeskyCov

Copy link
Member

@Sayam753 Sayam753 Apr 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to tell the transform what is the support dimensionality of the RV.

If I get it right, we need to somehow pass op.ndim_supp to transforms to decide whether or not to keep last dimension. Seems like a simpler solution.

Also, I think SimplexTransform defined in AePPL needs to account for this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to pass it to all transforms, we can instead initialize this transform with that information. Like we initialize Interval or LKJ transform with the specific information they need. This would follow a similar API and be less disruptive

In this case instead of having a single ordered already initialized, the user would initialize it (or us behind the scenes, like pm.interval does). It can be a kwarg that by default is 0 (for univariate RVs)

pseudo-code:

class OrderedTransform(TransformRV):

  def __init__(self, ndim_supp):
    self.ndim_supp = ndim_supp

  def log_jac_det(self, value, *inputs):
    # use self.ndim_supp to return the right shape

def ordered(ndim_supp=0):
    return OrderedTransform(ndim_supp)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything wrong with the Simplex? That seems to be vector specific transform, I don't think it would be used in other cases.

We can think of it as being a VectorSimplex (i.e., a Simplex for RVs with ndim_supp=1), which is what it is used for in PyMC. If we wanted to use it with other ndim_supp we could make it more flexible...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the required info to transforms during initialization indeed makes sense. @purna135 can you take up adding the proposed change in this PR itself?

Is there anything wrong with the Simplex?

I noticed a test in test_transforms.py that applies Simplex transform to a batched univariate distribution.

@pytest.mark.parametrize(
"lower,upper,size,transform",
[
(0.0, 1.0, (2,), tr.simplex),
(0.5, 5.5, (2, 3), tr.simplex),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])),
],
)
def test_uniform_other(self, lower, upper, size, transform):
initval = np.ones(size) / size[-1]
model = self.build_model(
pm.Uniform,
{"lower": lower, "upper": upper},
size=size,
initval=initval,
transform=transform,
)
self.check_vectortransform_elementwise_logp(model)

This prompted me to make sure that even Simplex should not reduce the last dimension in case when applied to univariate distributions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @Sayam753. I'll work on these suggested changes and let you know if I need any assistance.

@purna135 purna135 marked this pull request as draft June 24, 2022 09:20
@drbenvincent
Copy link

Hi. Just dropping a message here to see if anyone is able to push this forward? I have a vested interest in this fix, but am not familiar enough to help contribute efficiently.

@TimOliverMaier
Copy link
Contributor

TimOliverMaier commented Oct 27, 2022

Hey! @purna135, is there any news on this? I would also be interested in this fix. I tried to add the suggestions into the code, but I am still having failing tests and tbh a lack of understanding, as well. I would be happy to help here, but I can't accomplish it by myself, too.

@TimOliverMaier
Copy link
Contributor

I've been advancing on this locally. There are two remaining issues, as far as I see:

  1. There are some remaining failing tests that are caused by the changing dimensionality of the jacobian determinant <- Some
    Transforms keep the last dimension, others do not. Keeping the other transforms as they are will make some further cases
    in test_transform.check_vectortransform_elementwise_logp() necessary.
  2. Sampling of ordered multivariate distributions fails with ValueError :
    "array must not contain infs or NaNs\nApply node that caused the error: SolveTriangular..."
    However, this is also the case on the main branch for me.
    What is the best way to collaborate on this PR? @purna135

@purna135
Copy link
Member Author

Hello, @TimOliverMaier. Thank you so much for addressing this issue and sorry for the delay in responding; I was preoccupied with other issues and didn't have time to solve this one.
It would be greatly appreciated if you could open a new PR from your local branch and I would close this one.

@TimOliverMaier
Copy link
Contributor

Hey, I opened a new PR. #6255 It is still work in progress, though.

@ricardoV94
Copy link
Member

Fixed in #6255

@ricardoV94 ricardoV94 closed this Nov 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants