@@ -254,7 +254,9 @@ def sample_blackjax_nuts(
254
254
idata_kwargs : dict, optional
255
255
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
256
256
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
257
- not be included in the returned object.
257
+ not be included in the returned object. Values for ``observed_data``, ``constant_data``,
258
+ ``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
259
+ in ``idata_kwargs``.
258
260
259
261
Returns
260
262
-------
@@ -366,16 +368,16 @@ def sample_blackjax_nuts(
366
368
367
369
posterior = mcmc_samples
368
370
# Use 'partial' to set default arguments before passing 'idata_kwargs'
369
- az_trace = partial (
371
+ to_trace = partial (
370
372
az .from_dict ,
371
- posterior = posterior ,
372
373
log_likelihood = log_likelihood ,
373
374
observed_data = find_observations (model ),
374
375
constant_data = find_constants (model ),
375
376
coords = coords ,
376
377
dims = dims ,
377
378
attrs = make_attrs (attrs , library = blackjax ),
378
- )(** idata_kwargs )
379
+ )
380
+ az_trace = to_trace (posterior = posterior , ** idata_kwargs )
379
381
380
382
return az_trace
381
383
@@ -432,7 +434,9 @@ def sample_numpyro_nuts(
432
434
idata_kwargs : dict, optional
433
435
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
434
436
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
435
- not be included in the returned object.
437
+ not be included in the returned object. Values for ``observed_data``, ``constant_data``,
438
+ ``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
439
+ in ``idata_kwargs``.
436
440
nuts_kwargs: dict, optional
437
441
Keyword arguments for :func:`numpyro.infer.NUTS`.
438
442
@@ -562,16 +566,16 @@ def sample_numpyro_nuts(
562
566
563
567
posterior = mcmc_samples
564
568
# Use 'partial' to set default arguments before passing 'idata_kwargs'
565
- az_trace = partial (
569
+ to_trace = partial (
566
570
az .from_dict ,
567
- posterior = posterior ,
568
571
log_likelihood = log_likelihood ,
569
572
observed_data = find_observations (model ),
570
573
constant_data = find_constants (model ),
571
574
sample_stats = _sample_stats_to_xarray (pmap_numpyro ),
572
575
coords = coords ,
573
576
dims = dims ,
574
577
attrs = make_attrs (attrs , library = numpyro ),
575
- )(** idata_kwargs )
578
+ )
579
+ az_trace = to_trace (posterior = posterior , ** idata_kwargs )
576
580
577
581
return az_trace
0 commit comments