Skip to content

Commit e11969a

Browse files
zaxtaxtwiecki
andauthored
Add sigma to OrderedProbit fixes #5103 (#5141)
* Add sigma to OrderedProbit fixes #5103 * Updated tests and documentation * pep8 * Undo changes to tests * Update pymc/distributions/discrete.py Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 444d66a commit e11969a

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,16 +1761,16 @@ class _OrderedProbit(Categorical):
17611761
rv_op = categorical
17621762

17631763
@classmethod
1764-
def dist(cls, eta, cutpoints, *args, **kwargs):
1764+
def dist(cls, eta, cutpoints, sigma=1, *args, **kwargs):
17651765
eta = at.as_tensor_variable(floatX(eta))
17661766
cutpoints = at.as_tensor_variable(cutpoints)
17671767

17681768
probits = at.shape_padright(eta) - cutpoints
17691769
_log_p = at.concatenate(
17701770
[
1771-
at.shape_padright(normal_lccdf(0, 1, probits[..., 0])),
1772-
log_diff_normal_cdf(0, 1, probits[..., :-1], probits[..., 1:]),
1773-
at.shape_padright(normal_lcdf(0, 1, probits[..., -1])),
1771+
at.shape_padright(normal_lccdf(0, sigma, probits[..., 0])),
1772+
log_diff_normal_cdf(0, sigma, probits[..., :-1], probits[..., 1:]),
1773+
at.shape_padright(normal_lcdf(0, sigma, probits[..., -1])),
17741774
],
17751775
axis=-1,
17761776
)
@@ -1816,12 +1816,12 @@ class OrderedProbit:
18161816
The length K - 1 array of cutpoints which break :math:`\eta` into
18171817
ranges. Do not explicitly set the first and last elements of
18181818
:math:`c` to negative and positive infinity.
1819+
sigma: float, default 1.0
1820+
Standard deviation of the probit function.
18191821
compute_p: boolean, default True
18201822
Whether to compute and store in the trace the inferred probabilities of each categories,
18211823
based on the cutpoints' values. Defaults to True.
18221824
Might be useful to disable it if memory usage is of interest.
1823-
sigma: float
1824-
The standard deviation of probit function.
18251825
Example
18261826
--------
18271827
.. code:: python

pymc/tests/test_distributions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3132,11 +3132,11 @@ def test_ordered_logistic_probs():
31323132

31333133
def test_ordered_probit_probs():
31343134
with pm.Model() as m:
3135-
pm.OrderedProbit("op_p", cutpoints=np.array([-2, 0, 2]), eta=0)
3136-
pm.OrderedProbit("op_no_p", cutpoints=np.array([-2, 0, 2]), eta=0, compute_p=False)
3135+
pm.OrderedProbit("op_p", cutpoints=np.array([-2, 0, 2]), eta=0, sigma=1)
3136+
pm.OrderedProbit("op_no_p", cutpoints=np.array([-2, 0, 2]), eta=0, sigma=1, compute_p=False)
31373137
assert len(m.deterministics) == 1
31383138

3139-
x = pm.OrderedProbit.dist(cutpoints=np.array([-2, 0, 2]), eta=0)
3139+
x = pm.OrderedProbit.dist(cutpoints=np.array([-2, 0, 2]), eta=0, sigma=1)
31403140
assert isinstance(x, TensorVariable)
31413141

31423142

0 commit comments

Comments
 (0)