Skip to content

Commit 0923d25

Browse files
Return InferenceData by default
Also removes some unnecessary XFAIL marks. Closes #4372, #4740 Co-authored-by: Oriol Abril <[email protected]>
1 parent 660b95b commit 0923d25

24 files changed

+225
-231
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc3/pull/4471) and `3.11.2` release notes).
77
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
88
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`.
9+
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc3/pull/4744)).
910
- ...
1011

1112
### New Features

benchmarks/benchmarks/benchmarks.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def track_glm_hierarchical_ess(self, init):
181181
init=init, chains=self.chains, progressbar=False, random_seed=123
182182
)
183183
t0 = time.time()
184-
trace = pm.sample(
184+
idata = pm.sample(
185185
draws=self.draws,
186186
step=step,
187187
cores=4,
@@ -192,7 +192,7 @@ def track_glm_hierarchical_ess(self, init):
192192
compute_convergence_checks=False,
193193
)
194194
tot = time.time() - t0
195-
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
195+
ess = float(az.ess(idata, var_names=["mu_a"])["mu_a"].values)
196196
return ess / tot
197197

198198
def track_marginal_mixture_model_ess(self, init):
@@ -203,7 +203,7 @@ def track_marginal_mixture_model_ess(self, init):
203203
)
204204
start = [{k: v for k, v in start.items()} for _ in range(self.chains)]
205205
t0 = time.time()
206-
trace = pm.sample(
206+
idata = pm.sample(
207207
draws=self.draws,
208208
step=step,
209209
cores=4,
@@ -214,7 +214,7 @@ def track_marginal_mixture_model_ess(self, init):
214214
compute_convergence_checks=False,
215215
)
216216
tot = time.time() - t0
217-
ess = az.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
217+
ess = az.ess(idata, var_names=["mu"])["mu"].values.min() # worst case
218218
return ess / tot
219219

220220

@@ -235,7 +235,7 @@ def track_glm_hierarchical_ess(self, step):
235235
if step is not None:
236236
step = step()
237237
t0 = time.time()
238-
trace = pm.sample(
238+
idata = pm.sample(
239239
draws=self.draws,
240240
step=step,
241241
cores=4,
@@ -245,7 +245,7 @@ def track_glm_hierarchical_ess(self, step):
245245
compute_convergence_checks=False,
246246
)
247247
tot = time.time() - t0
248-
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
248+
ess = float(az.ess(idata, var_names=["mu_a"])["mu_a"].values)
249249
return ess / tot
250250

251251

@@ -302,9 +302,9 @@ def freefall(y, t, p):
302302
Y = pm.Normal("Y", mu=ode_solution, sd=sigma, observed=y)
303303

304304
t0 = time.time()
305-
trace = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
305+
idata = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
306306
tot = time.time() - t0
307-
ess = az.ess(trace)
307+
ess = az.ess(idata)
308308
return np.mean([ess.sigma, ess.gamma]) / tot
309309

310310

docs/source/Advanced_usage_of_Aesara_in_PyMC3.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ be time consuming if the number of datasets is large)::
4040
pm.Normal('y', mu=mu, sigma=1, observed=data)
4141

4242
# Generate one trace for each dataset
43-
traces = []
43+
idatas = []
4444
for data_vals in observed_data:
4545
# Switch out the observed dataset
4646
data.set_value(data_vals)
4747
with model:
48-
traces.append(pm.sample())
48+
idatas.append(pm.sample())
4949

5050
We can also sometimes use shared variables to work around limitations
5151
in the current PyMC3 api. A common task in Machine Learning is to predict
@@ -63,7 +63,7 @@ variable for our observations::
6363
pm.Bernoulli('obs', p=logistic, observed=y)
6464

6565
# fit the model
66-
trace = pm.sample()
66+
idata = pm.sample()
6767

6868
# Switch out the observations and use `sample_posterior_predictive` to predict
6969
x_shared.set_value([-1, 0, 1.])
@@ -220,4 +220,4 @@ We can now define our model using this new `Op`::
220220
mu = pm.Deterministic('mu', at_mu_from_theta(theta))
221221
pm.Normal('y', mu=mu, sigma=0.1, observed=[0.2, 0.21, 0.3])
222222

223-
trace = pm.sample()
223+
idata = pm.sample()

docs/source/Gaussian_Processes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ other implementations. The first block fits the GP prior. We denote
231231

232232
f = gp.marginal_likelihood("f", X, y, noise)
233233

234-
trace = pm.sample(1000)
234+
idata = pm.sample(1000)
235235

236236

237237
To construct the conditional distribution of :code:`gp1` or :code:`gp2`, we

docs/source/about.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,9 @@ Save this file, then from a python shell (or another file in the same directory)
237237
with bioassay_model:
238238

239239
# Draw samples
240-
trace = pm.sample(1000, tune=2000, cores=2)
240+
idata = pm.sample(1000, tune=2000, cores=2)
241241
# Plot two parameters
242-
az.plot_forest(trace, var_names=['alpha', 'beta'], r_hat=True)
242+
az.plot_forest(idata, var_names=['alpha', 'beta'], r_hat=True)
243243

244244
This example will generate 1000 posterior samples on each of two cores using the NUTS algorithm, preceded by 2000 tuning samples (these are good default numbers for most models).
245245

pymc3/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,12 +498,12 @@ class Data:
498498
... pm.Normal('y', mu=mu, sigma=1, observed=data)
499499
500500
>>> # Generate one trace for each dataset
501-
>>> traces = []
501+
>>> idatas = []
502502
>>> for data_vals in observed_data:
503503
... with model:
504504
... # Switch out the observed dataset
505505
... model.set_data('data', data_vals)
506-
... traces.append(pm.sample())
506+
... idatas.append(pm.sample())
507507
508508
To set the value of the data container variable, check out
509509
:func:`pymc3.model.set_data()`.

pymc3/distributions/discrete.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,14 +1691,15 @@ class OrderedLogistic(Categorical):
16911691
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
16921692
transform=pm.distributions.transforms.ordered)
16931693
y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y)
1694-
tr = pm.sample(1000)
1694+
idata = pm.sample(1000)
16951695
16961696
# Plot the results
16971697
plt.hist(cluster1, 30, alpha=0.5);
16981698
plt.hist(cluster2, 30, alpha=0.5);
16991699
plt.hist(cluster3, 30, alpha=0.5);
1700-
plt.hist(tr["cutpoints"][:,0], 80, alpha=0.2, color='k');
1701-
plt.hist(tr["cutpoints"][:,1], 80, alpha=0.2, color='k');
1700+
posterior = idata.posterior.stack(sample=("chain", "draw"))
1701+
plt.hist(posterior["cutpoints"][0], 80, alpha=0.2, color='k');
1702+
plt.hist(posterior["cutpoints"][1], 80, alpha=0.2, color='k');
17021703
17031704
"""
17041705

@@ -1782,14 +1783,15 @@ class OrderedProbit(Categorical):
17821783
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
17831784
transform=pm.distributions.transforms.ordered)
17841785
y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=x, observed=y)
1785-
tr = pm.sample(1000)
1786+
idata = pm.sample(1000)
17861787
17871788
# Plot the results
17881789
plt.hist(cluster1, 30, alpha=0.5);
17891790
plt.hist(cluster2, 30, alpha=0.5);
17901791
plt.hist(cluster3, 30, alpha=0.5);
1791-
plt.hist(tr["cutpoints"][:,0], 80, alpha=0.2, color='k');
1792-
plt.hist(tr["cutpoints"][:,1], 80, alpha=0.2, color='k');
1792+
posterior = idata.posterior.stack(sample=("chain", "draw"))
1793+
plt.hist(posterior["cutpoints"][0], 80, alpha=0.2, color='k');
1794+
plt.hist(posterior["cutpoints"][1], 80, alpha=0.2, color='k');
17931795
17941796
"""
17951797

pymc3/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def __init__(
495495
normal_dist.logp,
496496
observed=np.random.randn(100),
497497
)
498-
trace = pm.sample(100)
498+
idata = pm.sample(100)
499499
500500
.. code-block:: python
501501

pymc3/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,15 +1696,15 @@ def set_data(new_data, model=None):
16961696
... y = pm.Data('y', [1., 2., 3.])
16971697
... beta = pm.Normal('beta', 0, 1)
16981698
... obs = pm.Normal('obs', x * beta, 1, observed=y)
1699-
... trace = pm.sample(1000, tune=1000)
1699+
... idata = pm.sample(1000, tune=1000)
17001700
17011701
Set the value of `x` to predict on new data.
17021702
17031703
.. code:: ipython
17041704
17051705
>>> with model:
17061706
... pm.set_data({'x': [5., 6., 9.]})
1707-
... y_test = pm.sample_posterior_predictive(trace)
1707+
... y_test = pm.sample_posterior_predictive(idata)
17081708
>>> y_test['obs'].mean(axis=0)
17091709
array([4.6088569 , 5.54128318, 8.32953844])
17101710
"""

pymc3/sampling.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import aesara.gradient as tg
2929
import numpy as np
30-
import packaging
3130
import xarray
3231

3332
from aesara.compile.mode import Mode
@@ -355,7 +354,7 @@ def sample(
355354
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
356355
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
357356
init methods.
358-
return_inferencedata : bool, default=False
357+
return_inferencedata : bool, default=True
359358
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
360359
Defaults to `False`, but we'll switch to `True` in an upcoming release.
361360
idata_kwargs : dict, optional
@@ -430,9 +429,9 @@ def sample(
430429
In [2]: with pm.Model() as model: # context management
431430
...: p = pm.Beta("p", alpha=alpha, beta=beta)
432431
...: y = pm.Binomial("y", n=n, p=p, observed=h)
433-
...: trace = pm.sample()
432+
...: idata = pm.sample()
434433
435-
In [3]: az.summary(trace, kind="stats")
434+
In [3]: az.summary(idata, kind="stats")
436435
437436
Out[3]:
438437
mean sd hdi_3% hdi_97%
@@ -471,6 +470,9 @@ def sample(
471470
if not isinstance(random_seed, abc.Iterable):
472471
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
473472

473+
if return_inferencedata is None:
474+
return_inferencedata = True
475+
474476
if not discard_tuned_samples and not return_inferencedata:
475477
warnings.warn(
476478
"Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
@@ -480,18 +482,6 @@ def sample(
480482
stacklevel=2,
481483
)
482484

483-
if return_inferencedata is None:
484-
v = packaging.version.parse(pm.__version__)
485-
if v.release[0] > 3 or v.release[1] >= 10: # type: ignore
486-
warnings.warn(
487-
"In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. "
488-
"You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.",
489-
FutureWarning,
490-
stacklevel=2,
491-
)
492-
# set the default
493-
return_inferencedata = False
494-
495485
if start is not None:
496486
for start_vals in start:
497487
_check_start_shape(model, start_vals)

0 commit comments

Comments
 (0)