Hi, new to the community and bayesian modeling in general.
I am converting some example code from pymc v3 to pymc v5.
The original code comes from this notebook
The relevant section is this:
with pm.Model() as model:
# Index to true model
prior_model_prob = 0.5
#tau = pm.DiscreteUniform('tau', lower=0, upper=1)
tau = pm.Bernoulli('tau', prior_model_prob)
# Poisson parameters
mu_p = pm.Uniform('mu_p', 0, 60)
# Negative Binomial parameters
alpha = pm.Exponential('alpha', lam=0.2)
mu_nb = pm.Uniform('mu_nb', lower=0, upper=60)
y_like = pm.DensityDist('y_like',
lambda value: pm.math.switch(tau,
pm.Poisson.dist(mu_p).logp(value),
pm.NegativeBinomial.dist(mu_nb, alpha).logp(value)
),
observed=messages['time_delay_seconds'].values)
start = pm.find_MAP()
step1 = pm.Metropolis([mu_p, alpha, mu_nb])
step2 = pm.ElemwiseCategorical(vars=[tau], values=[0,1])
trace = pm.sample(200000, step=[step1, step2], start=start)
_ = pm.traceplot(trace[burnin:], varnames=['tau'])
My attempt at converting it looks like this:
def dist(
tau: TensorVariable,
mu_p: TensorVariable,
mu_nb: TensorVariable,
alpha: TensorVariable,
size: TensorVariable
) -> TensorVariable:
poisson_ = pm.Poisson.dist(mu_p, size=size)
negative_binomial_ = pm.NegativeBinomial.dist(mu_nb, alpha, size=size)
return pm.math.switch(
pm.math.eq(tau, 0),
poisson_,
negative_binomial_
)
with pm.Model() as model:
# Index to true model
prior_model_prob = 0.5
tau = pm.Bernoulli('tau', p=prior_model_prob)
# Poisson parameters
mu_p = pm.Uniform('mu_p', lower=0, upper=60)
# Negative Binomial parameters
alpha = pm.Exponential('alpha', lam=0.2)
mu_nb = pm.Uniform('mu_nb', lower=0, upper=60)
y_like = pm.CustomDist(
'y_like',
tau,
mu_p,
mu_nb,
alpha,
logp=logp,
observed=messages['time_delay_seconds'].values
)
start = pm.find_MAP()
step1 = pm.Metropolis([mu_p, alpha, mu_nb])
step2 = pm.BinaryGibbsMetropolis([tau])
trace = pm.sample(
draws=200000,
step=[step1, step2],
start=start,
progressbar=False
)
When comparing the mean of the posterior values of tau, I get 0 instead of a value between 0 and 1. What am I missing?