@@ -342,6 +342,7 @@ def sample_blackjax_nuts(
342
342
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
343
343
model : Optional [Model ] = None ,
344
344
var_names : Optional [Sequence [str ]] = None ,
345
+ progress_bar : bool = False ,
345
346
keep_untransformed : bool = False ,
346
347
chain_method : str = "parallel" ,
347
348
postprocessing_backend : Optional [Literal ["cpu" , "gpu" ]] = None ,
@@ -447,21 +448,22 @@ def sample_blackjax_nuts(
447
448
# Adapted from numpyro
448
449
if chain_method == "parallel" :
449
450
map_fn = jax .pmap
450
- if adaptation_kwargs . get ( " progress_bar" , False ) :
451
+ if progress_bar :
451
452
import warnings
452
453
453
454
warnings .warn (
454
- "BlackJax currently only display progress_bar correctly under "
455
- "`chain_method == 'vectorized'`. Setting `progress_bar =False`."
455
+ "BlackJax currently only display progress bar correctly under "
456
+ "`chain_method == 'vectorized'`. Setting `progressbar =False`."
456
457
)
457
- adaptation_kwargs [ " progress_bar" ] = False
458
+ progress_bar = False
458
459
elif chain_method == "vectorized" :
459
460
map_fn = jax .vmap
460
461
else :
461
462
raise ValueError (
462
463
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
463
464
)
464
465
466
+ adaptation_kwargs ["progress_bar" ] = progress_bar
465
467
get_posterior_samples = partial (
466
468
_blackjax_inference_loop ,
467
469
logprob_fn = logprob_fn ,
0 commit comments