Skip to content

Commit ac99306

Browse files
committed
Add DiscreteWeibull moment
1 parent 71b18e6 commit ac99306

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,12 @@ def dist(cls, q, beta, *args, **kwargs):
490490
beta = at.as_tensor_variable(floatX(beta))
491491
return super().dist([q, beta], **kwargs)
492492

493+
def get_moment(rv, size, q, beta):
494+
median = at.power(at.log(0.5) / at.log(q), 1 / beta) - 1
495+
if not rv_size_is_none(size):
496+
median = at.full(size, median)
497+
return median
498+
493499
def logp(value, q, beta):
494500
r"""
495501
Calculate log-probability of DiscreteWeibull distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DensityDist,
2323
Dirichlet,
2424
DiscreteUniform,
25+
DiscreteWeibull,
2526
ExGaussian,
2627
Exponential,
2728
Flat,
@@ -110,7 +111,6 @@ def test_all_distributions_have_moments():
110111

111112
# Distributions that have been refactored but don't yet have moments
112113
not_implemented |= {
113-
dist_module.discrete.DiscreteWeibull,
114114
dist_module.multivariate.DirichletMultinomial,
115115
dist_module.multivariate.Wishart,
116116
}
@@ -751,6 +751,26 @@ def test_discrete_uniform_moment(lower, upper, size, expected):
751751
DiscreteUniform("x", lower=lower, upper=upper, size=size)
752752

753753

754+
@pytest.mark.parametrize(
755+
"q, beta, size, expected",
756+
[
757+
(0.5, 0.5, None, 0),
758+
(0.6, 0.1, 5, (20,) * 5),
759+
(np.linspace(0.25, 0.99, 4), 0.42, None, [0, 0, 6, 23862]),
760+
(
761+
np.linspace(0.5, 0.99, 3),
762+
[[1, 1.25, 1.75], [1.25, 0.75, 0.5]],
763+
None,
764+
[[0, 0, 10], [0, 2, 4755]],
765+
),
766+
],
767+
)
768+
def test_discrete_weibull_moment(q, beta, size, expected):
769+
with Model() as model:
770+
DiscreteWeibull("x", q=q, beta=beta, size=size)
771+
assert_moment_is_expected(model, expected)
772+
773+
754774
@pytest.mark.parametrize(
755775
"a, size, expected",
756776
[

0 commit comments

Comments
 (0)