-
Notifications
You must be signed in to change notification settings - Fork 228
ReverseDiff support #1170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ReverseDiff support #1170
Conversation
ReverseDiff gives a mysterious error in Julia 1.0 and 1.1 so I removed the ReverseDiff inference tests from these 2 versions. We should probably drop support for Julia 1.1 (and 1.2?) at some point as well now that Julia 1.4 is out. I think we can keep support for the last 2/3 Julia minor versions and Julia 1.0. |
ReverseDiff performance here is bad because I am building the tape every time. I will try to hack into Memoization.jl to memoize the compiled tape. This means that non-static control flow will not be supported with ReverseDiff, e.g. runtime if statements depending on the values or while loops. This is a strong argument against using ReverseDiff by default in Turing. I think it makes sense to enable ReverseDiff but warn the user that only limited control flow is supported, e.g. fixed size loops and compile-time if statements. |
I will make the memoization optional. In some of the benchmarks, ReverseDiff is still a good choice of AD even without caching the tape. |
Sounds good to support Julia 1.0 and two most recent versions.
Do we really need |
The |
Some bijectors use if statements to guarantee numerical stability, e.g. not calling log on negative numbers. These need custom ReverseDiff adjoints to avoid compiling only one branch of the if statement. The Dirichlet distribution for example doesn't work properly with ReverseDiff currently. |
This is only needed when caching is turned on. |
Btw ReverseDiff with caching is giving me some decent speedup in some of the TuringExamples benchmarks. For the remaining, I have an idea why it is slow and I will work on them. |
The variational inference code needs to be updated to add the caching bits. |
Maybe one could wrap |
Sure, there are ways around Memoization.jl but they are not convenient. The methods you describe will probably be a little bit faster but will probably require way more than the 15 lines of code that implement caching in this PR. Let's think of tradeoffs and alternatives in a separate PR if you don't mind. I think Memoization is far from being the bottleneck at this point so I can focus my optimization and coding time elsewhere for now, e.g. custom adjoints. Memoization.jl is also a tiny package with only MacroTools as a dep (which is also a dep of DynamicPPL) so it should add very little to loading time in Turing if that's the concern. |
Why do we need to cache the tape somewhere explictly? Does something like this work https://github.com/TuringLang/AdvancedHMC.jl/blob/kx/benchmarking/src/contrib/ad.jl#L112-L124 ? |
But where and how often do we call |
Can https://github.com/TuringLang/Turing.jl/blob/master/src/core/ad.jl#L63 serve as |
Sorry it doesn't solve the issue. |
For HMC is can be cahced here: Turing.jl/src/inference/hmc.jl Line 399 in ba7556d
|
I think we need to rethink how we do AD in Turing to enable this kind of caching in a less hacky way than memoization. But I would leave this to another PR. |
Note that this function is called more than once when doing Gibbs sampling Turing.jl/src/inference/hmc.jl Line 349 in ba7556d
|
The memoization approach will work with Gibbs sampling provided that no runtime if statements exist so caching can be opt-in by users if they know the risks. |
And no stochastic control flow or dynamic size variables are allowed too. We need to make the limitations of caching clear to the user. |
For Gibbs we are supposed to handle the case that the number of non-HMC variables actually changes, so I think the most safe way is to rebuild the tape for each HMC interation. I think it's relatively easy to cache tape in |
What about Gibbs of HMCs? Given that caching is optional, by default ReverseDiff will work with a variable number of variables by creating the tape every iteration. That's why caching is opt-in. |
OK I see your point here. I think it's fine to finish this PR without caching then. |
Hmm caching is already implemented in the PR though :) And it gives a significant speedup in TuringExamples. |
I think depending on Memoization.jl is not a terrible thing to do in and of itself. Whether this is the best design of the caching mechanism or not, this is up for discussion. But imo there is no need to deny ourselves the speedup of caching until we figure out a better design. I say we keep memoization for now and discuss changes in a meeting or so after merging. |
So I guess the choices here are to either:
Both will require a similar number of lines of code. |
And option 1 will be faster for tiny models when caching is on for pure HMC. |
But arguably, tiny models should be using ForwardDiff anyways. |
Couldn't we get both without memoization by combining both approaches mentioned above? If enabled, we would cache the tape in each HMC iteration by computing it in Turing.jl/src/inference/hmc.jl Line 399 in ba7556d
cachetape(::Model) that returns a model with cached tape
struct ModelWithCachedTape{M<:Model,C}
model::M
cache::C
end and for which we just implement the required functions in https://github.com/TuringLang/DynamicPPL.jl/blob/8028aaa912354321c33faeb30fafe1f45f637585/src/model.jl#L22-L24 and stuff like Turing.jl/src/inference/hmc.jl Line 399 in ba7556d
::Model with ::TuringModel which would be Union{Model,ModelWithCachedTape} .
That would make it very clear to a user also what the differences between both approaches are IMO, and as default we would provide the possibly slower but definitely correct implementation. |
Note that generating the tape needs |
I think there is a need to have more subtypes of |
I think this is good to go. |
Are
I think it would be a major feature if it is not hidden away from the user. Of course, the user should not deal with the type (in the same as she is not supposed to care about I would go even further and say that one should avoid global switches if possible. As seen, e.g., in the CI tests, a setting in a completely different file might unexpectedly change the AD backend. So actually I would be in favour of making it a keyword argument. That would also allow to remove all the code that is needed to obtain a Val representation from the state that is saved as Symbol, which IMO feels a bit weird (and is not type stable, of course). |
Do we have test of using |
Yes in the variational inference optimizers. But may be we need more tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Excellent work, as per usual.
If there are no more comments, this is ready to merge. |
Thanks @mohamed82008! |
This PR adds ReverseDiff support. The most important change here is that in order to use Tracker for AD, the user would need to do
Turing.setadbackend(:tracker)
instead of:reverse_diff
after this PR is merged. ReverseDiff would be optional, much like Zygote, after this PR so the user needs to do:to use ReverseDiff. This PR requires a new release of Bijectors which I already made but the TagBot is taking a while to make the release and tag in the Bijectors repo. The General PR was already merged a while back. Let's wait for the TagBot then re-run the tests here.
Edit: Bijectors was released.