Skip to content

Commit 2a5df64

Browse files
committed
added test for moment
1 parent fe7c2b2 commit 2a5df64

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

pymc/distributions/multivariate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2277,9 +2277,10 @@ def dist(cls, alpha, K, *args, **kwargs):
22772277
return super().dist([alpha, K], **kwargs)
22782278

22792279
def moment(rv, size, alpha, K):
2280+
alpha = alpha[..., np.newaxis]
22802281
moment = (alpha / (1 + alpha)) ** at.arange(K)
22812282
moment *= 1 / (1 + alpha)
2282-
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
2283+
moment = at.concatenate([moment, (alpha / (1 + alpha)) ** K], axis=-1)
22832284
if not rv_size_is_none(size):
22842285
moment_size = at.concatenate(
22852286
[

pymc/tests/test_distributions.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -954,21 +954,14 @@ def test_hierarchical_obs_logp():
954954

955955

956956
@pytest.fixture(scope="module")
957-
def _compile_stickbreakingweights_logpdf():
957+
def stickbreakingweights_logpdf():
958958
_value = at.vector()
959959
_alpha = at.scalar()
960960
_k = at.iscalar()
961961
_logp = logp(StickBreakingWeights.dist(_alpha, _k), _value)
962-
return compile_pymc([_value, _alpha, _k], _logp)
962+
core_fn = compile_pymc([_value, _alpha, _k], _logp)
963963

964-
965-
def _stickbreakingweights_logpdf(value, alpha, k, _compile_stickbreakingweights_logpdf):
966-
return _compile_stickbreakingweights_logpdf(value, alpha, k)
967-
968-
969-
stickbreakingweights_logpdf = np.vectorize(
970-
_stickbreakingweights_logpdf, signature="(n),(),(),()->()"
971-
)
964+
return np.vectorize(core_fn, signature="(n),(),()->()")
972965

973966

974967
class TestMatchesScipy:
@@ -2338,14 +2331,14 @@ def test_stickbreakingweights_invalid(self):
23382331
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
23392332
],
23402333
)
2341-
def test_stickbreakingweights_vectorized(self, alpha, K, _compile_stickbreakingweights_logpdf):
2334+
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf):
23422335
value = pm.StickBreakingWeights.dist(alpha, K).eval()
23432336
with Model():
23442337
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
23452338
pt = {"sbw": value}
23462339
assert_almost_equal(
23472340
pm.logp(sbw, value).eval(),
2348-
stickbreakingweights_logpdf(value, alpha, K, _compile_stickbreakingweights_logpdf),
2341+
stickbreakingweights_logpdf(value, alpha, K),
23492342
decimal=select_by_precision(float64=6, float32=2),
23502343
err_msg=str(pt),
23512344
)

pymc/tests/test_distributions_moments.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,32 @@ def test_rice_moment(nu, sigma, size, expected):
11661166
fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
11671167
),
11681168
),
1169+
(
1170+
np.array([1, 3]),
1171+
11,
1172+
None,
1173+
np.array(
1174+
[
1175+
np.append((1 / 2) ** np.arange(11) * 1 / 2, (1 / 2) ** 11),
1176+
np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11),
1177+
]
1178+
),
1179+
),
1180+
(
1181+
np.array([1, 3, 5]),
1182+
9,
1183+
(5, 3),
1184+
np.full(
1185+
shape=(5, 3, 10),
1186+
fill_value=np.array(
1187+
[
1188+
np.append((1 / 2) ** np.arange(9) * 1 / 2, (1 / 2) ** 9),
1189+
np.append((3 / 4) ** np.arange(9) * 1 / 4, (3 / 4) ** 9),
1190+
np.append((5 / 6) ** np.arange(9) * 1 / 6, (5 / 6) ** 9),
1191+
]
1192+
),
1193+
),
1194+
),
11691195
],
11701196
)
11711197
def test_stickbreakingweights_moment(alpha, K, size, expected):

0 commit comments

Comments
 (0)