Skip to content

Commit 826cbe8

Browse files
committed
use standard processbar kwarg
1 parent 297b376 commit 826cbe8

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pymc/sampling/jax.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def sample_blackjax_nuts(
342342
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
343343
model: Optional[Model] = None,
344344
var_names: Optional[Sequence[str]] = None,
345+
progress_bar: bool = False,
345346
keep_untransformed: bool = False,
346347
chain_method: str = "parallel",
347348
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
@@ -447,21 +448,22 @@ def sample_blackjax_nuts(
447448
# Adapted from numpyro
448449
if chain_method == "parallel":
449450
map_fn = jax.pmap
450-
if adaptation_kwargs.get("progress_bar", False):
451+
if progress_bar:
451452
import warnings
452453

453454
warnings.warn(
454-
"BlackJax currently only display progress_bar correctly under "
455-
"`chain_method == 'vectorized'`. Setting `progress_bar=False`."
455+
"BlackJax currently only display progress bar correctly under "
456+
"`chain_method == 'vectorized'`. Setting `progressbar=False`."
456457
)
457-
adaptation_kwargs["progress_bar"] = False
458+
progress_bar = False
458459
elif chain_method == "vectorized":
459460
map_fn = jax.vmap
460461
else:
461462
raise ValueError(
462463
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
463464
)
464465

466+
adaptation_kwargs["progress_bar"] = progress_bar
465467
get_posterior_samples = partial(
466468
_blackjax_inference_loop,
467469
logprob_fn=logprob_fn,

pymc/sampling/mcmc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def _sample_external_nuts(
372372
random_seed=random_seed,
373373
initvals=initvals,
374374
model=model,
375+
progress_bar=progressbar,
375376
idata_kwargs=idata_kwargs,
376377
**nuts_sampler_kwargs,
377378
)

0 commit comments

Comments
 (0)