Closed
Description
Description
PyMC users are used to getting an immediate report when there are divergences in the standard NUTS sampler. I suggest we print (log) this information when we call the JAX samplers on behalf of the users.
This can be specially puzzling when there are nan in the gradient and the whole sampling finishes in an instant without a single warning, even though there were 100% divergences.
Recent user question: https://discourse.pymc.io/t/sampling-with-blackjax-no-divergence-report/13394