Skip to content

Commit 13e6ed1

Browse files
committed
Extend docstring and pass posterior after partial
1 parent 221aa00 commit 13e6ed1

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

pymc/sampling_jax.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ def sample_blackjax_nuts(
254254
idata_kwargs : dict, optional
255255
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
256256
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
257-
not be included in the returned object.
257+
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
258+
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
259+
in ``idata_kwargs``.
258260
259261
Returns
260262
-------
@@ -366,16 +368,16 @@ def sample_blackjax_nuts(
366368

367369
posterior = mcmc_samples
368370
# Use 'partial' to set default arguments before passing 'idata_kwargs'
369-
az_trace = partial(
371+
to_trace = partial(
370372
az.from_dict,
371-
posterior=posterior,
372373
log_likelihood=log_likelihood,
373374
observed_data=find_observations(model),
374375
constant_data=find_constants(model),
375376
coords=coords,
376377
dims=dims,
377378
attrs=make_attrs(attrs, library=blackjax),
378-
)(**idata_kwargs)
379+
)
380+
az_trace = to_trace(posterior=posterior, **idata_kwargs)
379381

380382
return az_trace
381383

@@ -432,7 +434,9 @@ def sample_numpyro_nuts(
432434
idata_kwargs : dict, optional
433435
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
434436
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
435-
not be included in the returned object.
437+
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
438+
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
439+
in ``idata_kwargs``.
436440
nuts_kwargs: dict, optional
437441
Keyword arguments for :func:`numpyro.infer.NUTS`.
438442
@@ -562,16 +566,16 @@ def sample_numpyro_nuts(
562566

563567
posterior = mcmc_samples
564568
# Use 'partial' to set default arguments before passing 'idata_kwargs'
565-
az_trace = partial(
569+
to_trace = partial(
566570
az.from_dict,
567-
posterior=posterior,
568571
log_likelihood=log_likelihood,
569572
observed_data=find_observations(model),
570573
constant_data=find_constants(model),
571574
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
572575
coords=coords,
573576
dims=dims,
574577
attrs=make_attrs(attrs, library=numpyro),
575-
)(**idata_kwargs)
578+
)
579+
az_trace = to_trace(posterior=posterior, **idata_kwargs)
576580

577581
return az_trace

0 commit comments

Comments
 (0)