@@ -1688,6 +1688,28 @@ class TestOrderedLogistic(BaseTestDistributionRandom):
1688
1688
"check_rv_size" ,
1689
1689
]
1690
1690
1691
+ @pytest .mark .parametrize (
1692
+ "eta, cutpoints, expected" ,
1693
+ [
1694
+ (0 , [- 2.0 , 0 , 2.0 ], (4 ,)),
1695
+ ([- 1 ], [- 2.0 , 0 , 2.0 ], (1 , 4 )),
1696
+ ([1.0 , - 2.0 ], [- 1.0 , 0 , 1.0 ], (2 , 4 )),
1697
+ (np .zeros ((3 , 2 )), [- 2.0 , 0 , 1.0 ], (3 , 2 , 4 )),
1698
+ (np .ones ((5 , 2 )), [[- 2.0 , 0 , 1.0 ], [- 1.0 , 0 , 1.0 ]], (5 , 2 , 4 )),
1699
+ (np .ones ((3 , 5 , 2 )), [[- 2.0 , 0 , 1.0 ], [- 1.0 , 0 , 1.0 ]], (3 , 5 , 2 , 4 )),
1700
+ ],
1701
+ )
1702
+ def test_shape_inputs (self , eta , cutpoints , expected ):
1703
+ """
1704
+ This test checks when providing different shapes for `eta` parameters.
1705
+ """
1706
+ categorical = _OrderedLogistic .dist (
1707
+ eta = eta ,
1708
+ cutpoints = cutpoints ,
1709
+ )
1710
+ p = categorical .owner .inputs [3 ].eval ()
1711
+ assert p .shape == expected
1712
+
1691
1713
1692
1714
class TestOrderedProbit (BaseTestDistributionRandom ):
1693
1715
pymc_dist = _OrderedProbit
@@ -1698,6 +1720,30 @@ class TestOrderedProbit(BaseTestDistributionRandom):
1698
1720
"check_rv_size" ,
1699
1721
]
1700
1722
1723
+ @pytest .mark .parametrize (
1724
+ "eta, cutpoints, sigma, expected" ,
1725
+ [
1726
+ (0 , [- 2.0 , 0 , 2.0 ], 1.0 , (4 ,)),
1727
+ ([- 1 ], [- 1.0 , 0 , 2.0 ], [2.0 ], (1 , 4 )),
1728
+ ([1.0 , - 2.0 ], [- 1.0 , 0 , 1.0 ], 1.0 , (2 , 4 )),
1729
+ ([1.0 , - 2.0 , 3.0 ], [- 1.0 , 0 , 2.0 ], np .ones ((1 , 3 )), (1 , 3 , 4 )),
1730
+ (np .zeros ((2 , 3 )), [- 2.0 , 0 , 1.0 ], [1.0 , 2.0 , 5.0 ], (2 , 3 , 4 )),
1731
+ (np .ones ((2 , 3 )), [- 1.0 , 0 , 1.0 ], np .ones ((2 , 3 )), (2 , 3 , 4 )),
1732
+ (np .zeros ((5 , 2 )), [[- 2 , 0 , 1 ], [- 1 , 0 , 1 ]], np .ones ((2 , 5 , 2 )), (2 , 5 , 2 , 4 )),
1733
+ ],
1734
+ )
1735
+ def test_shape_inputs (self , eta , cutpoints , sigma , expected ):
1736
+ """
1737
+ This test checks when providing different shapes for `eta` and `sigma` parameters.
1738
+ """
1739
+ categorical = _OrderedProbit .dist (
1740
+ eta = eta ,
1741
+ cutpoints = cutpoints ,
1742
+ sigma = sigma ,
1743
+ )
1744
+ p = categorical .owner .inputs [3 ].eval ()
1745
+ assert p .shape == expected
1746
+
1701
1747
1702
1748
class TestOrderedMultinomial (BaseTestDistributionRandom ):
1703
1749
pymc_dist = _OrderedMultinomial
0 commit comments