Mathematics of Evidence Lower Bound (ELBO) objective formulation is inevitably encountered when dealing with latent dynamics models and VAEs. In this blog, we break it down along with necessary concepts like the reparameterization trick, importance sampling and Jensen's inequality.

Prior Distribution

Latent variable models aim to model observed data o1:T while leveraging hidden (latent) variables s1:T​.

Joint Distribution

In these models, the joint distribution can be defined as: (1)p(o1:T,s1:Ta1:T)=t=1Tp(otst)p(stst1,at1) Where p(otst) is the likelihood of observing ot given the latent state st​ and p(stst1,at1) is transition dynamics of latent states.

Marginal Likelihood

To compute the marginal likelihood, we marginalize out s1:T as the following equation: (2)p(o1:Ta1:T)=p(o1:T,s1:Ta1:T)ds1:T This integral is typically intractable due to high dimensionality and complex dependencies between variables. So we instead approximate this intractable data likelihood distribution with a tractable posterior approximation.

Importance sampling

This statistical technique used to approximate expectations or probabilities by sampling from an easier-to-handle distribution (called the importance distribution) q(x) rather than the true, often intractable, distribution p(x). While sampling q(x) reweight the samples to account for the difference between q(x) and p(x), so it also known as importance weighting. (1)Ep(x)[f(x)]=f(x)p(x)dx Rewriting using q(x), we multiply and divide by q(x): (2)Ep(x)[f(x)]=f(x)p(x)q(x)q(x)dx=Eq(x)[f(x)p(x)q(x)] Where w(x)= p(x)q(x)​ are the importance weights.

Variational Posterior q

Variational posterior approximating the true posterior distribution p(s1:To1:T,a1:T) is expressed as: (3)q(s1:To1:T,a1:T)=t=1Tq(stot,a<t)

Tractability

Why q is easier to tract?

This factorization assumes that each latent variable st depends only on local observations ot​ and preceding actions a<t​, rather than the entire sequence of states. In the marginal likelihood or prior distribution, as in equation (2), one has to do the integral over all possible state values (probabilities of all states leading to an observation from o1:T), which is infinite as the state value is continuous.

Why p is difficult to tract?

P when st is continuous

Equation (2) involves integrating over s1:T​, which consists of the T latent variables s1,s2,,sT​. (4)ds1:T=Rnds1Rnds2RndsT The sequence s1:T​ has a finite length T, but The integral aggregates over all possible continuous values of s1,s2, leading to a continuous T-fold integral.

P when st is discrete

If st​ takes on discrete values st(1),st(2),,st(K), where K is the number of discrete states for st​, the marginal likelihood becomes: (5).s1:T=s1s2sT Each summation is over all possible values of st​. For categorical st​ might be with K states and T variables, we need to do a sum over KT combinations of s1:T. Note: In discrete models (e.g., Hidden Markov Models), the marginal likelihood summation can often be computed efficiently using algorithms like the forward algorithm, avoiding an explicit sum over all KT possibilities.

This is why

Marginal P is computationally intractable as exact integration over continuous spaces or KT summations over high dimensional discrete state with K values is rarely feasible. Computation of marginal likelihood using analytical methods is not possible, numerical methods are needed. Note: This is why techniques like variational inference or Monte Carlo sampling are used to approximate the marginal likelihood. Before proceeding with understanding expressions of ELBO, apart from about importance sampling, we will also need to know Jensen's equality.

Jensen’s Inequality

Jensen's inequality states that for a convex (or concave) function f(x), and a random variable X: (6)f(E[X])E If f(x) is concave, the inequality is flipped: (7)lnE[x]E[lnx]

Deriving the Variational Bound ELBO

In equation (2), we introduce a variational approximation q(s1:To1:T,a1:T) to the true posterior p(s1:To1:T,a1:T) using importance sampling: (8)p(o1:Ta1:T)=q(s1:To1:T,a1:T)p(o1:T,s1:Ta1:T)q(s1:To1:T,a1:T)ds1:T Taking the logarithm: (9)lnp(o1:Ta1:T)=lnq(s1:To1:T,a1:T)p(o1:T,s1:Ta1:T)q(s1:To1:T,a1:T)ds1:T Since the logarithm is concave, we can apply Jensen's inequality as: (10)ln(E[X])E[ln(X)] so that (11)lnp(o1:Ta1:T)q(s1:To1:T,a1:T)lnp(o1:T,s1:Ta1:T)q(s1:To1:T,a1:T)ds1:T This forms the Evidence Lower Bound (ELBO): (12)lnp(o1:Ta1:T)Eq(s1:To1:T,a1:T)[lnp(o1:T,s1:Ta1:T)lnq(s1:To1:T,a1:T)] Equation of joint likelihood: (13)p(o1:T,s1:Ta1:T)=t=1Tp(otst)p(stst1,at1) Using equation (1) of joint likelihood into this ELBO: (14)lnp(o1:Ta1:T)Eq(s1:To1:T,a1:T)[t=1Tlnp(otst)+lnp(stst1,at1)lnq(stot,a<t)].

ELBO Components

We can now split the ELBO into two main components: Reconstruction Term: measures how well the model reconstructs the observations given the latent states and KL Divergence measures how much the variational posterior q(stot,att) deviates from the prior dynamics p(stst1,at1). Final ELBO can be written as: (15)L=t=1TEq(stot)[lnp(otst)]Reconstruction Losst=1TDKL(q(stot,a<t)p(stst1,at1))Complexity The loss function becomes: (16)L=ELBO=Reconstruction Loss+Complexity Penalty

Non-differentiability Problem: Stochastic Sampling from variational posterior

While maximizing the ELBO if q is a complex function of parameters (e.g., neural network outputs), gradients for optimization can become problematic due to the non-differentiable nature of sampling operations in q. How? To optimize the parameters ϕ of q(s1:To1:T,a1:T), we compute gradients of \mathcal{L} with respect to ϕ. The expectation is over q(s1:To1:T,a1:T)), meaning we must sample s1:T​ to estimate this expectation. But sampling s1:Tq(s1:To1:T,a1:T) is stochastic and hence non-differentiable - the stochasticity in sampling disrupts the computation of gradients via backpropagation.

Solution: Reparameterization Trick

It re-expresses the stochastic sampling as a deterministic transformation of a fixed random variable. Instead of directly sampling st​ from q(stot,a<t), It introduces a Noise ϵ by sampling from a fixed distribution (e.g., N(0,1) - we move the randomness to ϵ that is independent of learnable parameters ϕ making the computation of ϕ tractable. In the deterministic transformation as below gradients ϕ​ can flow through μϕ​ and σϕ​ via backpropagation: (17)st=μϕ+σϕϵN(0,1)