Skip to content

ReverseDiff support #1170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 24, 2020
Merged

ReverseDiff support #1170

merged 11 commits into from
Mar 24, 2020

Conversation

mohamed82008
Copy link
Member

@mohamed82008 mohamed82008 commented Mar 22, 2020

This PR adds ReverseDiff support. The most important change here is that in order to use Tracker for AD, the user would need to do Turing.setadbackend(:tracker) instead of :reverse_diff after this PR is merged. ReverseDiff would be optional, much like Zygote, after this PR so the user needs to do:

using ReverseDiff, Turing
Turing.setadbackend(:reverse_diff)

to use ReverseDiff. This PR requires a new release of Bijectors which I already made but the TagBot is taking a while to make the release and tag in the Bijectors repo. The General PR was already merged a while back. Let's wait for the TagBot then re-run the tests here.

Edit: Bijectors was released.

@mohamed82008
Copy link
Member Author

ReverseDiff gives a mysterious error in Julia 1.0 and 1.1 so I removed the ReverseDiff inference tests from these 2 versions. We should probably drop support for Julia 1.1 (and 1.2?) at some point as well now that Julia 1.4 is out. I think we can keep support for the last 2/3 Julia minor versions and Julia 1.0.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Mar 22, 2020

ReverseDiff performance here is bad because I am building the tape every time. I will try to hack into Memoization.jl to memoize the compiled tape. This means that non-static control flow will not be supported with ReverseDiff, e.g. runtime if statements depending on the values or while loops. This is a strong argument against using ReverseDiff by default in Turing. I think it makes sense to enable ReverseDiff but warn the user that only limited control flow is supported, e.g. fixed size loops and compile-time if statements.

@mohamed82008
Copy link
Member Author

I will make the memoization optional. In some of the benchmarks, ReverseDiff is still a good choice of AD even without caching the tape.

@yebai
Copy link
Member

yebai commented Mar 22, 2020

ReverseDiff gives a mysterious error in Julia 1.0 and 1.1 so I removed the ReverseDiff inference tests from these 2 versions. We should probably drop support for Julia 1.1 (and 1.2?) at some point as well now that Julia 1.4 is out. I think we can keep support for the last 2/3 Julia minor versions and Julia 1.0.

Sounds good to support Julia 1.0 and two most recent versions.

I will make the memoization optional. In some of the benchmarks, ReverseDiff is still a good choice of AD even without caching the tape.

Do we really need Memoize.jl for caching a tape? Maybe cache it in the Model type?

@mohamed82008
Copy link
Member Author

The Model type is not owned by Turing. It's owned by DynamicPPL. The caching needs to be done for each model and sampler, e.g. in case of Gibbs of HMCs. The best tool to do this is memoization imo.

@mohamed82008
Copy link
Member Author

Some bijectors use if statements to guarantee numerical stability, e.g. not calling log on negative numbers. These need custom ReverseDiff adjoints to avoid compiling only one branch of the if statement. The Dirichlet distribution for example doesn't work properly with ReverseDiff currently.

@mohamed82008
Copy link
Member Author

This is only needed when caching is turned on.

@mohamed82008
Copy link
Member Author

Btw ReverseDiff with caching is giving me some decent speedup in some of the TuringExamples benchmarks. For the remaining, I have an idea why it is slow and I will work on them.

@mohamed82008
Copy link
Member Author

The variational inference code needs to be updated to add the caching bits.

@devmotion
Copy link
Member

devmotion commented Mar 22, 2020

The Model type is not owned by Turing. It's owned by DynamicPPL.

Maybe one could wrap model::Model inside of a TuringModel or a ReverseDiffTuringModel or ModelWithTape or ...? Then basically stuff like runmodel! would be forwarded to the wrapped model but in the gradient computation one would have access to the cache. I was wondering at which stage that wrapping should happen but maybe one could even have a user-facing cachetape(::Model) function that returns this struct which contains the model and the cached tape. This cached model could then be used just as every other Model but would provide access to the tape. Maybe that would make it even clearer that the tape is cached than having the setcache switch.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Mar 23, 2020

Sure, there are ways around Memoization.jl but they are not convenient. The methods you describe will probably be a little bit faster but will probably require way more than the 15 lines of code that implement caching in this PR. Let's think of tradeoffs and alternatives in a separate PR if you don't mind. I think Memoization is far from being the bottleneck at this point so I can focus my optimization and coding time elsewhere for now, e.g. custom adjoints. Memoization.jl is also a tiny package with only MacroTools as a dep (which is also a dep of DynamicPPL) so it should add very little to loading time in Turing if that's the concern.

@xukai92
Copy link
Member

xukai92 commented Mar 23, 2020

Why do we need to cache the tape somewhere explictly? Does something like this work https://github.com/TuringLang/AdvancedHMC.jl/blob/kx/benchmarking/src/contrib/ad.jl#L112-L124 ?

@mohamed82008
Copy link
Member Author

But where and how often do we call get_∂ℓπ∂θ_reversediff?

@xukai92
Copy link
Member

xukai92 commented Mar 23, 2020

@xukai92
Copy link
Member

xukai92 commented Mar 23, 2020

Sorry it doesn't solve the issue.

@xukai92
Copy link
Member

xukai92 commented Mar 23, 2020

For HMC is can be cahced here:

function gen_∂logπ∂θ(vi::VarInfo, spl::Sampler, model)

@mohamed82008
Copy link
Member Author

I think we need to rethink how we do AD in Turing to enable this kind of caching in a less hacky way than memoization. But I would leave this to another PR.

@mohamed82008
Copy link
Member Author

Note that this function is called more than once when doing Gibbs sampling

∂logπ∂θ = gen_∂logπ∂θ(spl.state.vi, spl, model)
.

@mohamed82008
Copy link
Member Author

The memoization approach will work with Gibbs sampling provided that no runtime if statements exist so caching can be opt-in by users if they know the risks.

@mohamed82008
Copy link
Member Author

And no stochastic control flow or dynamic size variables are allowed too. We need to make the limitations of caching clear to the user.

@xukai92
Copy link
Member

xukai92 commented Mar 23, 2020

Note that this function is called more than once when doing Gibbs sampling

For Gibbs we are supposed to handle the case that the number of non-HMC variables actually changes, so I think the most safe way is to rebuild the tape for each HMC interation. I think it's relatively easy to cache tape in gen_∂logπ∂θ (no matter how often it's called in Gibbs) - at least for pure HMC is a big gain, and see if we can make futher improvement later.

@mohamed82008
Copy link
Member Author

What about Gibbs of HMCs? Given that caching is optional, by default ReverseDiff will work with a variable number of variables by creating the tape every iteration. That's why caching is opt-in.

@xukai92
Copy link
Member

xukai92 commented Mar 23, 2020

OK I see your point here. I think it's fine to finish this PR without caching then.

@mohamed82008
Copy link
Member Author

Hmm caching is already implemented in the PR though :) And it gives a significant speedup in TuringExamples.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Mar 23, 2020

I think depending on Memoization.jl is not a terrible thing to do in and of itself. Whether this is the best design of the caching mechanism or not, this is up for discussion. But imo there is no need to deny ourselves the speedup of caching until we figure out a better design. I say we keep memoization for now and discuss changes in a meeting or so after merging.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Mar 23, 2020

So I guess the choices here are to either:

  1. Enable opt-in tape caching for pure HMC only and not depend on Memoization.jl, or
  2. Enable opt-in tape caching for pure HMC and Gibbs but depend on Memoization.jl.

Both will require a similar number of lines of code.

@mohamed82008
Copy link
Member Author

And option 1 will be faster for tiny models when caching is on for pure HMC.

@mohamed82008
Copy link
Member Author

But arguably, tiny models should be using ForwardDiff anyways.

@devmotion
Copy link
Member

Couldn't we get both without memoization by combining both approaches mentioned above?

If enabled, we would cache the tape in each HMC iteration by computing it in

function gen_∂logπ∂θ(vi::VarInfo, spl::Sampler, model)
. Additionally we provide a function cachetape(::Model) that returns a model with cached tape

struct ModelWithCachedTape{M<:Model,C}
	model::M
	cache::C
end

and for which we just implement the required functions in https://github.com/TuringLang/DynamicPPL.jl/blob/8028aaa912354321c33faeb30fafe1f45f637585/src/model.jl#L22-L24 and stuff like bijector by forwarding it to model. Then the only thing left to do is to use that cache in

function gen_∂logπ∂θ(vi::VarInfo, spl::Sampler, model)
instead of recomputing it and replacing most occurrences of ::Model with ::TuringModel which would be Union{Model,ModelWithCachedTape}.

That would make it very clear to a user also what the differences between both approaches are IMO, and as default we would provide the possibly slower but definitely correct implementation.

@mohamed82008
Copy link
Member Author

Note that generating the tape needs vi, spl and model not just model, so these would have to be inputs to cachetape. Also constructing the model in Turing happens independent of the sampler or AD backend choice, so CachedModel should be an implementation detail, preferably hidden away from the user and turned on with a flag, e.g. Turing.setcache. For Gibbs, we would also need to generate a tape for each HMC sub-sampler. Again all of this is possible but it doesn't provide a great value compared to the current approach imo.

@mohamed82008
Copy link
Member Author

I think there is a need to have more subtypes of AbstractModel that trigger different optimizations so in principle I am not against your proposal @devmotion. I just don't think it is necessary to do right now.

@mohamed82008
Copy link
Member Author

I think this is good to go.

@devmotion
Copy link
Member

Note that generating the tape needs vi, spl and model not just model, so these would have to be inputs to cachetape.

Are model and sampler not sufficient, is not always possible to obtain vi from sampler? One would cache f that is defined inside of gradient_logp as a callable struct together with the tape and just update vi inside of gradient_logp.

Also constructing the model in Turing happens independent of the sampler or AD backend choice, so CachedModel should be an implementation detail, preferably hidden away from the user and turned on with a flag, e.g. Turing.setcache.

I think it would be a major feature if it is not hidden away from the user. Of course, the user should not deal with the type (in the same as she is not supposed to care about Model) but IMO it would be a big advantage to be able to call cachetape(model, sampler). cachetape would only be available if ReverseDiff is loaded and one could even provide a warning if sample is called with a CachedModel and another AD backend and then just forward the call to the underlying Model. IMO compared to the current approach in this PR, this is a more "Julian" way and it allows a fine-grained control of the tape for both the user and the developer.

I would go even further and say that one should avoid global switches if possible. As seen, e.g., in the CI tests, a setting in a completely different file might unexpectedly change the AD backend. So actually I would be in favour of making it a keyword argument. That would also allow to remove all the code that is needed to obtain a Val representation from the state that is saved as Symbol, which IMO feels a bit weird (and is not type stable, of course).

@xukai92
Copy link
Member

xukai92 commented Mar 23, 2020

Do we have test of using ReverseDiffAD{true}?

@mohamed82008
Copy link
Member Author

Do we have test of using ReverseDiffAD{true}?

Yes in the variational inference optimizers. But may be we need more tests.

Copy link
Member

@cpfiffer cpfiffer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Excellent work, as per usual.

@mohamed82008
Copy link
Member Author

If there are no more comments, this is ready to merge.

@yebai yebai merged commit e946541 into master Mar 24, 2020
@delete-merged-branch delete-merged-branch bot deleted the mt/reversediff branch March 24, 2020 14:00
@yebai
Copy link
Member

yebai commented Mar 24, 2020

Thanks @mohamed82008!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants