Skip to content

BUG: SMC does not respect initval #7438

Closed
@tvwenger

Description

@tvwenger

Describe the issue:

Per the discussion here, SMC does not properly handle custom initval for free RVs. This is apparent when sampling from distributions with the Ordered() transformation, where it is imperative that the initval be ordered else the model will fail to sample due to a NaN logp at the initial point.

Reproduceable code example:

import pymc as pm
from pymc.distributions.transforms import Ordered

with pm.Model() as model:
    a = pm.Normal("a", mu=0.0, sigma=1.0, size=(3,), transform=Ordered(), initval=[-1.0, 0.0, 1.0])
    b = pm.Normal("b", mu=a, sigma=1.0, observed=[0.0, 0.0, 0.0])

with model:
    trace = pm.sample_smc()

Error message:

<details>
---------------------------------------------------------------------------
_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback:
"""
Traceback (most recent call last):
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/process.py", line 256, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py", line 346, in _sample_smc_int
    smc.update_beta_and_weights()
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/kernels.py", line 273, in update_beta_and_weights
    ESS = int(np.exp(-logsumexp(log_weights * 2)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: cannot convert float NaN to integer
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[1], line 9
      6     b = pm.Normal("b", mu=a, sigma=1.0, observed=[0.0, 0.0, 0.0])
      8 with model:
----> 9     trace = pm.sample_smc()

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:217, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
    208 params = (
    209     draws,
    210     kernel,
    211     start,
    212     model,
    213 )
    215 t1 = time.time()
--> 217 results = run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores)
    219 (
    220     traces,
    221     sample_stats,
    222     sample_settings,
    223 ) = zip(*results)
    225 trace = MultiTrace(traces)

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:422, in run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores)
    415                 # update the progress bar for this task:
    416                 progress.update(
    417                     status=f"Stage: {stage} Beta: {beta:.3f}",
    418                     task_id=task_id,
    419                     refresh=True,
    420                 )
--> 422 return tuple(cloudpickle.loads(r.result()) for r in done)

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:422, in <genexpr>(.0)
    415                 # update the progress bar for this task:
    416                 progress.update(
    417                     status=f"Stage: {stage} Beta: {beta:.3f}",
    418                     task_id=task_id,
    419                     refresh=True,
    420                 )
--> 422 return tuple(cloudpickle.loads(r.result()) for r in done)

File ~/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/_base.py:449, in Future.result(self, timeout)
    447     raise CancelledError()
    448 elif self._state == FINISHED:
--> 449     return self.__get_result()
    451 self._condition.wait(timeout)
    453 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File ~/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/_base.py:401, in Future.__get_result(self)
    399 if self._exception:
    400     try:
--> 401         raise self._exception
    402     finally:
    403         # Break a reference cycle with the exception in self._exception
    404         self = None

ValueError: cannot convert float NaN to integer
</details>

PyMC version information:

pymc 5.16.2 (conda)

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions