Skip to content

JAX-ified model throws error with SpecifyShape Op #6050

Closed
@martiningram

Description

@martiningram

Hi all,

I have a model that used to run in PyMC3, and which I am now upgrading to PyMC v4. I'm having a great time with PyMC v4, but I'm running into a bit of trouble here. It's actually a ported version of the US POTUS model originally written in Stan. I am trying to run the code in PyMC v4, and sampling appears to go fine when I run it with PyMC itself (it takes a while and I haven't checked thoroughly yet, but it runs at least). However, when I try to run the JAX sampler, I get the following error:

Compiling...
Compilation time =  0:00:08.213230
Sampling...
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [7], in <cell line: 1>()
      1 with m2 as model:
----> 2     res = pm.sampling_jax.sample_numpyro_nuts()

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:514, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    511 if chains > 1:
    512     map_seed = jax.random.split(map_seed, chains)
--> 514 pmap_numpyro.run(
    515     map_seed,
    516     init_params=init_params,
    517     extra_fields=(
    518         "num_steps",
    519         "potential_energy",
    520         "energy",
    521         "adapt_state.step_size",
    522         "accept_prob",
    523         "diverging",
    524     ),
    525 )
    527 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    529 tic3 = datetime.now()

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    597     states, last_state = _laxmap(partial_map_fn, map_args)
    598 elif self.chain_method == "parallel":
--> 599     states, last_state = pmap(partial_map_fn)(map_args)
    600 else:
    601     assert self.chain_method == "vectorized"

    [... skipping hidden 17 frame]

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    379 rng_key, init_state, init_params = init
    380 if init_state is None:
--> 381     init_state = self.sampler.init(
    382         rng_key,
    383         self.num_warmup,
    384         init_params,
    385         model_args=args,
    386         model_kwargs=kwargs,
    387     )
    388 sample_fn, postprocess_fn = self._get_cached_fns()
    389 diagnostics = (
    390     lambda x: self.sampler.get_diagnostics_str(x[0])
    391     if rng_key.ndim == 1
    392     else ""
    393 )  # noqa: E731

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc.py:746, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
   (...)
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
--> 746     init_state = hmc_init_fn(init_params, rng_key)
    747 else:
    748     # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
    749     # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
    750     # wa_steps because those variables do not depend on traced args: init_params, rng_key.
    751     init_state = vmap(hmc_init_fn)(init_params, rng_key)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc.py:726, in HMC.init.<locals>.<lambda>(init_params, rng_key)
    723         dense_mass = [tuple(sorted(z))] if dense_mass else []
    724     assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
    729     step_size=self._step_size,
    730     num_steps=self._num_steps,
    731     inverse_mass_matrix=inverse_mass_matrix,
    732     adapt_step_size=self._adapt_step_size,
    733     adapt_mass_matrix=self._adapt_mass_matrix,
    734     dense_mass=dense_mass,
    735     target_accept_prob=self._target_accept_prob,
    736     trajectory_length=self._trajectory_length,
    737     max_tree_depth=self._max_tree_depth,
    738     find_heuristic_step_size=self._find_heuristic_step_size,
    739     forward_mode_differentiation=self._forward_mode_differentiation,
    740     regularize_mass_matrix=self._regularize_mass_matrix,
    741     model_args=model_args,
    742     model_kwargs=model_kwargs,
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
    746     init_state = hmc_init_fn(init_params, rng_key)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc.py:322, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
    320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
    321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
    323 energy = vv_state.potential_energy + kinetic_fn(
    324     wa_state.inverse_mass_matrix, vv_state.r
    325 )
    326 zero_int = jnp.array(0, dtype=jnp.result_type(int))

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:278, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
    270 """
    271 :param z: Position of the particle.
    272 :param r: Momentum of the particle.
   (...)
    275 :return: initial state for the integrator.
    276 """
    277 if potential_energy is None or z_grad is None:
--> 278     potential_energy, z_grad = _value_and_grad(
    279         potential_fn, z, forward_mode_differentiation
    280     )
    281 return IntegratorState(z, r, potential_energy, z_grad)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
    244     return f(x), jacfwd(f)(x)
    245 else:
--> 246     return value_and_grad(f)(x)

    [... skipping hidden 8 frame]

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:109, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
    108 def logp_fn_wrap(x):
--> 109     return logp_fn(*x)[0]

File /var/folders/s_/l9wd4yls3mv1f4kkkz1xhhmm0000gn/T/tmpcdl0ef2w:94, in jax_funcified_fgraph(raw_polling_bias, raw_mu_b_T, raw_mu_b, raw_mu_c, raw_mu_m, raw_mu_pop, mu_e_bias, rho_e_bias_interval_, raw_e_bias, raw_measure_noise_national, raw_measure_noise_state)
     92 auto_114085 = equal(auto_114087, auto_116679)
     93 # Reshape{1}(Elemwise{Composite{(i0 * sqrt((i1 - sqr(i2))))}}.0, TensorConstant{(1,) of -1})
---> 94 auto_113078 = reshape(auto_122013, auto_19078)
     95 # Elemwise{Composite{(i0 * (i1 ** i2))}}(TensorConstant{[[1. 0. 0...1. 1. 1.]]}, InplaceDimShuffle{x,x}.0, TensorConstant{[[  0   1 ..   1   0]]})
     96 auto_121883 = composite12(auto_19117, auto_113086, auto_19111)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:733, in jax_funcify_Reshape.<locals>.reshape(x, shape)
    732 def reshape(x, shape):
--> 733     return jnp.reshape(x, shape)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:741, in reshape(a, newshape, order)
    739 _stackable(a) or _check_arraylike("reshape", a)
    740 try:
--> 741   return a.reshape(newshape, order=order)  # forward to method for ndarrays
    742 except AttributeError:
    743   return _reshape(a, newshape, order=order)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:759, in _reshape(a, order, *args)
    758 def _reshape(a, *args, order="C"):
--> 759   newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
    760   if order == "C":
    761     return lax.reshape(a, newshape, None)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:752, in _compute_newshape(a, newshape)
    750 except: iterable = False
    751 else: iterable = True
--> 752 newshape = core.canonicalize_shape(newshape if iterable else [newshape])
    753 return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
    754              if core.symbolic_equal_dim(d, -1) else d
    755              for d in newshape)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/jax/core.py:1790, in canonicalize_shape(shape, context)
   1788 except TypeError:
   1789   pass
-> 1790 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [-1].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

I've tried to debug this, but I'm a bit stumped! Here's the full model code:

import json
import numpy as np
import pymc as pm
import aesara

data = json.load(open('./exported_stan_data_2016.json'))

np_data = {x: np.squeeze(np.array(y)) for x, y in data.items()}

shapes = {
    'N_national_polls': int(np_data['N_national_polls']),
    'N_state_polls': int(np_data['N_state_polls']),
    'T': int(np_data['T']),
    'S': int(np_data['S']),
    'P': int(np_data['P']),
    'M': int(np_data['M']),
    'Pop': int(np_data['Pop'])
}

national_cov_matrix_error_sd = np.sqrt(
    np.squeeze(
        np_data["state_weights"].reshape(1, -1)
        @ (np_data["state_covariance_0"] @ np_data["state_weights"].reshape(-1, 1))
    )
)

ss_cov_poll_bias = (
    np_data["state_covariance_0"]
    * (np_data["polling_bias_scale"] / national_cov_matrix_error_sd) ** 2
)

ss_cov_mu_b_T = (
    np_data["state_covariance_0"]
    * (np_data["mu_b_T_scale"] / national_cov_matrix_error_sd) ** 2
)

ss_cov_mu_b_walk = (
    np_data["state_covariance_0"]
    * (np_data["random_walk_scale"] / national_cov_matrix_error_sd) ** 2
)

cholesky_ss_cov_poll_bias = np.linalg.cholesky(ss_cov_poll_bias)
cholesky_ss_cov_mu_b_T = np.linalg.cholesky(ss_cov_mu_b_T)
cholesky_ss_cov_mu_b_walk = np.linalg.cholesky(ss_cov_mu_b_walk)

i, j = np.indices((np_data["T"], np_data["T"]))

mask = np.tril(np.ones((np_data["T"], np_data["T"])))

with pm.Model() as m2:
    
    normal_dist = pm.Normal.dist(mu=0.7, sigma=0.1) 

    #BoundedNormalZeroOne = pm.Bound(pm.Normal, lower=0.0, upper=1.0)
    raw_polling_bias = pm.Normal(
        "raw_polling_bias", mu=0.0, sigma=1.0, shape=(shapes["S"])
    )

    raw_mu_b_T = pm.Normal("raw_mu_b_T", mu=0.0, sigma=1.0, shape=(shapes["S"]))
    raw_mu_b = pm.Normal(
        "raw_mu_b", mu=0.0, sigma=1.0, shape=(shapes["S"], shapes["T"])
    )
    raw_mu_c = pm.Normal("raw_mu_c", mu=0.0, sigma=1.0, shape=(shapes["P"]))
    raw_mu_m = pm.Normal("raw_mu_m", mu=0.0, sigma=1.0, shape=(shapes["M"]))
    raw_mu_pop = pm.Normal("raw_mu_pop", mu=0.0, sigma=1.0, shape=(shapes["Pop"]))

    # This has offset multiplier syntax in Stan, but ignore for now.
    mu_e_bias = pm.Normal("mu_e_bias", mu=0.0, sigma=0.02)

    # This may be an issue?
    # rho_e_bias = BoundedNormalZeroOne("rho_e_bias", mu=0.7, sigma=0.1)
    rho_e_bias = pm.Bound("rho_e_bias", normal_dist, lower=0., upper=1.)
    
    raw_e_bias = pm.Normal("raw_e_bias", mu=0.0, sigma=1.0, shape=(shapes["T"]))
    raw_measure_noise_national = pm.Normal(
        "raw_measure_noise_national",
        mu=0.0,
        sigma=1.0,
        shape=(shapes["N_national_polls"]),
    )
    raw_measure_noise_state = pm.Normal(
        "raw_measure_noise_state",
        mu=0.0,
        sigma=1.0,
        shape=(shapes["N_state_polls"]),
    )

    polling_bias = pm.Deterministic(
        "polling_bias",
        aesara.tensor.dot(cholesky_ss_cov_poll_bias, raw_polling_bias),
    )
    national_polling_bias_average = pm.Deterministic(
        "national_polling_bias_average",
        pm.math.sum(polling_bias * np_data["state_weights"]),
    )

    mu_b_final = (
        pm.math.dot(cholesky_ss_cov_mu_b_T, raw_mu_b_T) + np_data["mu_b_prior"]
    )

    # Innovations
    innovs = pm.math.matrix_dot(cholesky_ss_cov_mu_b_walk, raw_mu_b[:, :-1])

    # Reverse these (?)
    innovs = aesara.tensor.transpose(aesara.tensor.transpose(innovs)[::-1])

    # Tack on the "first" one:
    together = pm.math.concatenate(
        [aesara.tensor.reshape(mu_b_final, (-1, 1)), innovs], axis=1
    )

    # Compute the cumulative sums:
    cumsums = aesara.tensor.cumsum(together, axis=1)

    # To be [time, states]
    transposed = aesara.tensor.transpose(cumsums)

    mu_b = pm.Deterministic("mu_b", aesara.tensor.transpose(transposed[::-1]))

    national_mu_b_average = pm.Deterministic(
        "national_mu_b_average",
        pm.math.matrix_dot(
            aesara.tensor.transpose(mu_b), np_data["state_weights"].reshape((-1, 1))
        ),
    )[:, 0]

    mu_c = pm.Deterministic("mu_c", raw_mu_c * np_data["sigma_c"])
    mu_m = pm.Deterministic("mu_m", raw_mu_m * np_data["sigma_m"])
    mu_pop = pm.Deterministic("mu_pop", raw_mu_pop * np_data["sigma_pop"])

    # Matrix version:

    # e_bias_init = raw_e_bias[0] * np_data['sigma_e_bias']
    sigma_rho = pm.math.sqrt(1 - (rho_e_bias ** 2)) * np_data["sigma_e_bias"]
    sigma_vec = pm.math.concatenate(
        [
            [np_data["sigma_e_bias"]],
            aesara.tensor.repeat(sigma_rho, np_data["T"] - 1),
        ]
    )
    mus = pm.math.concatenate(
        [
            [0.0],
            aesara.tensor.repeat(mu_e_bias * (1 - rho_e_bias), shapes["T"] - 1),
        ]
    )

    A_inv = mask * (rho_e_bias ** (np.abs(i - j)))

    e_bias = pm.Deterministic(
        "e_bias", pm.math.matrix_dot(A_inv, sigma_vec * raw_e_bias + mus)
    )

    # Minus ones shenanigans required for different indexing
    logit_pi_democrat_state = (
        mu_b[np_data["state"] - 1, np_data["day_state"] - 1]
        + mu_c[np_data["poll_state"] - 1]
        + mu_m[np_data["poll_mode_state"] - 1]
        + mu_pop[np_data["poll_pop_state"] - 1]
        + np_data["unadjusted_state"] * e_bias[np_data["day_state"] - 1]
        + raw_measure_noise_state * np_data["sigma_measure_noise_state"]
        + polling_bias[np_data["state"] - 1]
    )

    logit_pi_democrat_state = pm.Deterministic(
        "logit_pi_democrat_state", logit_pi_democrat_state
    )

    logit_pi_democrat_national = (
        national_mu_b_average[np_data["day_national"] - 1]
        + mu_c[np_data["poll_national"] - 1]
        + mu_m[np_data["poll_mode_national"] - 1]
        + mu_pop[np_data["poll_pop_national"] - 1]
        + np_data["unadjusted_national"] * e_bias[np_data["day_national"] - 1]
        + raw_measure_noise_national * np_data["sigma_measure_noise_national"]
        + national_polling_bias_average
    )

    logit_pi_democrat_national = pm.Deterministic(
        "logit_pi_democrat_national", logit_pi_democrat_national
    )

    prob_state = aesara.tensor.sigmoid(logit_pi_democrat_state)
    prob_nat = aesara.tensor.sigmoid(logit_pi_democrat_national)

    state_lik = pm.Binomial(
        "state_lik",
        n=np_data["n_two_share_state"],
        p=prob_state,
        observed=np_data["n_democrat_state"],
    )
    national_lik = pm.Binomial(
        "nat_lik",
        n=np_data["n_two_share_national"],
        p=prob_nat,
        observed=np_data["n_democrat_national"],
    )

Required data to run it is here:
exported_stan_data_2016.json.zip

I am using PyMC v4.1.3, and aesara v2.7.7. This may also be the same problem as reported in the discourse for another model here.

Thanks a lot for your help!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions