Skip to content

Commit 18346ac

Browse files
authored
Allow OrderedProbit distribution to take vector inputs (#5418)
* fix and add test_vector_inputs for OrderedProbit * add test_shape_inputs for _OrderedLogistic
1 parent 1005d20 commit 18346ac

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

pymc/distributions/discrete.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1964,7 +1964,9 @@ def dist(cls, eta, cutpoints, sigma=1, *args, **kwargs):
19641964
_log_p = at.concatenate(
19651965
[
19661966
at.shape_padright(normal_lccdf(0, sigma, probits[..., 0])),
1967-
log_diff_normal_cdf(0, sigma, probits[..., :-1], probits[..., 1:]),
1967+
log_diff_normal_cdf(
1968+
0, at.shape_padright(sigma), probits[..., :-1], probits[..., 1:]
1969+
),
19681970
at.shape_padright(normal_lcdf(0, sigma, probits[..., -1])),
19691971
],
19701972
axis=-1,

pymc/tests/test_distributions_random.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,28 @@ class TestOrderedLogistic(BaseTestDistributionRandom):
16881688
"check_rv_size",
16891689
]
16901690

1691+
@pytest.mark.parametrize(
1692+
"eta, cutpoints, expected",
1693+
[
1694+
(0, [-2.0, 0, 2.0], (4,)),
1695+
([-1], [-2.0, 0, 2.0], (1, 4)),
1696+
([1.0, -2.0], [-1.0, 0, 1.0], (2, 4)),
1697+
(np.zeros((3, 2)), [-2.0, 0, 1.0], (3, 2, 4)),
1698+
(np.ones((5, 2)), [[-2.0, 0, 1.0], [-1.0, 0, 1.0]], (5, 2, 4)),
1699+
(np.ones((3, 5, 2)), [[-2.0, 0, 1.0], [-1.0, 0, 1.0]], (3, 5, 2, 4)),
1700+
],
1701+
)
1702+
def test_shape_inputs(self, eta, cutpoints, expected):
1703+
"""
1704+
This test checks when providing different shapes for `eta` parameters.
1705+
"""
1706+
categorical = _OrderedLogistic.dist(
1707+
eta=eta,
1708+
cutpoints=cutpoints,
1709+
)
1710+
p = categorical.owner.inputs[3].eval()
1711+
assert p.shape == expected
1712+
16911713

16921714
class TestOrderedProbit(BaseTestDistributionRandom):
16931715
pymc_dist = _OrderedProbit
@@ -1698,6 +1720,30 @@ class TestOrderedProbit(BaseTestDistributionRandom):
16981720
"check_rv_size",
16991721
]
17001722

1723+
@pytest.mark.parametrize(
1724+
"eta, cutpoints, sigma, expected",
1725+
[
1726+
(0, [-2.0, 0, 2.0], 1.0, (4,)),
1727+
([-1], [-1.0, 0, 2.0], [2.0], (1, 4)),
1728+
([1.0, -2.0], [-1.0, 0, 1.0], 1.0, (2, 4)),
1729+
([1.0, -2.0, 3.0], [-1.0, 0, 2.0], np.ones((1, 3)), (1, 3, 4)),
1730+
(np.zeros((2, 3)), [-2.0, 0, 1.0], [1.0, 2.0, 5.0], (2, 3, 4)),
1731+
(np.ones((2, 3)), [-1.0, 0, 1.0], np.ones((2, 3)), (2, 3, 4)),
1732+
(np.zeros((5, 2)), [[-2, 0, 1], [-1, 0, 1]], np.ones((2, 5, 2)), (2, 5, 2, 4)),
1733+
],
1734+
)
1735+
def test_shape_inputs(self, eta, cutpoints, sigma, expected):
1736+
"""
1737+
This test checks when providing different shapes for `eta` and `sigma` parameters.
1738+
"""
1739+
categorical = _OrderedProbit.dist(
1740+
eta=eta,
1741+
cutpoints=cutpoints,
1742+
sigma=sigma,
1743+
)
1744+
p = categorical.owner.inputs[3].eval()
1745+
assert p.shape == expected
1746+
17011747

17021748
class TestOrderedMultinomial(BaseTestDistributionRandom):
17031749
pymc_dist = _OrderedMultinomial

0 commit comments

Comments
 (0)