Skip to content

VI & ADVI improvements #902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Feb 18, 2020
Merged

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Sep 2, 2019

Overview

This PR aims to improve VI, in particular ADVI, by adapting the new Bijectors.jl interface together with Stacked<:Bijector from this PR, introducing a new optimiser DecayedADAGrad, and general improvements. See the full list below.

Worth noting that there are things to do to improve this further, but as we will soon separate out the VI parts of Turing into AdvancedVI.jl, I think this is a good-enough impl for Turing and then we can improve further as we make the separation.

Changes

  1. meanfield(model::Model) which returns a mean-field approx of the model which in turn can be passed to vi(model, alg, q). This accommodates future plans of allowing "arbitrary" variational posterior q.
  2. This branch doesn't perform assignments in the same way we did before, so now we can do TrackedArray instead of Array{TrackedReal} (with the help of a bit of behavior-piracy from Tracker.jl to make it work with SubArray)
  3. Added DecayedADAGrad which is currently what's in use in Stan and suggested in [2]. TruncatedADAGrad is an impl of what they proposed in their paper on ADVI, but looking at the source code it seems they've moved away from this.[1] DecayedADAGrad seems to do better for more complex models, e.g. MNIST example it seems like choosing the truncation length for TruncatedADAGrad affects the convergence quite a bit (+ it's much more memory-intensive than DecayedADAGrad). TruncatedADAGrad works very nicely for smaller models though. Now we can use both :)
  4. Using softplus instead of exp when enforcing positiveness of variances of TuringDiagNormal which prevents a numerical errors I've been running into with exp in certain cases. There was a brief discussion about exp vs. softplus on our Slack, and the consensus seemed to be that there's no clear winner in general. Nonetheless I think it's better to go for the most numerically stable one given that we've already run into numerical issues with exp. Also worth noting that in [Section 3.3, 2] they actually observe better softplus end results than with exp.
  5. setadbackend now also sets the ADBackend of Bijectors.jl, which was introduced in the interface PR.
  6. MeanField has been fully replaced by TransformedDistribution{<:TuringDiagNormal, <:Stacked}. This results in several improvements:
    • Flexibility. With this change, the only thing part of VI / ADVI that is hard-coded is the handling of the parameters of variational posterior q. This we do because we need to enforce the positive-constraint for variance of TuringDiagNormal. As soon as we have a way of dealing with constraints on the parameters of the distributions themselves (not the support) in a more generic manner, we'll be able to plug in any q with the correct dimensionality and run ADVI using this instead of a TuringDiagNormal as a base distribution. Even now, if you have a custom base distribution, e.g. normalizing flow, we can implement vi(model::Model, alg::ADVI, q::TransformedDistribution{<:NormalizingFlow}) and similarily for (elbo::ELBO)(...) and we're good. With current advi we can't do VI on LDA because of singular Jacobians when using AD. When this PR with analytical expressions for jacobian of SimplexBijector is merged in Bijectors.jl, we'll able to do ADVI on LDA without any changes to the Turing-side of things.
    • Type-stability. Though my testing hasn't been extensive, from what I can tell, everything but Variational.meanfield(model::Model) is now type-stable (same as for samplers, etc. in current Turing.jl). The optimize call inside vi(model, advi) works as a function-barrier, thus everything other than the first line of vi(model, advi) is type-stable.
    • Performance. All of the above gives us a perf-increase. Moreover, since everything involving the actual transformations, etc. is now outsourced to Bijectors.jl it's much easier to optimize performance in the future + gives us analytical expressions for logabsdetjac for all standard distributions. Should also benefit from eventual support for batch computation in Bijectors.jl (see this issue.
    • Less code. The entire src/variational/distributions.jl is now redundant and so has been removed.
  7. Some extra cleanups of redundant code & comments left over from the previous ADVI PR.
  8. Added weight argument to (elbo::ELBO)(...). This can be used to weight the 𝔼_q[log p(x, z)] term of the ELBO. This is necessary for stuff like stochastic variational inference (SVI) (see SVI PR), more complex initialization, e.g. Eq. (20) in [3]. This weight argument is now baked into make_logjoint which makes more sense.
  9. Now vi have the following definitions:
    • vi(model, alg): uses default variational posterior q which is alg dependent (in case of ADVI this is a mean-field approx with unconstrained-to-constrained Bijector)
    • vi(model, alg, q, θ_init): allows the user to pass in an arbitrary variational posterior q as long as Variational.update(d::typeof(q), θ::AbstractArray)::typeof(q) is defined.
    • vi(model, alg, getq, θ_init): don't need to pass in an initial variational posterior q but instead pass a function getq(θ::AbstractArray)::VariationalPosterior, i.e. maps from params to distribution.
    • In all of the above, we can also replace model with logπ(z::AbstractArray)::Real which represents the map z → log p(x, z), where x are the observations, making this applicable even for non-Turing models (will prob. be more useful when we separate out VI into AdvancedVI.jl).

Examples

ELBO implementation

Old

# Everything is hard-coded for mean-field Gaussian
elbo_acc = 0.0

for i = 1:num_samples
    # iterate through priors, sample and update
    idx = 0
    z = zeros(T, num_params)
    
    for sym  keys(varinfo.metadata)
        md = varinfo.metadata[sym]
        
        for i = 1:size(md.dists, 1)
            prior = md.dists[i]
            r = md.ranges[i] .+ idx

            # mean-field params for this set of model params
            μ_i = μ[r]
            ω_i = ω[r]

            # obtain samples from mean-field posterior approximation
            η = randn(length(μ_i))
            ζ = center_diag_gaussian_inv(η, μ_i, exp.(ω_i))
            
            # inverse-transform back to domain of original priro
            z[r] .= invlink(prior, ζ)

            # add the log-det-jacobian of inverse transform;
            # `logabsdet` returns `(log(abs(det(M))), sign(det(M)))` so return first entry
            # add `eps` to ensure SingularException does not occurr in `logabsdet`
            elbo_acc += logabsdet(jac_inv_transform(prior, ζ) .+ eps(T))[1] / num_samples
        end

        idx += md.ranges[end][end]
    end
    
    # compute log density
    varinfo = VarInfo(varinfo, SampleFromUniform(), z)
    model(varinfo)
    elbo_acc += varinfo.logp / num_samples
end

# add the term for the entropy of the variational posterior
variational_posterior_entropy = sum(ω)
elbo_acc += variational_posterior_entropy

return elbo_acc

New

function (elbo::ELBO)(
    rng::AbstractRNG,
    alg::ADVI,
    q::VariationalPosterior,
    logπ,  # ← arbitrary callable z → log p(x, z)
    num_samples
)
    _, z, logjac, _ = forward(rng, q)
    res = (logπ(z) + logjac) / num_samples

    res += entropy(q.dist)
    
    for i = 2:num_samples
        _, z, logjac, _ = forward(rng, q)
        res += (logπ(z) + logjac) / num_samples
    end

    return res
end

with

function make_logjoint(model; weight = 1.0)
    # setup
    ctx = DynamicPPL.MiniBatchContext(
        DynamicPPL.DefaultContext(),
        weight
    )
    varinfo = Turing.VarInfo(model, ctx)

    function logπ(z)
        varinfo = VarInfo(varinfo, SampleFromUniform(), z)
        model(varinfo)
        
        return varinfo.logp
    end

    return logπ
end

to make this work nicely with Turing.Model.

TODOs

  • Wait for Stacked<:Bijector PR has been merged so we don't have to depend on github branch.
  • Improve docstrings for the following:
    • TruncatedADAGrad
    • DecayedADAGrad
    • ELBO and (elbo::ELBO)(...)
  • Tests for
    • DecayedADAGrad
    • TruncatedADAGrad

References

[1] Kucukelbir, A., Ranganath, R., Gelman, A., & Blei, D. M., Automatic variational inference in stan, CoRR, (), (2015).
[2] Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M., Automatic Differentiation Variational Inference, CoRR, (), (2016).
[3] Rezende, D. J., & Mohamed, S., Variational Inference With Normalizing Flows, CoRR, (), (2015).

@torfjelde
Copy link
Member Author

Now that Stacked is merged into Bijectors.jl, I'll finish up this and then complete the tutorial on ADVI (which depends on this branch).

@xukai92
Copy link
Member

xukai92 commented Oct 27, 2019

What's left to do for this PR?

@torfjelde
Copy link
Member Author

What's left to do for this PR?

There's the two remaining points on the TODO above, and I'd like to get most of the Bijectors.jl PRs merged first and then look into making this work with arbitrary variational posteriors q. To make that work, we need some way of doing parameter-updates for any q.

I'll have a proper look at this soon, then I'll be able to better judge what's left to do.

@yebai
Copy link
Member

yebai commented Jan 5, 2020

@torfjelde, maybe focus on this PR for the next 2 weeks, when you got time.

@xukai92
Copy link
Member

xukai92 commented Feb 12, 2020

It seems there are some conflicts. Can you fix them?

@torfjelde
Copy link
Member Author

On it! I'll also go through it one last time before we merge 👍

@xukai92
Copy link
Member

xukai92 commented Feb 12, 2020

Cool. We plan to do another review tonight.

@@ -4,8 +4,8 @@ using ..Core, ..Utilities
using Distributions, Bijectors, DynamicPPL
using ProgressMeter, LinearAlgebra
using ..Turing: PROGRESS
using DynamicPPL: Model, SampleFromPrior, SampleFromUniform
using ..Turing: Turing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we merge the two lines

using ..Turing: PROGRESS
using ..Turing: Turing	

@xukai92
Copy link
Member

xukai92 commented Feb 16, 2020

Is it ready to merge?

@@ -68,7 +68,7 @@ end

Creates a mean-field approximation with multivariate normal as underlying distribution.
"""
function meanfield(model::Model)
function meanfield(model::Model; rng::AbstractRNG = GLOBAL_RNG)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common approach (as far as I have seen in Random, Distributions and some other packages) is to allow a random number generator as first argument, i.e., to define

meanfield(model::Model) = meanfield(Random.GLOBAL_RNG, model)
function meanfield(rng::AbstractRNG, model::Model)
    ...
end

I'm wondering if that could/should be done here as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so this is what I do for ELBO (see variational/objectives.jl) and so on. I didn't do it with meanfield because I initially didn't think it necessary. What say you @xukai92 ? Should I change from kwarg to this?

Copy link
Member

@xukai92 xukai92 Feb 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non explictly, but we currently do both for two different situations

  • For internal functions, we stick to the way @devmotion describes above
  • For functions to end-users, we use keyword arguments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For functions to end-users, we use keyword arguments

I'm curious, which functions are you referring to? At least the user-facing sample, psample, and rand only allow a random number generator as first argument but not as keyword argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC our sample takes a keyword for rng

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not case, I guess that might have changed at some point. The current API of AbstractMCMC and the implementation of sample in Turing (

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::InferenceAlgorithm,
N::Integer;
chain_type=Chains,
kwargs...
)
return sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
end
function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
N::Integer;
chain_type=Chains,
resume_from=nothing,
kwargs...
)
if resume_from === nothing
return sample(model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
else
return resume(resume_from, N)
end
end
) do not take a random number generator as keyword argument but only as first argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I just notice that this implementation is quite inconsistent since only in the second case the resume_from keyword is respected. I'll change that in the updates for AbstractMCMC 0.4.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK cool. If all interfaces take this convention now we should stick to it. Thanks for updating me on this. @torfjelde Let's make it the optional first argument :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grreaaat! Done!

@xukai92
Copy link
Member

xukai92 commented Feb 17, 2020

I will merge this in 24 hours if no objection.

@yebai yebai changed the title [WIP] VI & ADVI improvements VI & ADVI improvements Feb 18, 2020
@xukai92 xukai92 merged commit 1f71d8d into TuringLang:master Feb 18, 2020
@xukai92
Copy link
Member

xukai92 commented Feb 18, 2020

Merged! Thanks for the great work from @torfjelde and everyone who helped review this PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants