Let's derive some things related to variational auto-encoders (VAEs).
Evidence Lower Bound (ELBO)
First, we'll state some assumptions. We have a dataset of images, . We'll assume that each image is generated from some unseen latent code , and there's an underlying distribution of latents ). We'd like to discover parameters to maximize the likelihood of under the data. In practical terms -- train a neural network that will generate images from the dataset.
Let's look at the posterior term closer.
- We can calculate for any .
- We can assume . In the VAE case, it's the standard Gaussian.
- We don't know . It is intractable to compute , because the space of all is large.
- We also don't know .
To help us, we'll introduce an approximate model: . The idea is to closely match the true distribution , but in a way that we can sample from it. We can quantify the distance between the two distributions via KL divergence.
In fact, breaking down this term will give us a hint on how to measure .
We can rearrange terms as:
So, breaks down into two terms.
-
The KL term here is intractable because it involves , so we can't sample it. At least, we know that KL is always greater than zero.
-
The expectation is tractable! It involves three functions that we can evaluate. Note that means is sampled from the data, then computed via . This expectation is a lower bound on . We call it the evidence lower bound, or ELBO.
The ELBO can be further broken down into two components.
The ELBO is equal to:
- A reconstruction objective. For all in the dataset, encoding it via then decoding via should give high probability for the original .
- A prior-matching objective. For all in the dataset, the distribution of should be similar to the prior .
Here's a practical way to look at the ELBO objective. We can't maximize directly because we don't have access to . But if we approximate with , we can get a lower bound on and maximize that instead. The lower bound depends on 1) How well recreates the data, and 2) How well matches the true prior .
Analytical KL divergence for Gaussians
In the classic VAE setup, p(x) is a standard Gaussian. The prior-matching objective above can be computed analytically without the need to sample q(z|x).
Since the VAE setup defines an independent Gaussian for each data point in the batch, we only need to consider the univariate Gaussian case (instead of a multivariate). The encoder network will output a mean and standard deviation . We'll refer to this as . The standard Gaussian is .
The KL divergence measures how different two probability distributions are:
We'll also need the probability density for a Gaussian:
For Q, and so the term simplifies.
Given the above, let's plug in to the KL divergence.
Let's break down this expectation and deal with each term one-by-one.
The first expectation involves two logs that can be combined into one. There is no term, so the expectation can be dropped.
The second expectation can be simplified by knowing that is the equation for variance ().
The third expectation can also be explained in terms of variance. An equivalent equation for variance is . In our case, , and is what we want to find. So,
Nice, all of our expectations are now expressed as functions of and . Put together, we get the final equation for the KL divergence loss:
Note that our encoder would give us a batch-sized vector of and . Because we assume each pair parametrizes an independent Gaussian, the total loss is the sum of the above equation applied elementwise.
Sanity check: the KL Loss should be minimized when .
min. at .
min. at .
All good!
An aside on batch-wise KL vs. element-wise KL
A common misconception when examining this KL loss is how it relates to each batch of data. One could imagine that the KL loss is meant to shape a given batch of latent vectors so the approximate distribution within the batch is standard Gaussian. This interpretation would mean that on average, the latents should have mean 0 and variation 1 – but each sample can vary. It is possible to nearly-perfectly match the standard Gaussian while conveying information about the input image.
The above interpretation is incorrect. Instead, the KL loss is applied element-wise. Each image is encoded into a mean+variation pair, and the loss encourages this explicit mean and variation to resemble 0 and 1. Thus to nearly-match the standard Gaussian, each image would encode to a standard Gaussian (and thus no unique information is conveyed). In this way, the KL loss and the recreation objective are conflicting. A balance is located depending on the scaling of the two objectives.