-
Notifications
You must be signed in to change notification settings - Fork 228
VI & ADVI improvements #902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Now that |
What's left to do for this PR? |
There's the two remaining points on the TODO above, and I'd like to get most of the Bijectors.jl PRs merged first and then look into making this work with arbitrary variational posteriors I'll have a proper look at this soon, then I'll be able to better judge what's left to do. |
@torfjelde, maybe focus on this PR for the next 2 weeks, when you got time. |
It seems there are some conflicts. Can you fix them? |
On it! I'll also go through it one last time before we merge 👍 |
Cool. We plan to do another review tonight. |
@@ -4,8 +4,8 @@ using ..Core, ..Utilities | |||
using Distributions, Bijectors, DynamicPPL | |||
using ProgressMeter, LinearAlgebra | |||
using ..Turing: PROGRESS | |||
using DynamicPPL: Model, SampleFromPrior, SampleFromUniform | |||
using ..Turing: Turing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we merge the two lines
using ..Turing: PROGRESS
using ..Turing: Turing
Is it ready to merge? |
src/variational/advi.jl
Outdated
@@ -68,7 +68,7 @@ end | |||
|
|||
Creates a mean-field approximation with multivariate normal as underlying distribution. | |||
""" | |||
function meanfield(model::Model) | |||
function meanfield(model::Model; rng::AbstractRNG = GLOBAL_RNG) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The common approach (as far as I have seen in Random, Distributions and some other packages) is to allow a random number generator as first argument, i.e., to define
meanfield(model::Model) = meanfield(Random.GLOBAL_RNG, model)
function meanfield(rng::AbstractRNG, model::Model)
...
end
I'm wondering if that could/should be done here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, so this is what I do for ELBO
(see variational/objectives.jl
) and so on. I didn't do it with meanfield
because I initially didn't think it necessary. What say you @xukai92 ? Should I change from kwarg to this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non explictly, but we currently do both for two different situations
- For internal functions, we stick to the way @devmotion describes above
- For functions to end-users, we use keyword arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For functions to end-users, we use keyword arguments
I'm curious, which functions are you referring to? At least the user-facing sample
, psample
, and rand
only allow a random number generator as first argument but not as keyword argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC our sample
takes a keyword for rng
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not case, I guess that might have changed at some point. The current API of AbstractMCMC and the implementation of sample
in Turing (
Turing.jl/src/inference/Inference.jl
Lines 143 to 167 in 9dc8a29
function AbstractMCMC.sample( | |
rng::AbstractRNG, | |
model::AbstractModel, | |
alg::InferenceAlgorithm, | |
N::Integer; | |
chain_type=Chains, | |
kwargs... | |
) | |
return sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...) | |
end | |
function AbstractMCMC.sample( | |
model::AbstractModel, | |
alg::InferenceAlgorithm, | |
N::Integer; | |
chain_type=Chains, | |
resume_from=nothing, | |
kwargs... | |
) | |
if resume_from === nothing | |
return sample(model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...) | |
else | |
return resume(resume_from, N) | |
end | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW I just notice that this implementation is quite inconsistent since only in the second case the resume_from
keyword is respected. I'll change that in the updates for AbstractMCMC 0.4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK cool. If all interfaces take this convention now we should stick to it. Thanks for updating me on this. @torfjelde Let's make it the optional first argument :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grreaaat! Done!
I will merge this in 24 hours if no objection. |
Merged! Thanks for the great work from @torfjelde and everyone who helped review this PR! |
Overview
This PR aims to improve VI, in particular ADVI, by adapting the new Bijectors.jl interface together with
Stacked<:Bijector
from this PR, introducing a new optimiserDecayedADAGrad
, and general improvements. See the full list below.Worth noting that there are things to do to improve this further, but as we will soon separate out the VI parts of Turing into AdvancedVI.jl, I think this is a good-enough impl for Turing and then we can improve further as we make the separation.
Changes
meanfield(model::Model)
which returns a mean-field approx of the model which in turn can be passed tovi(model, alg, q)
. This accommodates future plans of allowing "arbitrary" variational posterior q.TrackedArray
instead ofArray{TrackedReal}
(with the help of a bit of behavior-piracy from Tracker.jl to make it work withSubArray
)DecayedADAGrad
which is currently what's in use in Stan and suggested in [2].TruncatedADAGrad
is an impl of what they proposed in their paper on ADVI, but looking at the source code it seems they've moved away from this.[1]DecayedADAGrad
seems to do better for more complex models, e.g. MNIST example it seems like choosing the truncation length forTruncatedADAGrad
affects the convergence quite a bit (+ it's much more memory-intensive thanDecayedADAGrad
).TruncatedADAGrad
works very nicely for smaller models though. Now we can use both :)softplus
instead ofexp
when enforcing positiveness of variances ofTuringDiagNormal
which prevents a numerical errors I've been running into withexp
in certain cases. There was a brief discussion aboutexp
vs.softplus
on our Slack, and the consensus seemed to be that there's no clear winner in general. Nonetheless I think it's better to go for the most numerically stable one given that we've already run into numerical issues withexp
. Also worth noting that in [Section 3.3, 2] they actually observe bettersoftplus
end results than withexp
.setadbackend
now also sets theADBackend
of Bijectors.jl, which was introduced in the interface PR.MeanField
has been fully replaced byTransformedDistribution{<:TuringDiagNormal, <:Stacked}
. This results in several improvements:q
. This we do because we need to enforce the positive-constraint for variance ofTuringDiagNormal
. As soon as we have a way of dealing with constraints on the parameters of the distributions themselves (not the support) in a more generic manner, we'll be able to plug in anyq
with the correct dimensionality and run ADVI using this instead of aTuringDiagNormal
as a base distribution. Even now, if you have a custom base distribution, e.g. normalizing flow, we can implementvi(model::Model, alg::ADVI, q::TransformedDistribution{<:NormalizingFlow})
and similarily for(elbo::ELBO)(...)
and we're good. With currentadvi
we can't do VI on LDA because of singular Jacobians when using AD. When this PR with analytical expressions for jacobian ofSimplexBijector
is merged in Bijectors.jl, we'll able to do ADVI on LDA without any changes to the Turing-side of things.Variational.meanfield(model::Model)
is now type-stable (same as for samplers, etc. in current Turing.jl). Theoptimize
call insidevi(model, advi)
works as a function-barrier, thus everything other than the first line ofvi(model, advi)
is type-stable.logabsdetjac
for all standard distributions. Should also benefit from eventual support for batch computation in Bijectors.jl (see this issue.src/variational/distributions.jl
is now redundant and so has been removed.AddedThisweight
argument to(elbo::ELBO)(...)
. This can be used to weight the𝔼_q[log p(x, z)]
term of the ELBO. This is necessary for stuff like stochastic variational inference (SVI) (see SVI PR), more complex initialization, e.g. Eq. (20) in [3].weight
argument is now baked intomake_logjoint
which makes more sense.vi
have the following definitions:vi(model, alg)
: uses default variational posteriorq
which isalg
dependent (in case ofADVI
this is a mean-field approx with unconstrained-to-constrainedBijector
)vi(model, alg, q, θ_init)
: allows the user to pass in an arbitrary variational posteriorq
as long asVariational.update(d::typeof(q), θ::AbstractArray)::typeof(q)
is defined.vi(model, alg, getq, θ_init)
: don't need to pass in an initial variational posteriorq
but instead pass a functiongetq(θ::AbstractArray)::VariationalPosterior
, i.e. maps from params to distribution.model
withlogπ(z::AbstractArray)::Real
which represents the map z → log p(x, z), wherex
are the observations, making this applicable even for non-Turing models (will prob. be more useful when we separate out VI into AdvancedVI.jl).Examples
ELBO implementation
Old
New
with
to make this work nicely with
Turing.Model
.TODOs
Stacked<:Bijector
PR has been merged so we don't have to depend on github branch.TruncatedADAGrad
DecayedADAGrad
ELBO
and(elbo::ELBO)(...)
DecayedADAGrad
TruncatedADAGrad
References
[1] Kucukelbir, A., Ranganath, R., Gelman, A., & Blei, D. M., Automatic variational inference in stan, CoRR, (), (2015).
[2] Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M., Automatic Differentiation Variational Inference, CoRR, (), (2016).
[3] Rezende, D. J., & Mohamed, S., Variational Inference With Normalizing Flows, CoRR, (), (2015).