Variational Inference

We have some data from a population and we suspect that the it is generated by some underlying process. Estimating the process which generates the data allows us to understand its fundamental properties.

Concretely, $p(x)$ is the distribution of the data and $z$ are its latent variables, the process which generates the data is $p(x|z)$. Estimating the generation process is computing the true posterior $p(z|x)$. This is the (posterior) inference problem.

From Bayes rule, we have $$p(z|x) = \frac{p(x,z)}{p(x)}$$ with $p(x) = \int_z p(x,z)$. But this is often intractable as the quantities on RHS are non-trivial to estimate.

Variational inference is the technique which helps in converting this estimation problem into an optimization problem by approximating the posterior $p(z|x)$ with a family of simpler distributions $q_v(z)$. The best approximation is found by minimizing the divergence between $q$ and $p$.

$$ \min_v D_{KL}[q_v(z), p(z|x)] $$

$q_v$ is a family of distribution and a specific member is selected by $v$. In this context, $v$’s are called variational parameters.

There are many different divergence measures. But we are going to stick with KL divergence as it has a straight forward definition. Note that it is asymmetric.

$$D_{KL}(p,q) = \int_{\Omega} p \log\frac{p}{q}$$

Although this sounds promising and simple in theory, in practise we run into some difficulties. Optimising KL divergence is practically difficult. The gradient is readily available for simpler distributions and in practise $p$ can be very complex. This will force us to use very simple $q$ and have a bad approximations.

Let see if we can optimise KL divergence indirectly.

Consider this. $\log p(x)$ is called evidence or log-evidence of $x$ and is considered a constant. Since $\int_z q_v(z)=1$ we can write

$$\log p(x) = \log p(x) \int_z q_v(z)$$

Since $p(x)$ is invariant to the w.r.to $z$, we have,

$$\log p(x) = \int_z q_v(z) \log p(x)$$

A bit of clever substitutions and wrangling later,

$$ \log p(x) =\int_z q_v(z) \log \frac{p(x,z)}{q_v(z)} - \int_z q_v(z) \log \frac{p(z|x)}{q_v(z)}$$

Thus we have the following system.

$$\log p(x) = L(q,p) - D_{KL}(q,p)$$

$\log p(x)$ is constant and KL divergence is always $\geq 0$. The $L(q,p)$ is thus lower bound on the approximation between $p$ and $q$. It is hence called Evidence Lower Bound or ELBO.

Since a constant is equal to difference between 2 variables, maximising ELBO decreases the KL divergence making approximation of $q$ to $p$ better. Now we have an alternate way to find the best approximation.

Evidence Lower Bound

One easy way to understand effect of optimising evidence lower bound is to think of it as a combination of 2 entropies. $$L(q,p) = H(q) - H(q,p)$$ While maximizing $L$, the entropy term is pushing $q$ to spread everywhere while the negative cross entropy term is pushing $q$ to concentrate on regions where $p$ has high density. However, for using as a training objective, we need its gradient.

Black Box Variational Inference

ELBO has the following equivalent form.

$$L(q,p) = \mathbb{E}_{q_v(z)} \left [ \log p(x,z) - \log q_v(z) \right ]$$

Ranganath et al. (2014) gives the following form of gradient. $$\nabla_v L(v) = \mathbb{E}_{q_v(z)} \left [ \nabla_v \log q_v(z)\, (\log p(x,z) - \log q_v(z) )\right ]$$

Gradient of the log of probability distribution is called score function and this gradient is known as score gradient. When the gradient is in form of an expectation of a random variable, we take monte-carlo samples and to get approximate gradient.

Also note that the gradient assumes no knowledge of the model except that we can evaluate the quantities in the equation. More of less, a Black Box Variational Inference.

Unfortunately, the approximation of the gradient has high variance. Fortunately, Ranganath et al. (2014) also describes a few variance reduction techniques.

Reprameterisation gradient

Score gradient allows us to use complex distributions to approximate posterior distribution. But it is difficult to sample from complex distributions. Reparametrisation trick allows us to create complex distributions from simple ones.

Say we are able to write $z = t(\epsilon, v)$ with $\epsilon \sim s(\epsilon)$, a simple distribution which we can sample from. What we have done is to bound all “randomness” to $s(\epsilon)$ and made $q_v(z)$ “non-random”.

With this, we have a simpler gradient for elbo.

$$\nabla_v L(v) = \mathbb{E}_{s(\epsilon)}\big[ \nabla_z \left[ \log p(x,z) - \log q_v(z) \right] \nabla_v t(\epsilon, v) \big]$$

To compute approximate gradient, take samples $\epsilon \sim s(\epsilon)$, compute $z = t(\epsilon, v)$ and then evaluate the reparameterization gradient. $\log p(x,z) - \log q_v(z)$ is the model and $\nabla_z(\cdots)$ can be evaluated using auto-differentiation. See Ruiz et al. (2016) for more discussion on reprameterisation gradient.

Amortised Variational Inference

Variational Inference is still not scalable. We still have to fit the variational parameter for each observations essentially minimising KL divergence between each $q_v(z)$ and $p(z|x_i)$. This is not scalable and will lead to over fitting.

We can get around this by constrianing the variational parameters to be a function of the observations; a learnable function.

$$v = g_\phi(x)$$

In Amortised Inference, the variational parameters are output of a network which takes the samples from true distribution as input. The parameters of the network can be learned by optimising reparameterisation gradient.

$$\nabla_\phi L(\phi) = \mathbb{E}_{s(\epsilon)}\big[ \nabla_z \left[ \log p(x,z) - \log q_v(z) \right] \nabla_\phi t(\epsilon, g_\phi(x)) \big]$$

Variational Auto Encoders

Variational auto encoders are a generative model which learns to generate data from its true distribution. It has an architecture simlar to a denoising auto encoder and uses variational inference to learn the distribution.

Variational auto encoders make a constraint that the posterior is approximated by an gaussian distribution with diagonal covariances. As a result, the latent representation will have linearly independent dimensions.

The encoder is a differentiable network which is used to approximate posterior distribution $p(z|x)$. The network is trained to predict the parameters of the approximating distribution from a data point.

$$\mu, \sigma = g_{\phi}(x)$$

then, the posterior is approximated by the distribution $$ q_{\mu, \sigma}(z)= \mathcal{N}(z; \mu, diag(\sigma^2))$$

The latent vector is obtained by sampling from $q_{\mu, \sigma}(z)$.

The decoder is also a differentiable network which is trained to predict a sample from the latent vector. $$\hat{x} = f_{\theta}(z)$$

Different samples generate different predictions thus generating samples from $p(x|z)$. Some designs explicitly adds “intelligent noise” to aid in directed distribution. $\hat{x} = f_{\theta}(z+\epsilon)$

The forward inference of variational auto encoder defined in the Kingma et al. (2013) has the following form.

$$(\mu, \log \sigma) = g_\phi(x) $$ $$q_{\mu,\sigma}(z|x) = \mathcal{N} (z; \mu, diag(\sigma^2))$$ $$z \sim q_{\mu,\sigma}(z|x)$$ $$ \hat{x} = f_{\theta}(z)$$

Both the encoder and decoder are neural networks and $\phi$ and $\theta$ are their parameters. Both the networks are trained end to end to miniminse the following loss. $$L = C\|x - f(z) \|^2 + D_{KL}(\mathcal{N}(\mu, diag(\sigma^2), \mathcal{N}(0,I))$$

References

  • Blei et al. (2017): Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. “Variational inference: A review for statisticians.” Journal of the American statistical Association 112.518 (2017): 859-877.

  • Kingma et al. (2013): Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).

  • Kingma et al. (2019): Diederik P. Kingma and Max Welling (2019), “An Introduction to Variational Autoencoders”, Foundations and Trends in Machine Learning: Vol. xx, No.xx, pp 1–18. DOI: 10.1561/XXXXXXXXX.

  • Ranganath et al. (2014): Rajesh Ranganath, Sean Gerrish, and David Blei. “Black box variational inference.” Artificial intelligence and statistics. PMLR, 2014.

  • Rezende et al. (2014): Rezende, Danilo Jimenez, Shakir Mohamed, and Daan Wierstra. “Stochastic backpropagation and approximate inference in deep generative models.” International conference on machine learning. PMLR, 2014.

  • Ruiz et al. (2016): Ruiz, Francisco R., Titsias RC AUEB, and David Blei. “The generalized reparameterization gradient.” Advances in neural information processing systems 29 (2016).

  • Vincent et al. (2010): Vincent, P., Larochelle, H., Lajoie, I., Bengio, Y., Manzagol, P. A., & Bottou, L. (2010). “Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion”. Journal of machine learning research, 11(12).

  • Amortized Inference by Tuan Anh Le
    https://www.tuananhle.co.uk/notes/amortized-inference.html

  • Amortized Inference and Variational Auto Encoders
    https://erdogdu.github.io/csc412/notes/lec11-1.pdf

  • Divergence
    https://en.wikipedia.org/wiki/Divergence_(statistics)

  • Evidence Lower Bound
    https://en.wikipedia.org/wiki/Evidence_lower_bound

  • KL Divergence
    https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence

  • Stochastic Approximation
    https://en.wikipedia.org/wiki/Stochastic_approximation

  • The variational auto-encoder
    https://ermongroup.github.io/cs228-notes/extras/vae

  • Variational Inference with Normalizing Flows
    https://www.depthfirstlearning.com/2021/VI-with-NFs

  • Variational inference
    https://ermongroup.github.io/cs228-notes/inference/variational

  • Variational Inference: Foundations and Innovations by David Blei




tags: #machine-learning #optimisation #probability