Description
Hi all,
I have a model that used to run in PyMC3, and which I am now upgrading to PyMC v4. I'm having a great time with PyMC v4, but I'm running into a bit of trouble here. It's actually a ported version of the US POTUS model originally written in Stan. I am trying to run the code in PyMC v4, and sampling appears to go fine when I run it with PyMC itself (it takes a while and I haven't checked thoroughly yet, but it runs at least). However, when I try to run the JAX sampler, I get the following error:
Compiling...
Compilation time = 0:00:08.213230
Sampling...
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [7], in <cell line: 1>()
1 with m2 as model:
----> 2 res = pm.sampling_jax.sample_numpyro_nuts()
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:514, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
511 if chains > 1:
512 map_seed = jax.random.split(map_seed, chains)
--> 514 pmap_numpyro.run(
515 map_seed,
516 init_params=init_params,
517 extra_fields=(
518 "num_steps",
519 "potential_energy",
520 "energy",
521 "adapt_state.step_size",
522 "accept_prob",
523 "diverging",
524 ),
525 )
527 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
529 tic3 = datetime.now()
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
--> 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == "vectorized"
[... skipping hidden 17 frame]
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc.py:746, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
(...)
743 rng_key=rng_key,
744 )
745 if rng_key.ndim == 1:
--> 746 init_state = hmc_init_fn(init_params, rng_key)
747 else:
748 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
749 # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
750 # wa_steps because those variables do not depend on traced args: init_params, rng_key.
751 init_state = vmap(hmc_init_fn)(init_params, rng_key)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc.py:726, in HMC.init.<locals>.<lambda>(init_params, rng_key)
723 dense_mass = [tuple(sorted(z))] if dense_mass else []
724 assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
729 step_size=self._step_size,
730 num_steps=self._num_steps,
731 inverse_mass_matrix=inverse_mass_matrix,
732 adapt_step_size=self._adapt_step_size,
733 adapt_mass_matrix=self._adapt_mass_matrix,
734 dense_mass=dense_mass,
735 target_accept_prob=self._target_accept_prob,
736 trajectory_length=self._trajectory_length,
737 max_tree_depth=self._max_tree_depth,
738 find_heuristic_step_size=self._find_heuristic_step_size,
739 forward_mode_differentiation=self._forward_mode_differentiation,
740 regularize_mass_matrix=self._regularize_mass_matrix,
741 model_args=model_args,
742 model_kwargs=model_kwargs,
743 rng_key=rng_key,
744 )
745 if rng_key.ndim == 1:
746 init_state = hmc_init_fn(init_params, rng_key)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc.py:322, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
323 energy = vv_state.potential_energy + kinetic_fn(
324 wa_state.inverse_mass_matrix, vv_state.r
325 )
326 zero_int = jnp.array(0, dtype=jnp.result_type(int))
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:278, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
270 """
271 :param z: Position of the particle.
272 :param r: Momentum of the particle.
(...)
275 :return: initial state for the integrator.
276 """
277 if potential_energy is None or z_grad is None:
--> 278 potential_energy, z_grad = _value_and_grad(
279 potential_fn, z, forward_mode_differentiation
280 )
281 return IntegratorState(z, r, potential_energy, z_grad)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
244 return f(x), jacfwd(f)(x)
245 else:
--> 246 return value_and_grad(f)(x)
[... skipping hidden 8 frame]
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:109, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
108 def logp_fn_wrap(x):
--> 109 return logp_fn(*x)[0]
File /var/folders/s_/l9wd4yls3mv1f4kkkz1xhhmm0000gn/T/tmpcdl0ef2w:94, in jax_funcified_fgraph(raw_polling_bias, raw_mu_b_T, raw_mu_b, raw_mu_c, raw_mu_m, raw_mu_pop, mu_e_bias, rho_e_bias_interval_, raw_e_bias, raw_measure_noise_national, raw_measure_noise_state)
92 auto_114085 = equal(auto_114087, auto_116679)
93 # Reshape{1}(Elemwise{Composite{(i0 * sqrt((i1 - sqr(i2))))}}.0, TensorConstant{(1,) of -1})
---> 94 auto_113078 = reshape(auto_122013, auto_19078)
95 # Elemwise{Composite{(i0 * (i1 ** i2))}}(TensorConstant{[[1. 0. 0...1. 1. 1.]]}, InplaceDimShuffle{x,x}.0, TensorConstant{[[ 0 1 .. 1 0]]})
96 auto_121883 = composite12(auto_19117, auto_113086, auto_19111)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:733, in jax_funcify_Reshape.<locals>.reshape(x, shape)
732 def reshape(x, shape):
--> 733 return jnp.reshape(x, shape)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:741, in reshape(a, newshape, order)
739 _stackable(a) or _check_arraylike("reshape", a)
740 try:
--> 741 return a.reshape(newshape, order=order) # forward to method for ndarrays
742 except AttributeError:
743 return _reshape(a, newshape, order=order)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:759, in _reshape(a, order, *args)
758 def _reshape(a, *args, order="C"):
--> 759 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
760 if order == "C":
761 return lax.reshape(a, newshape, None)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:752, in _compute_newshape(a, newshape)
750 except: iterable = False
751 else: iterable = True
--> 752 newshape = core.canonicalize_shape(newshape if iterable else [newshape])
753 return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
754 if core.symbolic_equal_dim(d, -1) else d
755 for d in newshape)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/jax/core.py:1790, in canonicalize_shape(shape, context)
1788 except TypeError:
1789 pass
-> 1790 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [-1].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
I've tried to debug this, but I'm a bit stumped! Here's the full model code:
import json
import numpy as np
import pymc as pm
import aesara
data = json.load(open('./exported_stan_data_2016.json'))
np_data = {x: np.squeeze(np.array(y)) for x, y in data.items()}
shapes = {
'N_national_polls': int(np_data['N_national_polls']),
'N_state_polls': int(np_data['N_state_polls']),
'T': int(np_data['T']),
'S': int(np_data['S']),
'P': int(np_data['P']),
'M': int(np_data['M']),
'Pop': int(np_data['Pop'])
}
national_cov_matrix_error_sd = np.sqrt(
np.squeeze(
np_data["state_weights"].reshape(1, -1)
@ (np_data["state_covariance_0"] @ np_data["state_weights"].reshape(-1, 1))
)
)
ss_cov_poll_bias = (
np_data["state_covariance_0"]
* (np_data["polling_bias_scale"] / national_cov_matrix_error_sd) ** 2
)
ss_cov_mu_b_T = (
np_data["state_covariance_0"]
* (np_data["mu_b_T_scale"] / national_cov_matrix_error_sd) ** 2
)
ss_cov_mu_b_walk = (
np_data["state_covariance_0"]
* (np_data["random_walk_scale"] / national_cov_matrix_error_sd) ** 2
)
cholesky_ss_cov_poll_bias = np.linalg.cholesky(ss_cov_poll_bias)
cholesky_ss_cov_mu_b_T = np.linalg.cholesky(ss_cov_mu_b_T)
cholesky_ss_cov_mu_b_walk = np.linalg.cholesky(ss_cov_mu_b_walk)
i, j = np.indices((np_data["T"], np_data["T"]))
mask = np.tril(np.ones((np_data["T"], np_data["T"])))
with pm.Model() as m2:
normal_dist = pm.Normal.dist(mu=0.7, sigma=0.1)
#BoundedNormalZeroOne = pm.Bound(pm.Normal, lower=0.0, upper=1.0)
raw_polling_bias = pm.Normal(
"raw_polling_bias", mu=0.0, sigma=1.0, shape=(shapes["S"])
)
raw_mu_b_T = pm.Normal("raw_mu_b_T", mu=0.0, sigma=1.0, shape=(shapes["S"]))
raw_mu_b = pm.Normal(
"raw_mu_b", mu=0.0, sigma=1.0, shape=(shapes["S"], shapes["T"])
)
raw_mu_c = pm.Normal("raw_mu_c", mu=0.0, sigma=1.0, shape=(shapes["P"]))
raw_mu_m = pm.Normal("raw_mu_m", mu=0.0, sigma=1.0, shape=(shapes["M"]))
raw_mu_pop = pm.Normal("raw_mu_pop", mu=0.0, sigma=1.0, shape=(shapes["Pop"]))
# This has offset multiplier syntax in Stan, but ignore for now.
mu_e_bias = pm.Normal("mu_e_bias", mu=0.0, sigma=0.02)
# This may be an issue?
# rho_e_bias = BoundedNormalZeroOne("rho_e_bias", mu=0.7, sigma=0.1)
rho_e_bias = pm.Bound("rho_e_bias", normal_dist, lower=0., upper=1.)
raw_e_bias = pm.Normal("raw_e_bias", mu=0.0, sigma=1.0, shape=(shapes["T"]))
raw_measure_noise_national = pm.Normal(
"raw_measure_noise_national",
mu=0.0,
sigma=1.0,
shape=(shapes["N_national_polls"]),
)
raw_measure_noise_state = pm.Normal(
"raw_measure_noise_state",
mu=0.0,
sigma=1.0,
shape=(shapes["N_state_polls"]),
)
polling_bias = pm.Deterministic(
"polling_bias",
aesara.tensor.dot(cholesky_ss_cov_poll_bias, raw_polling_bias),
)
national_polling_bias_average = pm.Deterministic(
"national_polling_bias_average",
pm.math.sum(polling_bias * np_data["state_weights"]),
)
mu_b_final = (
pm.math.dot(cholesky_ss_cov_mu_b_T, raw_mu_b_T) + np_data["mu_b_prior"]
)
# Innovations
innovs = pm.math.matrix_dot(cholesky_ss_cov_mu_b_walk, raw_mu_b[:, :-1])
# Reverse these (?)
innovs = aesara.tensor.transpose(aesara.tensor.transpose(innovs)[::-1])
# Tack on the "first" one:
together = pm.math.concatenate(
[aesara.tensor.reshape(mu_b_final, (-1, 1)), innovs], axis=1
)
# Compute the cumulative sums:
cumsums = aesara.tensor.cumsum(together, axis=1)
# To be [time, states]
transposed = aesara.tensor.transpose(cumsums)
mu_b = pm.Deterministic("mu_b", aesara.tensor.transpose(transposed[::-1]))
national_mu_b_average = pm.Deterministic(
"national_mu_b_average",
pm.math.matrix_dot(
aesara.tensor.transpose(mu_b), np_data["state_weights"].reshape((-1, 1))
),
)[:, 0]
mu_c = pm.Deterministic("mu_c", raw_mu_c * np_data["sigma_c"])
mu_m = pm.Deterministic("mu_m", raw_mu_m * np_data["sigma_m"])
mu_pop = pm.Deterministic("mu_pop", raw_mu_pop * np_data["sigma_pop"])
# Matrix version:
# e_bias_init = raw_e_bias[0] * np_data['sigma_e_bias']
sigma_rho = pm.math.sqrt(1 - (rho_e_bias ** 2)) * np_data["sigma_e_bias"]
sigma_vec = pm.math.concatenate(
[
[np_data["sigma_e_bias"]],
aesara.tensor.repeat(sigma_rho, np_data["T"] - 1),
]
)
mus = pm.math.concatenate(
[
[0.0],
aesara.tensor.repeat(mu_e_bias * (1 - rho_e_bias), shapes["T"] - 1),
]
)
A_inv = mask * (rho_e_bias ** (np.abs(i - j)))
e_bias = pm.Deterministic(
"e_bias", pm.math.matrix_dot(A_inv, sigma_vec * raw_e_bias + mus)
)
# Minus ones shenanigans required for different indexing
logit_pi_democrat_state = (
mu_b[np_data["state"] - 1, np_data["day_state"] - 1]
+ mu_c[np_data["poll_state"] - 1]
+ mu_m[np_data["poll_mode_state"] - 1]
+ mu_pop[np_data["poll_pop_state"] - 1]
+ np_data["unadjusted_state"] * e_bias[np_data["day_state"] - 1]
+ raw_measure_noise_state * np_data["sigma_measure_noise_state"]
+ polling_bias[np_data["state"] - 1]
)
logit_pi_democrat_state = pm.Deterministic(
"logit_pi_democrat_state", logit_pi_democrat_state
)
logit_pi_democrat_national = (
national_mu_b_average[np_data["day_national"] - 1]
+ mu_c[np_data["poll_national"] - 1]
+ mu_m[np_data["poll_mode_national"] - 1]
+ mu_pop[np_data["poll_pop_national"] - 1]
+ np_data["unadjusted_national"] * e_bias[np_data["day_national"] - 1]
+ raw_measure_noise_national * np_data["sigma_measure_noise_national"]
+ national_polling_bias_average
)
logit_pi_democrat_national = pm.Deterministic(
"logit_pi_democrat_national", logit_pi_democrat_national
)
prob_state = aesara.tensor.sigmoid(logit_pi_democrat_state)
prob_nat = aesara.tensor.sigmoid(logit_pi_democrat_national)
state_lik = pm.Binomial(
"state_lik",
n=np_data["n_two_share_state"],
p=prob_state,
observed=np_data["n_democrat_state"],
)
national_lik = pm.Binomial(
"nat_lik",
n=np_data["n_two_share_national"],
p=prob_nat,
observed=np_data["n_democrat_national"],
)
Required data to run it is here:
exported_stan_data_2016.json.zip
I am using PyMC v4.1.3, and aesara v2.7.7. This may also be the same problem as reported in the discourse for another model here.
Thanks a lot for your help!