Skip to content

Fix pm.distributions.transforms.ordered fails on >1 dimension #5660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def forward(self, value, *inputs):
return y

def log_jac_det(self, value, *inputs):
return at.sum(value[..., 1:], axis=-1)
return at.sum(value[..., 1:], axis=-1, keepdims=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting keepdims=True will keep the last dimension while computing logprob for multivariate distributions. This behaviour can be a lot confusing to interpret when adding transforms for multivariate distributions.

Example:

from pymc.distributions.logprob import joint_logpt

with pm.Model() as model:
    mv = pm.MvNormal(
        "mv",
        mu=[-3, -1, 1, 2],
        cov=np.eye(4),
        initval=np.random.randn(4),
    )

print(joint_logpt(mv, sum=False)[0].ndim)  # will output 0

After adding Ordered transform

from pymc.distributions.logprob import joint_logpt

with pm.Model() as model:
    mv = pm.MvNormal(
        "mv",
        mu=[-3, -1, 1, 2],
        cov=np.eye(4),
        initval=np.random.randn(4),
        transform=pm.distributions.transforms.ordered,
    )

print(joint_logpt(mv, sum=False)[0].ndim)  # will output 1

I think there can be a simple check here: https://github.com/aesara-devs/aeppl/blob/65d25c39fa771809b4d992c9323f59e78f282cd5/aeppl/transforms.py#L372-L388

to squeeze out the last dimension in case when transforms are added for multivariate distributions.

cc @ricardoV94

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to tell the transform what is the support dimensionality of the RV. Can be an internal variable like n in the transform that's used by the LKJCholeskyCov

Copy link
Member

@Sayam753 Sayam753 Apr 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to tell the transform what is the support dimensionality of the RV.

If I get it right, we need to somehow pass op.ndim_supp to transforms to decide whether or not to keep last dimension. Seems like a simpler solution.

Also, I think SimplexTransform defined in AePPL needs to account for this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to pass it to all transforms, we can instead initialize this transform with that information. Like we initialize Interval or LKJ transform with the specific information they need. This would follow a similar API and be less disruptive

In this case instead of having a single ordered already initialized, the user would initialize it (or us behind the scenes, like pm.interval does). It can be a kwarg that by default is 0 (for univariate RVs)

pseudo-code:

class OrderedTransform(TransformRV):

  def __init__(self, ndim_supp):
    self.ndim_supp = ndim_supp

  def log_jac_det(self, value, *inputs):
    # use self.ndim_supp to return the right shape

def ordered(ndim_supp=0):
    return OrderedTransform(ndim_supp)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything wrong with the Simplex? That seems to be vector specific transform, I don't think it would be used in other cases.

We can think of it as being a VectorSimplex (i.e., a Simplex for RVs with ndim_supp=1), which is what it is used for in PyMC. If we wanted to use it with other ndim_supp we could make it more flexible...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the required info to transforms during initialization indeed makes sense. @purna135 can you take up adding the proposed change in this PR itself?

Is there anything wrong with the Simplex?

I noticed a test in test_transforms.py that applies Simplex transform to a batched univariate distribution.

@pytest.mark.parametrize(
"lower,upper,size,transform",
[
(0.0, 1.0, (2,), tr.simplex),
(0.5, 5.5, (2, 3), tr.simplex),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])),
],
)
def test_uniform_other(self, lower, upper, size, transform):
initval = np.ones(size) / size[-1]
model = self.build_model(
pm.Uniform,
{"lower": lower, "upper": upper},
size=size,
initval=initval,
transform=transform,
)
self.check_vectortransform_elementwise_logp(model)

This prompted me to make sure that even Simplex should not reduce the last dimension in case when applied to univariate distributions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @Sayam753. I'll work on these suggested changes and let you know if I need any assistance.



class SumTo1(RVTransform):
Expand All @@ -102,7 +102,7 @@ def forward(self, value, *inputs):

def log_jac_det(self, value, *inputs):
y = at.zeros(value.shape)
return at.sum(y, axis=-1)
return at.sum(y, axis=-1, keepdims=True)


class CholeskyCovPacked(RVTransform):
Expand Down
28 changes: 28 additions & 0 deletions pymc/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,31 @@ def test_interval_transform_raises():
tr.Interval(at.constant(5) + 1, None)

assert tr.Interval(at.constant(5), None)


def test_transforms_ordered():
with pm.Model() as model:
pm.Normal(
"x",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=pm.distributions.transforms.ordered,
)

log_prob = model.point_logps()
np.testing.assert_allclose(list(log_prob.values()), np.array([18.69]))


def test_transforms_sumto1():
with pm.Model() as model:
pm.Normal(
"x",
mu=[-3, -1, 1, 2],
sigma=1,
size=(10, 4),
transform=pm.distributions.transforms.sum_to_1,
)

log_prob = model.point_logps()
np.testing.assert_allclose(list(log_prob.values()), np.array([-56.76]))