Skip to content

Refactor HMC and warning system #2677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 8, 2018

Conversation

aseyboldt
Copy link
Member

We recently implemented a couple of improvements for NUTS, and most of them would also be applicable to HMC (and some also to metropolis, but this is a separate issue).
This is a first step toward using dual averaging in hmc. Proper divergence handling is still missing in hmc. It might be better to move some of this into BaseHMC, and just use it in both NUTS and HMC.

probability across the trajectories are close to target_accept.
Higher values for target_accept lead to smaller step sizes.
gamma : float, default .05
k : float (.5,1) default .75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.5,1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied that from nuts, so if we change it here, we should also change it there.
k has to be in the open interval (0.5, 1).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, isn't [0.5, 1] the more common syntax?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would imply a closed interval, including 0.5 and 1. I don't think that's allowed. I can just put it in words in the description, that would be leas confusing I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think common syntax might be ]0.5, 1[, which makes me twitch, but at least makes it clear we're not talking about a tuple. also fine with using words!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please dont use ]0.5, 1[... word is good

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Words it is.
I think the ]0.5, 1[ syntax is more common in the US, and (0.5, 1) in europe.

@twiecki
Copy link
Member

twiecki commented Nov 3, 2017

pymc3/step_methods/hmc/hmc.py:7: [W0611(unused-import), ] Unused floatX imported from pymc3.theanof

@aseyboldt aseyboldt force-pushed the adapt-step-hmc branch 2 times, most recently from e2626c9 to e45c465 Compare November 6, 2017 08:04
@aseyboldt aseyboldt changed the title WIP: Use improvements to NUTS for HMC WIP: Refactor HMC and warning system Nov 6, 2017
@aseyboldt aseyboldt force-pushed the adapt-step-hmc branch 2 times, most recently from 86d7612 to fe95101 Compare November 24, 2017 21:05
@aseyboldt aseyboldt force-pushed the adapt-step-hmc branch 6 times, most recently from e9e9ecf to acaaeda Compare December 3, 2017 15:22
@@ -289,6 +296,9 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
Options for traceplot. Example: live_plot_kwargs={'varnames': ['x']}
discard_tuned_samples : bool
Whether to discard posterior samples of the tune interval.
compute_stats : bool, default=True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any ideas for a better name than compute_stats?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

store_stats?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. That might be confusing with the normal sampler stats.
Maybe convergence_checks or convergence_stats or the very long compute_convergence_stats?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think compute_convergence_stats suits.

@aseyboldt
Copy link
Member Author

I ended up refactoring more than I planned:

The hmc related code is divided into three classes: BaseHMC, HamiltonianMC and NUTS. The first implements things that are common to both hmc and nuts. In this PR I moved dual averaging (now wrapped in a separate class, which might also be useful for metropolis) and mass matrix adaptation from NUTS to BaseHMC, so that HamiltonianMC can take advantage of it as well. The actual trajectory is computed and sampled in NUTS._hamiltonian_step and HamiltonianMC._hamiltonian_step, which is called by BaseHMC.astep. BaseHMC.astep handles divergences, adaptation, warnings, and some error checking.

I rewrote the whole warning system. Warnings during sampling are no longer python warnings (although we could still raise those if we want), but are stored as a namedtuple:

"""
At the moment `kind` can be one of
* "divergence" (for hmc and nuts)
* "tuning-divergence"
* "bad-params" (for problematic sampler parameters)
* "convergence" (for indications that the chains did not converge, eg Rhat)
* "bad-acceptance"
* "treedepth" (for nuts)
"""
SamplerWarning = namedtuple(
    'SamplerWarning',
    "kind, message, level, step, exec_info, extra")

Several classes have a warnings(strace) method now, that generates these warnings. This is still missing for the mass matrix adaptation, because I'm not sure yet, how to diagnose problems there.

level is the same as the logging levels from the std logging module. exec_info stores an exception if there was one, and extra can store additional data related to that particular warning. For divergences it stores the integration state at the point of the divergence (which is different than the later accepted point, and should be useful for plotting.) If there is a very large number of divergences we stop storing that state at some point, to avoid memory exhaustion.

pm.sample collects the warnings, and attaches them to the trace in a report, that keeps track of the chain in which the warning occurred. This report (trace.report) also computes gelman_rubin and effective_n and generates global warnings if those don't look ok. The report object could probably use a lot of work (a html representation for the notebook would be cool). Right now it has only two public functions: ok and raise_ok. I'm not sure if or how we should expose the warnings to the users directly.

pm.sample then calls trace.report._log_warnings, which uses the std logging module to print the warnings. The global logging level and the level of the individual warnings decides if the warning is shown to the user.

I think we should think a bit more about how we want to present those warnings to the user (eg do we want to use the std logging module, do we only want to show one warning "please look at the report" etc) but I think this can wait until after this PR, as at this point we only need to change the _log_warnings method.

raise NotImplementedError("Abstract method")

def astep(self, q0):
"""Perform a single NUTS iteration."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: "a single HMC iteration"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -447,10 +463,13 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
for chain in chains]
return _squeeze_cat(results, combine, squeeze)

def _slice(self, idx):
def _slice(self, slice):
"""Return a new MultiTrace object sliced according to `idx`."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to idx --> according to slice

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

'convergence', msg, 'info', None, None, gelman_rubin)
warnings.append(warn)

eff_min = min(val.min() for val in effective_n.values())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be more informative to raise warning if the effect sample size is lower than a certain percentage of the total sample? For example, raise warning if effect sample size is below 25% of the total number of sample.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. We could have both, too. One with level info if it is <25% and one with level warn if it is <200?

Copy link
Member

@junpenglao junpenglao Jan 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do the other way around:
warn if it is <25% (high autocorrelation and poor mixing is likely an indication of modelling problem)
and info if it is <200 (sometimes ppl sample only 200 for demo purpose and it would be annoying to see warning)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that comes down to how much we trust effective_n. If we don't trust the numbers it returns, but only use it as an indication that the sampler doesn't work well, then I think I agree. But if we take the values from effective_n at face value, then a low percentage of effective samples isn't in itself a problem (ie should only be info). But a very low number of eff samples is a problem.
Say your samples have some autocorrelation, but you don't get any divergences, and gelman_rubin looks fine. Then I don't see a problem with just running the sampler for a very long time, until you get a lot of effective samples. This seems to me like a valid use-case, and we'd print warnings if people did this, even though what they are doing is ok. (at least I think it is, right?)
On the other hand, few effective samples are always a problem if you plan to do anything with your trace.
Maybe a middle way would be:

  • info warning if neff < 0.25 * draws
  • warn if neff < 200 and draws > 500

That way we don't issue a low-neff warning if users explicitly asked for a low number of samples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good point, I agree with this solution.

logger = logging.getLogger('pymc3')


SamplerWarning = namedtuple(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might make sense to have SamplerWarning.kind be an Enum, so that there is a little more reuse?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, fixed

@junpenglao
Copy link
Member

Oh, and final nitpick: It would be great to change pm._log to _log using logging everywhere else in the code base. Of course, we can also do it in a separate PR.

@aseyboldt
Copy link
Member Author

I agree that we should do that, but I already had to rebase this PR because of that at least 3 times...
Maybe a separate PR is better :-)

@junpenglao
Copy link
Member

LOL, agree.
Travis was complaining about unused import previously:

************* Module pymc3.sampling
pymc3/sampling.py:21: [W0611(unused-import), ] Unused SamplerWarning imported from pymc3.backends.report
************* Module pymc3.tests.test_step
pymc3/tests/test_step.py:14: [W0611(unused-import), ] Unused SamplingError imported from pymc3

Did you fix that?

@aseyboldt
Copy link
Member Author

ups, I have now :-)

@junpenglao junpenglao changed the title WIP: Refactor HMC and warning system Refactor HMC and warning system Jan 8, 2018
@junpenglao
Copy link
Member

This is good to go.

@springcoil
Copy link
Contributor

Cool - excited about all of this :)

@junpenglao junpenglao merged commit 39cd75d into pymc-devs:master Jan 8, 2018
@aseyboldt
Copy link
Member Author

Ah, I thought we wanted to wait with this one until after 3.3 so that we have a bit more time to find problems with it? I guess we could also just branch off a 3.3 branch from before this merge, and push that to github. We want to release that pretty soon, right?

@junpenglao
Copy link
Member

Ohhh right. Well then definitely branch off 3.3 before this merge
cc @twiecki @fonnesbeck

@twiecki
Copy link
Member

twiecki commented Jan 10, 2018

How about we just revert the merge commit?

@junpenglao
Copy link
Member

For the release? Sure. Sorry about the potential mess in the history...

@junpenglao
Copy link
Member

Re v3.3 release, maybe we can just include this PR (after merging #2808)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants