Closed
Description
Describe the issue:
Per the discussion here, SMC does not properly handle custom initval
for free RVs. This is apparent when sampling from distributions with the Ordered()
transformation, where it is imperative that the initval
be ordered else the model will fail to sample due to a NaN
logp at the initial point.
Reproduceable code example:
import pymc as pm
from pymc.distributions.transforms import Ordered
with pm.Model() as model:
a = pm.Normal("a", mu=0.0, sigma=1.0, size=(3,), transform=Ordered(), initval=[-1.0, 0.0, 1.0])
b = pm.Normal("b", mu=a, sigma=1.0, observed=[0.0, 0.0, 0.0])
with model:
trace = pm.sample_smc()
Error message:
<details>
---------------------------------------------------------------------------
_RemoteTraceback Traceback (most recent call last)
_RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/process.py", line 256, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py", line 346, in _sample_smc_int
smc.update_beta_and_weights()
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/kernels.py", line 273, in update_beta_and_weights
ESS = int(np.exp(-logsumexp(log_weights * 2)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: cannot convert float NaN to integer
"""
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[1], line 9
6 b = pm.Normal("b", mu=a, sigma=1.0, observed=[0.0, 0.0, 0.0])
8 with model:
----> 9 trace = pm.sample_smc()
File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:217, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
208 params = (
209 draws,
210 kernel,
211 start,
212 model,
213 )
215 t1 = time.time()
--> 217 results = run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores)
219 (
220 traces,
221 sample_stats,
222 sample_settings,
223 ) = zip(*results)
225 trace = MultiTrace(traces)
File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:422, in run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores)
415 # update the progress bar for this task:
416 progress.update(
417 status=f"Stage: {stage} Beta: {beta:.3f}",
418 task_id=task_id,
419 refresh=True,
420 )
--> 422 return tuple(cloudpickle.loads(r.result()) for r in done)
File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:422, in <genexpr>(.0)
415 # update the progress bar for this task:
416 progress.update(
417 status=f"Stage: {stage} Beta: {beta:.3f}",
418 task_id=task_id,
419 refresh=True,
420 )
--> 422 return tuple(cloudpickle.loads(r.result()) for r in done)
File ~/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/_base.py:449, in Future.result(self, timeout)
447 raise CancelledError()
448 elif self._state == FINISHED:
--> 449 return self.__get_result()
451 self._condition.wait(timeout)
453 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
File ~/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/_base.py:401, in Future.__get_result(self)
399 if self._exception:
400 try:
--> 401 raise self._exception
402 finally:
403 # Break a reference cycle with the exception in self._exception
404 self = None
ValueError: cannot convert float NaN to integer
</details>
PyMC version information:
pymc 5.16.2 (conda)
Context for the issue:
No response