Understanding the Variational Autoencoder

11 minute read

Published:

Before you read: this blog is intended for people that already had some knowledge on VAE. If you are totally new to VAE, I suggest you read other blogs first.

Table of Contents

Probabilistic Generative Models (with maximum likelihood estimation)

Why do we care about the density of images $p(X)$?

We formulate the process of generating data as a probablistic sampling process, with the assumption that all images in the world are drawn from an underlying distribution, with a density function $p(X)$. Our goal is to approximate the underlying true distribution $p(X)$ with a model $p_\theta(X)$ with parameters $\theta$. The parameters are updated by maximizing the (log) likelihood of a training set ${x_1, x_2, …, x_M}$. Formally, the objective is

\[\arg\max_\theta \log (p_\theta(x_1) \cdot p_\theta(x_2) \cdot ... \cdot p_\theta(x_M)) = \arg\max_\theta \sum_i \log p_\theta (x_i)\]

Therefore, modeling the density of images is a must if we want to use MLE to estimate model parameters.

The introduction of latent variable $z$

In VAE, there is a latent variable $z$ assumed to follow a standard Gaussian distribution:

\[z \sim \mathcal{N}(\bf 0, \bf I)\]

During training, it is computed by the encoder network given the input image $x$. During inference, $z$ is sampled to serve as the input to the decoder to generate images. The latent variable $z$ is considered to carry information about the underlying patterns of the images ${x_1, x_2, …, x_M}$. For example, given a set of human face images, one dimension of $z$ may control the sexuality while another may control if the subject wears glasses. Hence, $z$ can be seen as an unobserved label for $x$.

The Math Behind VAE

Deriving the objective function

As mentioned above, VAE uses the maximum likelihood estimation (MLE) method, which means we want to maximize the log-likelihood over all training examples $\textbf{X}={x^{(1), }x^{(2)}, …, x^{(N)}}$. Formally,

\[\log p_\theta (\textbf{X}) = \sum_{i=1}^N \log p_\theta (x^{(i)})\]

From probability theory, we know that an intuitive way of computing $p_\theta(X=x)$ would be integrating the joint probability density function of $(X, Z)$ over $z$:

\[p_\theta(X=x) = \int p_\theta(X=x, Z=z) dz\]

By substituting the joint probability density by the conditional density, we have

\[p_\theta(X=x) = \int p_\theta(X=x\vert Z=z)p_\theta(Z=z) dz = \mathbb{E}_z[\,p_\theta(X=x\vert Z=z)]\,\]

Note that $p_\theta(Z=z)$ also has $\theta$ in its notation. Here $\theta$ can be generally regarded as the set of all parameters in these density functions. However, in the real implementation of VAE, $p_\theta(Z=z)$ is independent of $\theta$ as it is assumed to be a standard Gaussian.

In other words, the likelihood function is just the expectation of the conditional probability $p_\theta(X=x\vert Z=z)$ over all possible $z$. A naive way to estimate its value is to use Monte Carlo sampling. Suppose we sample $L$ examples $z_1, z_2, …, z_L$, then the estimated $p_\theta(X=x)$ is given by

\[p_\theta(X=x) = {1 \over L} \sum_{l=1}^L p_\theta(X=x\vert Z=z_l)\]

This solution seems natural because we sample some values of $z$ and estimate the expectation with the sample mean. However, the problem is that the conditional distribution $p_\theta(X\vert Z)$ is highly skewed. For most of the $z$’s, $p_\theta(X=x\vert Z=z_l)$ takes negligible values (near 0). This may be too abstract to understand but if you associate it with examples, it’s actually intuitive. Suppose there’s an image of a cat and an image of a car. Their latent variables $z_{cat}$ and $z_{car}$ is expected to be drastically different from each other. In other words, suitable latent variables for an image fall only in an extremely small region of the entire space. Therefore, the Monte Carlo method is hard to yield an accurate estimation.

Now that it’s not possible to integrate over all $z$, we now try another way to compute $p_\theta(X=x)$:

\[p_\theta(X=x) = {p_\theta(X=x, Z) \over p_\theta(Z\vert X=x)}\]

However, $p_\theta(Z\vert X=x)$ is also intractable. We introduce $q_\phi(Z\vert X=x)$ as its estimation, which is computed by a network with parameters $\phi$.

Here comes the big deal, the derivation of the Variational Lower Bound (VLB) or the Evidence Lower Bound (ELBO). Here, I provide 2 kinds of derivations.

Method 1

Since $\log p_\theta (x^{(i)})$ is independent of $z$, we can take expectation over $z$ and it stays the same:

\[\begin{align*} \log p_\theta (x^{(i)}) &= \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})}[\,\log p_\theta (x^{(i)})]\,\\ \end{align*}\]

Next, we expand the expression using properties of the conditional probability and multiply by $q_\phi(z\vert x^{(i)}) \over q_\phi(z\vert x^{(i)})$:

\[\begin{align*} \log p_\theta (x^{(i)}) &= \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} [\,\log {p_\theta(x^{(i)}\vert z) p_\theta(z) \over p_\theta(z\vert x^{(i)}) } ]\, \\ &= \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} [\,\log {p_\theta(x^{(i)}\vert z) p_\theta(z) \over p_\theta(z\vert x^{(i)}) } { q_\phi(z\vert x^{(i)}) \over q_\phi(z\vert x^{(i)})} ]\, \\ \end{align*}\]

Then, expanding the logarithm and rearrange:

\[\begin{align*} \log p_\theta (x^{(i)}) &= \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} [\,\log p_\theta(x^{(i)}\vert z) ]\, - \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} (\log{ q_\phi(z\vert x^{(i)})\over p_\theta(z)} ) + \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} [\,\log {q_\phi(z\vert x^{(i)}) \over p_\theta(z\vert x^{(i)})} ]\, \\ \end{align*}\]

Note that the KL divergence is defined as

\[\mathcal{D}_{KL}(P\vert \vert Q) = \int p(x) \log {p(x) \over q(x)}dx = \mathbb{E}_{x\sim p(x)}[\,\log {p(x) \over q(x)} ]\,\]

Then we convert the last 2 terms into KL divergence:

\[\begin{align*} \log p_\theta (x^{(i)}) &= \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} [\,\log p_\theta(x^{(i)}\vert z) ]\, - \mathcal{D}_{KL}(q_\phi(z\vert x^{(i)}) \vert \vert p_\theta(z))+ \mathcal{D}_{KL}(q_\phi(z\vert x^{(i)}) \vert \vert p_\theta(z\vert z^{(i)})) \\ &= {\mathcal{L} (x^{(i)}, \theta, \phi ) } + \mathcal{D}_{KL}(q_\phi(z\vert x^{(i)}) \vert \vert p_\theta(z\vert z^{(i)})) \end{align*}\]

Since KL divergence is non-negative, we have $\log p_\theta (x^{(i)}) \geq \mathcal{L} (x^{(i)}, \theta, \phi ) $, which makes it a lower bound of the log-likelihood function. Now we only need to maximize the lower bound (in red) as our objective.

Method 2

First, we multiply by $q_\phi(Z=z\vert X=x) \over q_\phi(Z=z\vert X=x)$:

\[\begin{align*} p_\theta (X=x) &= \int p_\theta(X=x, Z=z)dz \\ &= \int {p_\theta(X=x, Z=z) \over q_\phi(Z=z\vert X=x)} q_\phi(Z=z\vert X=x) dz \\ &= \mathbb{E}_{q_\phi(Z\vert X=x)}[\,{p_\theta(X=x, Z=z) \over q_\phi(Z=z\vert X=x)} ]\, \end{align*}\]

Then, we expand $p_\theta(X=x, Z=z)$ with conditional probability:

\[\begin{align*} p_\theta (X=x) &= \mathbb{E}_{q_\phi(Z\vert X=x)}[\,{p_\theta(X=x\vert Z)p_\theta(Z) \over q_\phi(Z=z\vert X=x)}]\, \end{align*}\]

Next, we take log on the expression:

\[\begin{align*} \log p_\theta (X=x) &= \log \mathbb{E}_{q_\phi(Z\vert X=x)}[\,{p_\theta(X=x\vert Z)p_\theta(Z) \over q_\phi(Z=z\vert X=x)} ]\, \end{align*}\]

Using the Jensen’s inequality:

\[\begin{align*} \log p_\theta (X=x) &= \log \mathbb{E}_{q_\phi(Z\vert X=x)}[\,{p_\theta(X=x\vert Z)p_\theta(Z) \over q_\phi(Z=z\vert X=x)} ]\,\\ &\geq \mathbb{E}_{q_\phi(Z\vert X=x)} [\,\log {p_\theta(X=x\vert Z)p_\theta(Z) \over q_\phi(Z=z\vert X=x)} ]\, \end{align*}\]

Thus,

\[\mathbb{E}_{q_\phi(Z\vert X=x)} [\,\log {p_\theta(X=x\vert Z)p_\theta(Z) \over q_\phi(Z=z\vert X=x)} ]\,\]

is a lower bound of $\log p_\theta(X=x)$, also called the variational lower bound. The rest of the derivation is the same as in Method 1.

To summarize, the final objective function is

\[\mathcal{L} (x^{(i)}, \theta, \phi ) = \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} [\,\log p_\theta(x^{(i)}\vert z) ]\, - \mathcal{D}_{KL}(q_\phi(z\vert x^{(i)}) \vert \vert p_\theta(z))\]

Interpreting the objective function

  1. Essentially, the first term is still the log-likelihood, measuring how well our model captures the image distribution. So it is also called the reconstruction error. The only difference is that instead of taking expectation over $z\sim p_\theta(z)$, i.e., the prior, we take expectation over $z \sim q_\phi(Z\vert X)$, i.e., the posterior. This solves the problem that only a small region of $z$’s assign high probability density to $p_\theta(X\vert Z)$ because now we are able to predict which $z$’s correspond to the input $x$ by modeling $q_\phi(Z\vert X)$.
  2. The second term is the negative KL divergence between $q_\phi(Z\vert X)$ and $p_\theta(z)$. Maximizing the objective is equivalent to minimizing the KL divergence. Why would we do that? This is because during training, the input $z$ to the decoder is sampled from $q_\phi(Z\vert X)$ while when we generate images, $z$ is sampled from $p_\theta(z)$. Intuitively, we want the two distribution to be close in order to generate valid images.

Implementation

A few assumptions

Now that we’ve derived everything analytically, let’s look into the implementation. Recall once again that the objective function is

\[\mathcal{L} (x^{(i)}, \theta, \phi ) = \mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} [\,\log p_\theta(x^{(i)}\vert z) ]\, - \mathcal{D}_{KL}(q_\phi(z\vert x^{(i)}) \vert \vert p_\theta(z))\]

We can see that three distributions are involved:

\[p_\theta(Z),p_\theta(X\vert Z), q_\phi(Z\vert X)\]

We will go through them one by one:

  • $p_\theta(Z)$ (Prior): As mentioned at the beginning, $p_\theta(Z)$ is assumed to be a standard normal distribution, i.e., $Z\sim \mathcal{N}(\textbf{0}, \textbf{I})$.
  • $p_\theta(X\vert Z)$ (Likelihood): We also assume that it is Gaussian, so we only need to estimate its parameters, i.e., mean and variance, and we do that using a neural network, with parameters $\theta$. Formally,
\[X\sim \mathcal{N}(\mu_\theta(z), \Sigma_\theta(z))\]

, where $\mu_\theta(z)$ and $\Sigma_\theta(z)$ are outputs of the neural network with $z$ being the input.

  • $q_\phi(Z\vert X)$ (Posterior): Again, it’s assumed to be Gaussian and we estimate its mean and variance using another neural network, with parameters $\phi$.
\[Z\sim \mathcal{N}(\mu_\phi(x), \Sigma_\phi(x) )\]

Moreover, another assumption is that dimensions of $z$ are independent, meaning that $\Sigma_\phi(x)$ is a diagonal matrix. This assumption simplifies the calculation of the objective, as will be shown in a moment.

Now let’s carry out the calculations.

Calculating the first term

Recall that the first term

\(\mathbb{E}_{z\sim q_\phi(z\vert x^{(i)})} [\,\log p_\theta(x^{(i)}\vert z) ]\,\) takes expectation over $z\sim q_\phi(Z\vert X)$. Therefore, we can sample $L$ examples of $z$, ${z_1, z_2, …, z_L}$, and take average:

\[{1\over L}\sum_{i=1}^L \log p_\theta(x\vert z_i)\]

However, the action of sampling is not differentiable, so it stops back-propagation. Here comes the reparameterization trick:

  • Sample $\epsilon_i^k \sim \mathcal{N}(0,1)$, where $k\in {1,2,…,d}$ and $d$ is the dimension of $z$.
  • Compute $z_i^k = \mu_\phi^k(x)+\epsilon_i^k \cdot \sigma_\phi^k(x)$

This decouples the process of sampling $z$ into first sampling $\epsilon$ and shifting its distribution. Suppose $x$ is an image of shape $(H,W)$. Then we have:

\[\begin{align*} \log p_\theta(x\vert z_i) &= -{1\over 2} (x-\mu_\theta(z_i))^T \Sigma_\theta(z_i)(x-\mu_\theta(z_i))-\log((2\pi)^{WH\over 2}\sqrt{\vert \Sigma_\theta(z_i)\vert })\\ &=-{1\over 2} (x-\mu_\theta(z_i))^T \Sigma_\theta(z_i)(x-\mu_\theta(z_i))-{1\over 2}\log((2\pi)^{WH}\vert \Sigma_\theta(z_i)\vert )\\ \end{align*}\]

We further assume that the covariance matrix $\Sigma_\theta(z_i)$ is an identity matrix.

\[\begin{align*} \log p_\theta(x\vert z_i) &=-{1\over 2} (x-\mu_\theta(z_i))^T I(x-\mu_\theta(z_i))-{1\over 2}\log((2\pi)^{WH}\vert I\vert )\\ &\propto - (x-\mu_\theta(z_i))^T (x-\mu_\theta(z_i))\\ &= -\sum_{u=1}^{W} \sum_{v=1}^H (x^{u,v} - \mu_\theta(z_i)^{u,v})^2 \end{align*}\]

Voila! Essentially, the first term is just (proportional to) the pixel-wise L2 loss. That’s why it’s called the reconstruction error.

Calculating the second term

The second term calculates the KL divergence between $q_\phi(Z\vert X)$ and $p(Z)$. They are both Gaussian and their covariance matrices are both diagonal, which means we can calculate the KL divergence for each dimension and then sum them up.

\[\mathcal{D}_{KL}(q_\phi(Z\vert X)\vert \vert p(Z)) = \sum_{i=1}^d \mathcal{D}_{KL}(q_\phi(Z^{(i)}\vert X)\vert \vert p(Z^{(i)} ))\]

Recall the how to compute the KL divergence for 2 Gaussian distributions:

\[\mathcal{D}_{KL}(\mathcal{N}(\mu^q, \sigma^q)\vert \vert \mathcal{N}(\mu^p, \sigma^p)) = \log {\sigma^p \over \sigma^q} - {1\over 2} + {(\sigma^q)^2 + (\mu^q - \mu^p)^2 \over 2(\sigma^p)^2}\]

, where $\mu^p=0$, $\sigma^p = 1$. So

\[\mathcal{D}_{KL}(\mathcal{N}(\mu^q, \sigma^q)\vert \vert \mathcal{N}(\mu^p, \sigma^p)) = \log {1\over \sigma^q} - {1\over 2} + {(\sigma^q)^2 + (\mu^q)^2 \over 2}\]

By far, we have gone through the implementation of VAE.