Deriving the KL divergence loss in variational autoencoders
Let's derive some things related to variational auto-encoders (VAEs).
Evidence Lower Bound (ELBO)
First, we'll state some assumptions. We have a dataset of images, xx. We'll assume that each image is generated from some unseen latent code zz, and there's an underlying distribution of latents p(zp(z). We'd like to discover parameters θ\theta to maximize the likelihood of pθ(x∣z)p_\theta(x|z) under the data. In practical terms -- train a neural network that will generate images from the dataset.
Let's look at the posterior term closer.
pθ(x∣z)=p_\theta(x|z) = \dfrac{p(z|x)p(x)}{p(z)}
- We can calculate pθ(x∣z)p_\theta(x|z) for any zz.
- We can assume p(z)p(z). In the VAE case, it's the standard Gaussian.
- We don't know p(x)p(x). It is intractable to compute p(x)=∫pθ(x∣z)p(z)dzp(x) = \int p_\theta(x|z)p(z)dz, because the space of all zz is large.
- We also don't know p(z∣x)p(z|x).
To help us, we'll introduce an approximate model: qϕ(z∣x)q_\phi(z|x). The idea is to closely match the true distribution p(z∣x)p(z|x), but in a way that we can sample from it. We can quantify the distance between the two distributions via KL divergence.
KL(qϕ(z∣x),p(z∣x))=Eq[log qϕ(z∣x)−log p(z∣x)] KL(q_\phi(z|x), p(z|x)) = E_q[log \; q_\phi(z|x) - log \; p(z|x)]
In fact, breaking down this term will give us a hint on how to measure p(x)p(x).
=Eq[log qϕ(z∣x)−log p(z∣x)] = E_q[log \; q_\phi(z|x) - log \; p(z|x)]
=Eq[log qϕ(z∣x)−log p(z,x)/p(x)]= E_q[log \; q_\phi(z|x) - log \; p(z,x)/p(x)]
=Eq[log qϕ(z∣x)−log p(z,x)+log p(x)]= E_q[log \; q_\phi(z|x) - log \; p(z,x) + log \; p(x)]
=Eq[log qϕ(z∣x)−log p(z,x)]+log p(x)= E_q[log \; q_\phi(z|x) - log \; p(z,x)] + log \; p(x)
=Eq[log qϕ(z∣x)−log pθ(x∣z)p(z)]+log p(x)= E_q[log \; q_\phi(z|x) - log \; p_\theta(x|z)p(z)] + log \; p(x)
We can rearrange terms as:
log p(x)=Eq[log pθ(x∣z)p(z)−log qϕ(z∣x)]+KL(qϕ(z∣x),p(z∣x))log \; p(x) = E_q[log \; p_\theta(x|z)p(z) - log \; q_\phi(z|x)] + KL(q_\phi(z|x), p(z|x))
So, log p(x)log \; p(x) breaks down into two terms.
-
The KL term here is intractable because it involves p(z∣x)p(z|x), so we can't sample it. At least, we know that KL is always greater than zero.
-
The expectation is tractable! It involves three functions that we can evaluate. Note that EqE_q means xx is sampled from the data, then zz computed via qϕ(z∣x)q_\phi(z|x). This expectation is a lower bound on log p(x)log \; p(x). We call it the evidence lower bound, or ELBO.
The ELBO can be further broken down into two components.
Eq[log pθ(x∣z)p(z)−log qϕ(z∣x)] E_q[log \; p_\theta(x|z)p(z) - log \; q_\phi(z|x)]
=Eq[log pθ(x∣z)+log p(z)−log qϕ(z∣x)]= E_q[log \; p_\theta(x|z)+ log \; p(z) - log \; q_\phi(z|x)]
=Eq[log pθ(x∣z)]+Eq[log p(z)−log qϕ(z∣x)]= E_q[log \; p_\theta(x|z)] + E_q[log \; p(z) - log \; q_\phi(z|x)]
=Eq[log pθ(x∣z)]+KL(p(z),qϕ(z∣x))= E_q[log \; p_\theta(x|z)] + KL(p(z), q_\phi(z|x))
The ELBO is equal to:
- A reconstruction objective. For all xx in the dataset, encoding it via qϕq_\phi then decoding via pθp_\theta should give high probability for the original xx.
- A prior-matching objective. For all xx in the dataset, the distribution of qϕ(z∣x)q_\phi(z|x) should be similar to the prior p(z)p(z).
Here's a practical way to look at the ELBO objective. We can't maximize p(x)p(x) directly because we don't have access to p(z∣x)p(z|x). But if we approximate p(z∣x)p(z|x) with qϕ(z∣x)q_\phi(z|x), we can get a lower bound on p(x)p(x) and maximize that instead. The lower bound depends on 1) How well x′=p(q(x))x' = p(q(x)) recreates the data, and 2) How well q(x)q(x) matches the true prior p(z)p(z).
Analytical KL divergence for Gaussians
In the classic VAE setup, p(x) is a standard Gaussian. The prior-matching objective above can be computed analytically without the need to sample q(z|x).
Since the VAE setup defines an independent Gaussian for each data point in the batch, we only need to consider the univariate Gaussian case (instead of a multivariate). The encoder network will output a mean μ\mu and standard deviation σ\sigma. We'll refer to this as P=N(μ,σ2)P = N(\mu, \sigma^2). The standard Gaussian is Q=N(0,1)Q = N(0,1).
The KL divergence measures how different two probability distributions are:
KL(P,Q)=EP[log P(x)−log Q(x)]KL(P,Q) = E_{P}[log \;P(x) - log \;Q(x)]
We'll also need the probability density for a Gaussian:
p(x∣μ,σ)=e−()2p(x|\mu,\sigma) = \dfrac{1}{\sigma \sqrt{2 \pi}} e^{-\dfrac{1}{2}(\dfrac{x-\mu}{\sigma})^2}
For Q, μ=0\mu=0 and σ=1\sigma=1 so the term simplifies.
q(x)=e−x2q(x) = \dfrac{1}{\sqrt{2 \pi}} e^{-\dfrac{1}{2}x^2}
Given the above, let's plug in to the KL divergence.
KL(P,Q)=EP[log()−()2−log()+(x)2]KL(P,Q) = E_P[log(\dfrac{1}{\sigma \sqrt{2 \pi}}) -\dfrac{1}{2}(\dfrac{x-\mu}{\sigma})^2 -log(\dfrac{1}{\sqrt{2 \pi}}) +\dfrac{1}{2}(x)^2]
Let's break down this expectation and deal with each term one-by-one.
KL(P,Q)=EP[log()−log()]+EP[−()2]+EP[(x)2]KL(P,Q) = E_P[log(\dfrac{1}{\sigma \sqrt{2 \pi}}) -log(\dfrac{1}{\sqrt{2 \pi}})] + E_P[- \dfrac{1}{2}(\dfrac{x-\mu}{\sigma})^2] + E_P[\dfrac{1}{2}(x)^2]
The first expectation involves two logs that can be combined into one. There is no xx term, so the expectation can be dropped.
=EP[−log()] = E_P[-log(\dfrac{\sigma \sqrt{2 \pi}}{\sqrt{2 \pi}})]
=EP[−log(σ)]= E_P[-log(\sigma)]
=−log(σ)=−(1/2)log(σ2)= -log(\sigma) = -(1/2)log(\sigma^2)
The second expectation can be simplified by knowing that EP[(x−μ)2]E_P[(x-\mu)^2] is the equation for variance (σ2\sigma^2).
EP[−()]E_P[- \dfrac{1}{2}(\dfrac{(x-\mu)^2}{\sigma^2})]
=−(1/2)EP[(x−μ)2)](1/σ2)= -(1/2) E_P[(x-\mu)^2)] (1/\sigma^2)
=−(1/2)σ2(1/σ2)= -(1/2) \sigma^2 (1/\sigma^2)
=−(1/2)= -(1/2)
The third expectation can also be explained in terms of variance. An equivalent equation for variance is σ2=E[X2]−E[X]2\sigma^2 = E[X^2] - E[X]^2. In our case, E[X]=μE[X] = \mu, and E[X2]E[X^2] is what we want to find. So,
σ2=E[X2]−μ2\sigma^2 = E[X^2] - \mu^2
E[X2]=σ2+μ2E[X^2] = \sigma^2 + \mu^2
(1/2)E[X2]=(1/2)(σ2+μ2)(1/2)E[X^2] = (1/2)(\sigma^2 + \mu^2)
Nice, all of our expectations are now expressed as functions of μ\mu and σ\sigma. Put together, we get the final equation for the KL divergence loss:
KL_Loss(μ,σ)=(1/2)[−log(σ2)−1+σ2+μ2]KL\_Loss(\mu,\sigma) = (1/2) [-log(\sigma^2) - 1 +\sigma^2 + \mu^2]
Note that our encoder would give us a batch-sized vector of μ\mu and σ\sigma. Because we assume each (μ,σ)(\mu, \sigma) pair parametrizes an independent Gaussian, the total loss is the sum of the above equation applied elementwise.
Sanity check: the KL Loss should be minimized when μ=0,σ=1\mu=0, \sigma=1.
(d/dμ) [−log(σ2)−1+σ2+μ2]=2μ→(d/d \mu) \; [-log(\sigma^2) - 1 +\sigma^2 + \mu^2] = 2 \mu \rightarrow min. at μ=0\mu=0.
(d/dσ) [−log(σ2)−1+σ2+μ2]=2σ−2/σ→(d/d \sigma) \; [-log(\sigma^2) - 1 +\sigma^2 + \mu^2] = 2 \sigma - 2/\sigma \rightarrow min. at σ=1\sigma=1.
All good!
An aside on batch-wise KL vs. element-wise KL
A common misconception when examining this KL loss is how it relates to each batch of data. One could imagine that the KL loss is meant to shape a given batch of latent vectors so the approximate distribution within the batch is standard Gaussian. This interpretation would mean that on average, the latents should have mean 0 and variation 1 – but each sample can vary. It is possible to nearly-perfectly match the standard Gaussian while conveying information about the input image.
The above interpretation is incorrect. Instead, the KL loss is applied element-wise. Each image is encoded into a mean+variation pair, and the loss encourages this explicit mean and variation to resemble 0 and 1. Thus to nearly-match the standard Gaussian, each image would encode to a standard Gaussian (and thus no unique information is conveyed). In this way, the KL loss and the recreation objective are conflicting. A balance is located depending on the scaling of the two objectives.