-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor HMC and warning system #2677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
pymc3/step_methods/hmc/hmc.py
Outdated
probability across the trajectories are close to target_accept. | ||
Higher values for target_accept lead to smaller step sizes. | ||
gamma : float, default .05 | ||
k : float (.5,1) default .75 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.5,1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copied that from nuts, so if we change it here, we should also change it there.
k has to be in the open interval (0.5, 1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, isn't [0.5, 1]
the more common syntax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would imply a closed interval, including 0.5 and 1. I don't think that's allowed. I can just put it in words in the description, that would be leas confusing I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think common syntax might be ]0.5, 1[
, which makes me twitch, but at least makes it clear we're not talking about a tuple. also fine with using words!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please dont use ]0.5, 1[
... word is good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Words it is.
I think the ]0.5, 1[ syntax is more common in the US, and (0.5, 1) in europe.
|
e2626c9
to
e45c465
Compare
e45c465
to
6cdde3b
Compare
86d7612
to
fe95101
Compare
e9e9ecf
to
acaaeda
Compare
acaaeda
to
e20de1a
Compare
pymc3/sampling.py
Outdated
@@ -289,6 +296,9 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, | |||
Options for traceplot. Example: live_plot_kwargs={'varnames': ['x']} | |||
discard_tuned_samples : bool | |||
Whether to discard posterior samples of the tune interval. | |||
compute_stats : bool, default=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any ideas for a better name than compute_stats
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
store_stats
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm. That might be confusing with the normal sampler stats.
Maybe convergence_checks
or convergence_stats
or the very long compute_convergence_stats
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think compute_convergence_stats
suits.
e20de1a
to
1fa4b8f
Compare
I ended up refactoring more than I planned: The hmc related code is divided into three classes: I rewrote the whole warning system. Warnings during sampling are no longer python warnings (although we could still raise those if we want), but are stored as a """
At the moment `kind` can be one of
* "divergence" (for hmc and nuts)
* "tuning-divergence"
* "bad-params" (for problematic sampler parameters)
* "convergence" (for indications that the chains did not converge, eg Rhat)
* "bad-acceptance"
* "treedepth" (for nuts)
"""
SamplerWarning = namedtuple(
'SamplerWarning',
"kind, message, level, step, exec_info, extra") Several classes have a
I think we should think a bit more about how we want to present those warnings to the user (eg do we want to use the std logging module, do we only want to show one warning "please look at the report" etc) but I think this can wait until after this PR, as at this point we only need to change the |
pymc3/step_methods/hmc/base_hmc.py
Outdated
raise NotImplementedError("Abstract method") | ||
|
||
def astep(self, q0): | ||
"""Perform a single NUTS iteration.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: "a single HMC iteration"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
pymc3/backends/base.py
Outdated
@@ -447,10 +463,13 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True, | |||
for chain in chains] | |||
return _squeeze_cat(results, combine, squeeze) | |||
|
|||
def _slice(self, idx): | |||
def _slice(self, slice): | |||
"""Return a new MultiTrace object sliced according to `idx`.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to idx
--> according to slice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
pymc3/backends/report.py
Outdated
'convergence', msg, 'info', None, None, gelman_rubin) | ||
warnings.append(warn) | ||
|
||
eff_min = min(val.min() for val in effective_n.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be more informative to raise warning if the effect sample size is lower than a certain percentage of the total sample? For example, raise warning if effect sample size is below 25% of the total number of sample.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. We could have both, too. One with level info
if it is <25% and one with level warn
if it is <200?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would do the other way around:
warn
if it is <25% (high autocorrelation and poor mixing is likely an indication of modelling problem)
and info
if it is <200 (sometimes ppl sample only 200 for demo purpose and it would be annoying to see warning)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that comes down to how much we trust effective_n
. If we don't trust the numbers it returns, but only use it as an indication that the sampler doesn't work well, then I think I agree. But if we take the values from effective_n
at face value, then a low percentage of effective samples isn't in itself a problem (ie should only be info
). But a very low number of eff samples is a problem.
Say your samples have some autocorrelation, but you don't get any divergences, and gelman_rubin looks fine. Then I don't see a problem with just running the sampler for a very long time, until you get a lot of effective samples. This seems to me like a valid use-case, and we'd print warnings if people did this, even though what they are doing is ok. (at least I think it is, right?)
On the other hand, few effective samples are always a problem if you plan to do anything with your trace.
Maybe a middle way would be:
info
warning ifneff < 0.25 * draws
warn
ifneff < 200
anddraws > 500
That way we don't issue a low-neff warning if users explicitly asked for a low number of samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a good point, I agree with this solution.
pymc3/backends/report.py
Outdated
logger = logging.getLogger('pymc3') | ||
|
||
|
||
SamplerWarning = namedtuple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might make sense to have SamplerWarning.kind
be an Enum
, so that there is a little more reuse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, fixed
Oh, and final nitpick: It would be great to change |
I agree that we should do that, but I already had to rebase this PR because of that at least 3 times... |
LOL, agree.
Did you fix that? |
ups, I have now :-) |
055a68d
to
3eb6a66
Compare
This is good to go. |
Cool - excited about all of this :) |
Ah, I thought we wanted to wait with this one until after 3.3 so that we have a bit more time to find problems with it? I guess we could also just branch off a 3.3 branch from before this merge, and push that to github. We want to release that pretty soon, right? |
Ohhh right. Well then definitely branch off 3.3 before this merge |
How about we just revert the merge commit? |
For the release? Sure. Sorry about the potential mess in the history... |
Re v3.3 release, maybe we can just include this PR (after merging #2808)? |
We recently implemented a couple of improvements for NUTS, and most of them would also be applicable to HMC (and some also to metropolis, but this is a separate issue).
This is a first step toward using dual averaging in hmc. Proper divergence handling is still missing in hmc. It might be better to move some of this into
BaseHMC
, and just use it in both NUTS and HMC.