Skip to content

Port Dirichlet Multinomial to v4 #4758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 53 additions & 92 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ class Dirichlet(Continuous):
a: array
Concentration parameters (a > 0).
"""

rv_op = dirichlet

def __new__(cls, name, *args, **kwargs):
Expand Down Expand Up @@ -504,11 +503,6 @@ def dist(cls, n, p, *args, **kwargs):
n = at.as_tensor_variable(n)
p = at.as_tensor_variable(p)

# mean = n * p
# mode = at.cast(at.round(mean), "int32")
# diff = n - at.sum(mode, axis=-1, keepdims=True)
# inc_bool_arr = at.abs_(diff) > 0
# mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
return super().dist([n, p], *args, **kwargs)

def logp(value, n, p):
Expand All @@ -518,7 +512,7 @@ def logp(value, n, p):

Parameters
----------
x: numeric
value: numeric
Value for which log-probability is calculated.

Returns
Expand All @@ -536,6 +530,46 @@ def logp(value, n, p):
)


class DirichletMultinomialRV(RandomVariable):
name = "dirichlet_multinomial"
ndim_supp = 1
ndims_params = [0, 1]
dtype = "int64"
_print_name = ("DirichletMN", "\\operatorname{DirichletMN}")

def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)

@classmethod
def rng_fn(cls, rng, n, a, size):

if n.ndim > 0 or a.ndim > 1:
n, a = broadcast_params([n, a], cls.ndims_params)
size = tuple(size or ())

if size:
n = np.broadcast_to(n, size + n.shape)
a = np.broadcast_to(a, size + a.shape)

res = np.empty(a.shape)
for idx in np.ndindex(a.shape[:-1]):
p = rng.dirichlet(a[idx])
res[idx] = rng.multinomial(n[idx], p)
return res
else:
# n is a scalar, a is a 1d array
p = rng.dirichlet(a, size=size) # (size, a.shape)

res = np.empty(p.shape)
for idx in np.ndindex(p.shape[:-1]):
res[idx] = rng.multinomial(n, p[idx])

return res


dirichlet_multinomial = DirichletMultinomialRV()


class DirichletMultinomial(Discrete):
r"""Dirichlet Multinomial log-likelihood.

Expand Down Expand Up @@ -569,92 +603,16 @@ class DirichletMultinomial(Discrete):
Describes shape of distribution. For example if n=array([5, 10]), and
a=array([1, 1, 1]), shape should be (2, 3).
"""
rv_op = dirichlet_multinomial

def __init__(self, n, a, shape, *args, **kwargs):

super().__init__(shape=shape, defaults=("_defaultval",), *args, **kwargs)

@classmethod
def dist(cls, n, a, *args, **kwargs):
n = intX(n)
a = floatX(a)
if len(self.shape) > 1:
self.n = at.shape_padright(n)
self.a = at.as_tensor_variable(a) if a.ndim > 1 else at.shape_padleft(a)
else:
# n is a scalar, p is a 1d array
self.n = at.as_tensor_variable(n)
self.a = at.as_tensor_variable(a)

p = self.a / self.a.sum(-1, keepdims=True)

self.mean = self.n * p
# Mode is only an approximation. Exact computation requires a complex
# iterative algorithm as described in https://doi.org/10.1016/j.spl.2009.09.013
mode = at.cast(at.round(self.mean), "int32")
diff = self.n - at.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = at.abs_(diff) > 0
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
self._defaultval = mode

def _random(self, n, a, size=None):
# numpy will cast dirichlet and multinomial samples to float64 by default
original_dtype = a.dtype

# Thanks to the default shape handling done in generate_values, the last
# axis of n is a dummy axis that allows it to broadcast well with `a`
n = np.broadcast_to(n, size)
a = np.broadcast_to(a, size)
n = n[..., 0]

# np.random.multinomial needs `n` to be a scalar int and `a` a
# sequence so we semi flatten them and iterate over them
n_ = n.reshape([-1])
a_ = a.reshape([-1, a.shape[-1]])
p_ = np.array([np.random.dirichlet(aa) for aa in a_])
samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)])
samples = samples.reshape(a.shape)

# We cast back to the original dtype
return samples.astype(original_dtype)

def random(self, point=None, size=None):
"""
Draw random values from Dirichlet-Multinomial distribution.

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).

Returns
-------
array
"""
# n, a = draw_values([self.n, self.a], point=point, size=size)
# samples = generate_samples(
# self._random,
# n,
# a,
# dist_shape=self.shape,
# size=size,
# )
#
# # If distribution is initialized with .dist(), valid init shape is not asserted.
# # Under normal use in a model context valid init shape is asserted at start.
# expected_shape = to_tuple(size) + to_tuple(self.shape)
# sample_shape = tuple(samples.shape)
# if sample_shape != expected_shape:
# raise ShapeError(
# f"Expected sample shape was {expected_shape} but got {sample_shape}. "
# "This may reflect an invalid initialization shape."
# )
#
# return samples
return super().dist([n, a], **kwargs)

def logp(self, value):
def logp(value, n, a):
"""
Calculate log-probability of DirichletMultinomial distribution
at specified value.
Expand All @@ -668,13 +626,16 @@ def logp(self, value):
-------
TensorVariable
"""
a = self.a
n = self.n
sum_a = a.sum(axis=-1, keepdims=True)
if value.ndim >= 1:
n = at.shape_padright(n)
if a.ndim > 1:
a = at.shape_padleft(a)

sum_a = a.sum(axis=-1, keepdims=True)
const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a)
series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a))
result = const + series.sum(axis=-1, keepdims=True)

# Bounds checking to confirm parameters and data meet all constraints
# and that each observation value_i sums to n_i.
return bound(
Expand Down Expand Up @@ -855,7 +816,7 @@ def logp(X, nu, V):


def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initval=None):
R"""
r"""
Bartlett decomposition of the Wishart distribution. As the Wishart
distribution requires the matrix to be symmetric positive semi-definite
it is impossible for MCMC to ever propose acceptable matrices.
Expand Down
69 changes: 41 additions & 28 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,8 +2272,17 @@ def test_batch_multinomial(self):
sample = dist.eval()
assert_allclose(sample, np.stack([vals, vals], axis=0))

def test_multinomial_zero_probs(self):
# test multinomial accepts 0 probabilities / observations:
value = aesara.shared(np.array([0, 0, 100], dtype=int))
logp = pm.Multinomial.logp(value=value, n=100, p=at.constant([0.0, 0.0, 1.0]))
logp_fn = aesara.function(inputs=[], outputs=logp)
assert logp_fn() >= 0

value.set_value(np.array([50, 50, 0], dtype=int))
assert np.isneginf(logp_fn())

@pytest.mark.parametrize("n", [2, 3])
@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_dirichlet_multinomial(self, n):
self.check_logp(
DirichletMultinomial,
Expand All @@ -2282,43 +2291,47 @@ def test_dirichlet_multinomial(self, n):
dirichlet_multinomial_logpmf,
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_dirichlet_multinomial_matches_beta_binomial(self):
a, b, n = 2, 1, 5
ns = np.arange(n + 1)
ns_dm = np.vstack((ns, n - ns)).T # covert ns=1 to ns_dm=[1, 4], for all ns...
bb_logp = logpt(pm.BetaBinomial.dist(n=n, alpha=a, beta=b), ns).tag.test_value
dm_logp = logpt(
pm.DirichletMultinomial.dist(n=n, a=[a, b], size=(1, 2)), ns_dm
).tag.test_value
dm_logp = dm_logp.ravel()
ns_dm = np.vstack((ns, n - ns)).T # convert ns=1 to ns_dm=[1, 4], for all ns...

bb = pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2)
bb_value = bb.type()
bb.tag.value_var = bb_value
bb_logp = logpt(var=bb, rv_values={bb: bb_value}).eval({bb_value: ns})

dm = pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2)
dm_value = dm.type()
dm.tag.value_var = dm_value
dm_logp = logpt(var=dm, rv_values={dm: dm_value}).eval({dm_value: ns_dm}).ravel()

assert_almost_equal(
dm_logp,
bb_logp,
decimal=select_by_precision(float64=6, float32=3),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_dirichlet_multinomial_vec(self):
vals = np.array([[2, 4, 4], [3, 3, 4]])
a = np.array([0.2, 0.3, 0.5])
n = 10

with Model() as model_single:
DirichletMultinomial("m", n=n, a=a, size=len(a))
DirichletMultinomial("m", n=n, a=a)

with Model() as model_many:
DirichletMultinomial("m", n=n, a=a, size=vals.shape)
DirichletMultinomial("m", n=n, a=a, size=2)

assert_almost_equal(
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
np.asarray([model_single.fastlogp({"m": val}) for val in vals]),
decimal=4,
)

assert_almost_equal(
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
model_many.free_RVs[0].logp_elemwise({"m": vals}).squeeze(),
np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
logpt(model_many.m, vals).eval().squeeze(),
decimal=4,
)

Expand All @@ -2328,56 +2341,52 @@ def test_dirichlet_multinomial_vec(self):
decimal=4,
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_dirichlet_multinomial_vec_1d_n(self):
vals = np.array([[2, 4, 4], [4, 3, 4]])
a = np.array([0.2, 0.3, 0.5])
ns = np.array([10, 11])

with Model() as model:
DirichletMultinomial("m", n=ns, a=a, size=vals.shape)
DirichletMultinomial("m", n=ns, a=a)

assert_almost_equal(
sum(dirichlet_multinomial_logpmf(val, n, a) for val, n in zip(vals, ns)),
model.fastlogp({"m": vals}),
decimal=4,
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_dirichlet_multinomial_vec_1d_n_2d_a(self):
vals = np.array([[2, 4, 4], [4, 3, 4]])
as_ = np.array([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]])
ns = np.array([10, 11])

with Model() as model:
DirichletMultinomial("m", n=ns, a=as_, size=vals.shape)
DirichletMultinomial("m", n=ns, a=as_)

assert_almost_equal(
sum(dirichlet_multinomial_logpmf(val, n, a) for val, n, a in zip(vals, ns, as_)),
model.fastlogp({"m": vals}),
decimal=4,
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_dirichlet_multinomial_vec_2d_a(self):
vals = np.array([[2, 4, 4], [3, 3, 4]])
as_ = np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]])
n = 10

with Model() as model:
DirichletMultinomial("m", n=n, a=as_, size=vals.shape)
DirichletMultinomial("m", n=n, a=as_)

assert_almost_equal(
sum(dirichlet_multinomial_logpmf(val, n, a) for val, a in zip(vals, as_)),
model.fastlogp({"m": vals}),
decimal=4,
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_batch_dirichlet_multinomial(self):
# Test that DM can handle a 3d array for `a`

# Create an almost deterministic DM by setting a to 0.001, everywehere
# Create an almost deterministic DM by setting a to 0.001, everywhere
# except for one category / dimension which is given the value of 1000
n = 5
vals = np.zeros((4, 5, 3), dtype="int32")
Expand All @@ -2386,19 +2395,23 @@ def test_batch_dirichlet_multinomial(self):
np.put_along_axis(vals, inds, n, axis=-1)
np.put_along_axis(a, inds, 1000, axis=-1)

dist = DirichletMultinomial.dist(n=n, a=a, size=vals.shape)
dist = DirichletMultinomial.dist(n=n, a=a)

# Logp should be approx -9.924431e-06
dist_logp = logpt(dist, vals).tag.test_value
expected_logp = np.full(shape=vals.shape[:-1] + (1,), fill_value=-9.924431e-06)
# Logp should be approx -9.98004998e-06
value = at.tensor3(dtype="int32")
value.tag.test_value = np.zeros_like(vals, dtype="int32")
logp = logpt(dist, value)
f = aesara.function(inputs=[value], outputs=logp)
expected_logp = np.full(shape=f(vals).shape, fill_value=-9.98004998e-06)
assert_almost_equal(
dist_logp,
f(vals),
expected_logp,
decimal=select_by_precision(float64=6, float32=3),
)

# Samples should be equal given the almost deterministic DM
sample = dist.random(size=2)
dist = DirichletMultinomial.dist(n=n, a=a, size=2)
sample = dist.eval()
assert_allclose(sample, np.stack([vals, vals], axis=0))

@aesara.config.change_flags(compute_test_value="raise")
Expand Down
Loading