Skip to content

Commit 3eb6a66

Browse files
committed
Change SamplerWarning.kind to enum
1 parent 1fa4b8f commit 3eb6a66

File tree

7 files changed

+42
-26
lines changed

7 files changed

+42
-26
lines changed

pymc3/backends/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
464464
return _squeeze_cat(results, combine, squeeze)
465465

466466
def _slice(self, slice):
467-
"""Return a new MultiTrace object sliced according to `idx`."""
467+
"""Return a new MultiTrace object sliced according to `slice`."""
468468
new_traces = [trace._slice(slice) for trace in self._straces.values()]
469469
trace = MultiTrace(new_traces)
470470
idxs = slice.indices(len(self))

pymc3/backends/report.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
11
from collections import namedtuple
22
import logging
3+
import enum
34

45

56
logger = logging.getLogger('pymc3')
67

78

9+
@enum.unique
10+
class WarningType(enum.Enum):
11+
# For HMC and NUTS
12+
DIVERGENCE = 1
13+
TUNING_DIVERGENCE = 2
14+
DIVERGENCES = 3
15+
TREEDEPTH = 4
16+
# Problematic sampler parameters
17+
BAD_PARAMS = 5
18+
# Indications that chains did not converge, eg Rhat
19+
CONVERGENCE = 6
20+
BAD_ACCEPTANCE = 7
21+
22+
823
SamplerWarning = namedtuple(
924
'SamplerWarning',
1025
"kind, message, level, step, exec_info, extra")
@@ -47,7 +62,8 @@ def _run_convergence_checks(self, trace):
4762
if trace.nchains == 1:
4863
msg = ("Only one chain was sampled, this makes it impossible to "
4964
"run some convergence checks")
50-
warn = SamplerWarning('bad-params', msg, 'info', None, None, None)
65+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
66+
None, None, None)
5167
self._add_warnings([warn])
5268
return
5369

@@ -62,34 +78,35 @@ def _run_convergence_checks(self, trace):
6278
msg = ("The gelman-rubin statistic is larger than 1.4 for some "
6379
"parameters. The sampler did not converge.")
6480
warn = SamplerWarning(
65-
'convergence', msg, 'error', None, None, gelman_rubin)
81+
WarningType.CONVERGENCE, msg, 'error', None, None, gelman_rubin)
6682
warnings.append(warn)
6783
elif rhat_max > 1.2:
6884
msg = ("The gelman-rubin statistic is larger than 1.2 for some "
6985
"parameters.")
7086
warn = SamplerWarning(
71-
'convergence', msg, 'warn', None, None, gelman_rubin)
87+
WarningType.CONVERGENCE, msg, 'warn', None, None, gelman_rubin)
7288
warnings.append(warn)
7389
elif rhat_max > 1.05:
7490
msg = ("The gelman-rubin statistic is larger than 1.05 for some "
7591
"parameters. This indicates slight problems during "
7692
"sampling.")
7793
warn = SamplerWarning(
78-
'convergence', msg, 'info', None, None, gelman_rubin)
94+
WarningType.CONVERGENCE, msg, 'info', None, None, gelman_rubin)
7995
warnings.append(warn)
8096

8197
eff_min = min(val.min() for val in effective_n.values())
82-
if eff_min < 100:
98+
n_samples = len(trace) * trace.nchains
99+
if eff_min < 200 and n_samples >= 500:
83100
msg = ("The estimated number of effective samples is smaller than "
84-
"100 for some parameters.")
101+
"200 for some parameters.")
85102
warn = SamplerWarning(
86-
'convergence', msg, 'error', None, None, effective_n)
103+
WarningType.CONVERGENCE, msg, 'error', None, None, effective_n)
87104
warnings.append(warn)
88-
elif eff_min < 300:
89-
msg = ("The estimated number of effective samples is smaller than "
90-
"300 for some parameters.")
105+
elif eff_min / n_samples < 0.25:
106+
msg = ("The number of effective samples is smaller than "
107+
"25% for some parameters.")
91108
warn = SamplerWarning(
92-
'convergence', msg, 'warn', None, None, effective_n)
109+
WarningType.CONVERGENCE, msg, 'warn', None, None, effective_n)
93110
warnings.append(warn)
94111

95112
self._add_warnings(warnings)

pymc3/sampling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from .util import update_start_vals
1919
from .vartypes import discrete_types
2020
from pymc3.step_methods.hmc import quadpotential
21-
from pymc3.backends.report import SamplerWarning
2221
from pymc3 import plots
2322
import pymc3 as pm
2423
from tqdm import tqdm

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pymc3.tuning import guess_scaling
1010
from .quadpotential import quad_potential, QuadPotentialDiagAdapt
1111
from pymc3.step_methods import step_sizes
12-
from pymc3.backends.report import SamplerWarning
12+
from pymc3.backends.report import SamplerWarning, WarningType
1313

1414

1515
HMCStepData = namedtuple(
@@ -107,7 +107,7 @@ def _hamiltonian_step(self, start, p0, step_size):
107107
raise NotImplementedError("Abstract method")
108108

109109
def astep(self, q0):
110-
"""Perform a single NUTS iteration."""
110+
"""Perform a single HMC iteration."""
111111
p0 = self.potential.random()
112112
start = self.integrator.compute_state(q0, p0)
113113

@@ -130,10 +130,10 @@ def astep(self, q0):
130130
if hmc_step.divergence_info:
131131
info = hmc_step.divergence_info
132132
if self.tune:
133-
kind = 'tuning-divergence'
133+
kind = WarningType.TUNING_DIVERGENCE
134134
point = None
135135
else:
136-
kind = 'divergence'
136+
kind = WarningType.DIVERGENCE
137137
self._num_divs_sample += 1
138138
# We don't want to fill up all memory with divergence info
139139
if self._num_divs_sample < 100:
@@ -174,25 +174,25 @@ def warnings(self, strace):
174174
msg = ('The chain contains only diverging samples. The model is '
175175
'probably misspecified.')
176176
warning = SamplerWarning(
177-
'divergences', msg, 'error', None, None, None)
177+
WarningType.DIVERGENCES, msg, 'error', None, None, None)
178178
warnings.append(warning)
179179
elif n_divs > 0:
180180
message = ('Divergences after tuning. Increase `target_accept` or '
181181
'reparameterize.')
182182
warning = SamplerWarning(
183-
'divergences', message, 'error', None, None, None)
183+
WarningType.DIVERGENCES, message, 'error', None, None, None)
184184
warnings.append(warning)
185185

186186
# small trace
187187
if self._samples_after_tune == 0:
188188
msg = "Tuning was enabled throughout the whole trace."
189189
warning = SamplerWarning(
190-
'bad-params', msg, 'error', None, None, None)
190+
WarningType.BAD_PARAMS, msg, 'error', None, None, None)
191191
warnings.append(warning)
192192
elif self._samples_after_tune < 500:
193193
msg = "Only %s samples in chain." % self._samples_after_tune
194194
warning = SamplerWarning(
195-
'bad-params', msg, 'error', None, None, None)
195+
WarningType.BAD_PARAMS, msg, 'error', None, None, None)
196196
warnings.append(warning)
197197

198198
warnings.extend(self.step_adapt.warnings())

pymc3/step_methods/hmc/nuts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..arraystep import Competence
99
from .base_hmc import BaseHMC, HMCStepData, DivergenceInfo
1010
from .integration import IntegrationError
11-
from pymc3.backends.report import SamplerWarning
11+
from pymc3.backends.report import SamplerWarning, WarningType
1212
from pymc3.theanof import floatX
1313
from pymc3.vartypes import continuous_types
1414

@@ -189,7 +189,8 @@ def warnings(self, strace):
189189
if np.mean(self._reached_max_treedepth) > 0.05:
190190
msg = ('The chain reached the maximum tree depth. Increase '
191191
'max_treedepth, increase target_accept or reparameterize.')
192-
warn = SamplerWarning('treedepth', msg, 'warn', None, None, None)
192+
warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn',
193+
None, None, None)
193194
warnings.append(warn)
194195
return warnings
195196

pymc3/step_methods/step_sizes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from scipy import stats
55

6-
from pymc3.backends.report import SamplerWarning
6+
from pymc3.backends.report import SamplerWarning, WarningType
77

88

99
class DualAverageAdaptation(object):
@@ -61,7 +61,7 @@ def warnings(self):
6161
% (mean_accept, target_accept))
6262
info = {'target': target_accept, 'actual': mean_accept}
6363
warning = SamplerWarning(
64-
'bad-acceptance', msg, 'warn', None, None, info)
64+
WarningType.BAD_ACCEPTANCE, msg, 'warn', None, None, info)
6565
return [warning]
6666
else:
6767
return []

pymc3/tests/test_step.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
MultivariateNormalProposal, HamiltonianMC,
1212
EllipticalSlice, smc, DEMetropolis)
1313
from pymc3.theanof import floatX
14-
from pymc3 import SamplingError
1514
from pymc3.distributions import (
1615
Binomial, Normal, Bernoulli, Categorical, Beta, HalfNormal)
1716

0 commit comments

Comments
 (0)