Mathematics of Evidence Lower Bound (ELBO) objective formulation is inevitably encountered when dealing with latent dynamics models and VAEs. In this blog, we break it down along with necessary concepts like the reparameterization trick, importance sampling and Jensen's inequality.
Prior Distribution#
Latent variable models aim to model observed data while leveraging hidden (latent) variables .
Joint Distribution#
In these models, the joint distribution can be defined as:
Where is the likelihood of observing given the latent state and is transition dynamics of latent states.
Marginal Likelihood#
To compute the marginal likelihood, we marginalize out as the following equation:
This integral is typically intractable due to high dimensionality and complex dependencies between variables. So we instead approximate this intractable data likelihood distribution with a tractable posterior approximation.
Importance sampling#
This statistical technique used to approximate expectations or probabilities by sampling from an easier-to-handle distribution (called the importance distribution) q(x) rather than the true, often intractable, distribution p(x). While sampling q(x) reweight the samples to account for the difference between q(x) and p(x), so it also known as importance weighting.
Rewriting using q(x), we multiply and divide by q(x):
Where w(x)= are the importance weights.
Variational Posterior q#
Variational posterior approximating the true posterior distribution is expressed as:
Tractability#
Why q is easier to tract?#
This factorization assumes that each latent variable depends only on local observations and preceding actions , rather than the entire sequence of states. In the marginal likelihood or prior distribution, as in equation (), one has to do the integral over all possible state values (probabilities of all states leading to an observation from ), which is infinite as the state value is continuous.
Why p is difficult to tract?#
P when is continuous#
Equation () involves integrating over , which consists of the T latent variables .
The sequence has a finite length T, but The integral aggregates over all possible continuous values of leading to a continuous T-fold integral.
P when is discrete#
If takes on discrete values , where K is the number of discrete states for , the marginal likelihood becomes:
Each summation is over all possible values of . For categorical might be with K states and T variables, we need to do a sum over combinations of .
Note: In discrete models (e.g., Hidden Markov Models), the marginal likelihood summation can often be computed efficiently using algorithms like the forward algorithm, avoiding an explicit sum over all possibilities.
This is why#
Marginal P is computationally intractable as exact integration over continuous spaces or summations over high dimensional discrete state with K values is rarely feasible. Computation of marginal likelihood using analytical methods is not possible, numerical methods are needed.
Note: This is why techniques like variational inference or Monte Carlo sampling are used to approximate the marginal likelihood.
Before proceeding with understanding expressions of ELBO, apart from about importance sampling, we will also need to know Jensen's equality.
Jensen’s Inequality#
Jensen's inequality states that for a convex (or concave) function f(x), and a random variable X:
If f(x) is concave, the inequality is flipped:
Deriving the Variational Bound ELBO#
In equation (), we introduce a variational approximation to the true posterior using importance sampling:
Taking the logarithm:
Since the logarithm is concave, we can apply Jensen's inequality as:
so that
This forms the Evidence Lower Bound (ELBO):
Equation of joint likelihood:
Using equation () of joint likelihood into this ELBO:
ELBO Components#
We can now split the ELBO into two main components: Reconstruction Term:
measures how well the model reconstructs the observations given the latent states and KL Divergence measures how much the variational posterior deviates from the prior dynamics .
Final ELBO can be written as:
The loss function becomes:
Non-differentiability Problem: Stochastic Sampling from variational posterior#
While maximizing the ELBO if q is a complex function of parameters (e.g., neural network outputs), gradients for optimization can become problematic due to the non-differentiable nature of sampling operations in q.
How? To optimize the parameters of , we compute gradients of \mathcal{L} with respect to . The expectation is over , meaning we must sample to estimate this expectation. But sampling is stochastic and hence non-differentiable - the stochasticity in sampling disrupts the computation of gradients via backpropagation.
Solution: Reparameterization Trick#
It re-expresses the stochastic sampling as a deterministic transformation of a fixed random variable. Instead of directly sampling from , It introduces a Noise by sampling from a fixed distribution (e.g., - we move the randomness to that is independent of learnable parameters making the computation of tractable.
In the deterministic transformation as below gradients can flow through and via backpropagation: