Blob Blame History Raw
from __future__ import absolute_import
from __future__ import print_function
import matplotlib.pyplot as plt

import autograd.numpy as np
import autograd.scipy.stats.norm as norm

from autograd.misc.optimizers import adam, sgd

# same BBSVI function!
from black_box_svi import black_box_variational_inference

if __name__ == '__main__':

    # Specify an inference problem by its unnormalized log-density.
    # it's difficult to see the benefit in low dimensions
    # model parameters are a mean and a log_sigma
    np.random.seed(42)
    obs_dim = 20
    Y = np.random.randn(obs_dim, obs_dim).dot(np.random.randn(obs_dim))
    def log_density(x, t):
        mu, log_sigma = x[:, :obs_dim], x[:, obs_dim:]
        sigma_density = np.sum(norm.logpdf(log_sigma, 0, 1.35), axis=1)
        mu_density    = np.sum(norm.logpdf(Y, mu, np.exp(log_sigma)), axis=1)
        return sigma_density + mu_density

    # Build variational objective.
    D = obs_dim * 2    # dimension of our posterior
    objective, gradient, unpack_params = \
        black_box_variational_inference(log_density, D, num_samples=2000)

    # Define the natural gradient
    #   The natural gradient of the ELBO is the gradient of the elbo,
    #   preconditioned by the inverse Fisher Information Matrix.  The Fisher,
    #   in the case of a diagonal gaussian, is a diagonal matrix that is a
    #   simple function of the variance.  Intuitively, statistical distance
    #   created by perturbing the mean of an independent Gaussian is
    #   determined by how wide the distribution is along that dimension ---
    #   the wider the distribution, the less sensitive statistical distances is
    #   to perturbations of the mean; the narrower the distribution, the more
    #   the statistical distance changes when you perturb the mean (imagine
    #   an extremely narrow Gaussian --- basically a spike.  The KL between
    #   this Gaussian and a Gaussian $\epsilon$ away in location can be big ---
    #   moving the Gaussian could significantly reduce overlap in support
    #   which corresponds to a greater statistical distance).
    #
    #   When we want to move in directions of steepest ascent, we multiply by
    #   the inverse fisher --- that way we make quicker progress when the
    #   variance is wide, and we scale down our step size when the variance
    #   is small (which leads to more robust/less chaotic ascent).
    def fisher_diag(lam):
        mu, log_sigma = unpack_params(lam)
        return np.concatenate([np.exp(-2.*log_sigma),
                               np.ones(len(log_sigma))*2])

    # simple! basically free!
    natural_gradient = lambda lam, i: (1./fisher_diag(lam)) * gradient(lam, i)

    # function for keeping track of callback ELBO values (for plotting below)
    def optimize_and_lls(optfun):
        num_iters = 200
        elbos     = []
        def callback(params, t, g):
            elbo_val = -objective(params, t)
            elbos.append(elbo_val)
            if t % 50 == 0:
                print("Iteration {} lower bound {}".format(t, elbo_val))

        init_mean    = -1 * np.ones(D)
        init_log_std = -5 * np.ones(D)
        init_var_params = np.concatenate([init_mean, init_log_std])
        variational_params = optfun(num_iters, init_var_params, callback)
        return np.array(elbos)

    # let's optimize this with a few different step sizes
    elbo_lists = []
    step_sizes = [.1, .25, .5]
    for step_size in step_sizes:
        # optimize with standard gradient + adam
        optfun = lambda n, init, cb: adam(gradient, init, step_size=step_size,
                                                    num_iters=n, callback=cb)
        standard_lls = optimize_and_lls(optfun)

        # optimize with natural gradient + sgd, no momentum
        optnat = lambda n, init, cb: sgd(natural_gradient, init, step_size=step_size,
                                         num_iters=n, callback=cb, mass=.001)
        natural_lls = optimize_and_lls(optnat)
        elbo_lists.append((standard_lls, natural_lls))

    # visually compare the ELBO
    plt.figure(figsize=(12,8))
    colors = ['b', 'k', 'g']
    for col, ss, (stand_lls, nat_lls) in zip(colors, step_sizes, elbo_lists):
        plt.plot(np.arange(len(stand_lls)), stand_lls,
                 '--', label="standard (adam, step-size = %2.2f)"%ss, alpha=.5, c=col)
        plt.plot(np.arange(len(nat_lls)), nat_lls, '-',
                 label="natural (sgd, step-size = %2.2f)"%ss, c=col)

    llrange = natural_lls.max() - natural_lls.min()
    plt.ylim((natural_lls.max() - llrange*.1, natural_lls.max() + 10))
    plt.xlabel("optimization iteration")
    plt.ylabel("ELBO")
    plt.legend(loc='lower right')
    plt.title("%d dimensional posterior"%D)
    plt.show()