Mathematics of Evidence Lower Bound (ELBO) objective formulation is inevitably encountered when dealing with latent dynamics models and VAEs. In this blog, we break it down along with necessary concepts like the reparameterization trick, importance sampling and Jensen's inequality.

Prior Distribution

Latent variable models aim to model observed data $o_{1:T}$ while leveraging hidden (latent) variables $s_{1:T}$​.

Joint Distribution

In these models, the joint distribution can be defined as: $$ p(o_{1:T}, s_{1:T} \mid a_{1:T}) = \prod_{t=1}^T p(o_t \mid s_t) p(s_t \mid s_{t-1}, a_{t-1}) \tag{1} \label{eq:jointlikelihood}$$ Where $p(o_t \mid s_t)$ is the likelihood of observing $o_t$ given the latent state $s_t$​ and $p(s_t \mid s_{t-1}, a_{t-1})$ is transition dynamics of latent states.

Marginal Likelihood

To compute the marginal likelihood, we marginalize out $s_{1:T}$ as the following equation: $$ p(o_{1:T} \mid a_{1:T}) = \int p(o_{1:T}, s_{1:T} \mid a_{1:T}) \, ds_{1:T} \tag{2} \label{eq:marginallikelihood} $$ This integral is typically intractable due to high dimensionality and complex dependencies between variables. So we instead approximate this intractable data likelihood distribution with a tractable posterior approximation.

Importance sampling

This statistical technique used to approximate expectations or probabilities by sampling from an easier-to-handle distribution (called the importance distribution) q(x) rather than the true, often intractable, distribution p(x). While sampling q(x) reweight the samples to account for the difference between q(x) and p(x), so it also known as importance weighting. $$ \mathbb{E}_{p(x)}[f(x)] = \int f(x) p(x) \, dx $$ Rewriting using q(x), we multiply and divide by q(x): $$ \mathbb{E}_{p(x)}[f(x)] = \int f(x) \frac{p(x)}{q(x)} q(x) \, dx = \mathbb{E}_{q(x)} \left[ f(x) \frac{p(x)}{q(x)} \right] $$ Where w(x)= $\frac{p(x)}{q(x)}$​ are the importance weights.

Variational Posterior q

Variational posterior approximating the true posterior distribution $p(s_{1:T} \mid o_{1:T}, a_{1:T})$ is expressed as: $$q(s_{1:T} \mid o_{1:T}, a_{1:T}) = \prod_{t=1}^T q(s_t \mid o_t, a_{\lt t})$$

Tractability

Why q is easier to tract?

This factorization assumes that each latent variable $s_t​$ depends only on local observations $o_t$​ and preceding actions $a_{\lt t}$​, rather than the entire sequence of states. In the marginal likelihood or prior distribution, as in equation (\ref{eq:marginallikelihood}), one has to do the integral over all possible state values (probabilities of all states leading to an observation from $o_{1:T}$), which is infinite as the state value is continuous.

Why p is difficult to tract?

P when $s_t$ is continuous

Equation (\ref{eq:marginallikelihood}) involves integrating over $s_{1:T}$​, which consists of the T latent variables $s_1, s_2, \dots, s_T$​. $$\int ds_{1:T} = \int_{\mathbb{R}^n} ds_1 \int_{\mathbb{R}^n} ds_2 \cdots \int_{\mathbb{R}^n} ds_T$$ The sequence $s_{1:T}$​ has a finite length T, but The integral aggregates over all possible continuous values of $s_1, s_2, \dots$ leading to a continuous T-fold integral.

P when $s_t$ is discrete

If $s_t$​ takes on discrete values $s_t^{(1)}, s_t^{(2)}, \dots, s_t^{(K)}$, where K is the number of discrete states for $s_t$​, the marginal likelihood becomes: $$.\sum_{s_{1:T}} = \sum_{s_1} \sum_{s_2} \cdots \sum_{s_T}$$ Each summation is over all possible values of $s_t$​. For categorical $s_t$​ might be with K states and T variables, we need to do a sum over $K^T$ combinations of $s_{1:T}$. Note: In discrete models (e.g., Hidden Markov Models), the marginal likelihood summation can often be computed efficiently using algorithms like the forward algorithm, avoiding an explicit sum over all $K^T$ possibilities.

This is why

Marginal P is computationally intractable as exact integration over continuous spaces or $K^T$ summations over high dimensional discrete state with K values is rarely feasible. Computation of marginal likelihood using analytical methods is not possible, numerical methods are needed. Note: This is why techniques like variational inference or Monte Carlo sampling are used to approximate the marginal likelihood. Before proceeding with understanding expressions of ELBO, apart from about importance sampling, we will also need to know Jensen's equality.

Jensen’s Inequality

Jensen's inequality states that for a convex (or concave) function f(x), and a random variable X: $$f(\mathbb{E}[X]) \leq \mathbb{E}$$ If f(x) is concave, the inequality is flipped: $$ \ln \mathbb{E}[x] \geq \mathbb{E}[\ln x] $$

Deriving the Variational Bound ELBO

In equation (\ref{eq:marginallikelihood}), we introduce a variational approximation $q(s_{1:T} \mid o_{1:T}, a_{1:T})$ to the true posterior $p(s_{1:T} \mid o_{1:T}, a_{1:T})$ using importance sampling: $$ p(o_{1:T} \mid a_{1:T}) = \int q(s_{1:T} \mid o_{1:T}, a_{1:T}) \frac{p(o_{1:T}, s_{1:T} \mid a_{1:T})}{q(s_{1:T} \mid o_{1:T}, a_{1:T})} \, ds_{1:T} $$ Taking the logarithm: $$ \ln p(o_{1:T} \mid a_{1:T}) = \ln \int q(s_{1:T} \mid o_{1:T}, a_{1:T}) \frac{p(o_{1:T}, s_{1:T} \mid a_{1:T})}{q(s_{1:T} \mid o_{1:T}, a_{1:T})} \, ds_{1:T} $$ Since the logarithm is concave, we can apply Jensen's inequality as: $$\ln(\mathbb{E}[X]) \geq \mathbb{E}[\ln(X)]$$ so that $$ \ln p(o_{1:T} \mid a_{1:T}) \geq \int q(s_{1:T} \mid o_{1:T}, a_{1:T}) \ln \frac{p(o_{1:T}, s_{1:T} \mid a_{1:T})}{q(s_{1:T} \mid o_{1:T}, a_{1:T})} \, ds_{1:T} $$ This forms the Evidence Lower Bound (ELBO): $$ \ln p(o_{1:T} \mid a_{1:T}) \geq \mathbb{E}_{q(s_{1:T} \mid o_{1:T}, a_{1:T})} \left[ \ln p(o_{1:T}, s_{1:T} \mid a_{1:T}) - \ln q(s_{1:T} \mid o_{1:T}, a_{1:T}) \right] $$ Equation of joint likelihood: $$ p(o_{1:T}, s_{1:T} \mid a_{1:T}) = \prod_{t=1}^T p(o_t \mid s_t) p(s_t \mid s_{t-1}, a_{t-1}) $$ Using equation (\ref{eq:jointlikelihood}) of joint likelihood into this ELBO: $$\ln p(o_{1:T} \mid a_{1:T}) \geq \mathbb{E}_{q(s_{1:T} \mid o_{1:T}, a_{1:T})} \left[ \sum_{t=1}^T \ln p(o_t \mid s_t) + \ln p(s_t \mid s_{t-1}, a_{t-1}) - \ln q(s_t \mid o_t, a_{\lt t}) \right].$$

ELBO Components

We can now split the ELBO into two main components: Reconstruction Term: measures how well the model reconstructs the observations given the latent states and KL Divergence measures how much the variational posterior $q(s_t \mid o_t, a_{\;t t})$ deviates from the prior dynamics $p(s_t \mid s_{t-1}, a_{t-1})$. Final ELBO can be written as: $$ \mathcal{L} = \underbrace{\sum_{t=1}^T \mathbb{E}_{q(s_t \mid o_t)} [\ln p(o_t \mid s_t)]}_{\text{Reconstruction Loss}} - \underbrace{\sum_{t=1}^T D_{\mathrm{KL}} \left( q(s_t \mid o_t, a_{\lt t}) \parallel p(s_t \mid s_{t-1}, a_{t-1}) \right)}_{\text{Complexity}} $$ The loss function becomes: $$\mathcal{L} = -\text{ELBO} = \text{Reconstruction Loss} + \text{Complexity Penalty}$$

Non-differentiability Problem: Stochastic Sampling from variational posterior

While maximizing the ELBO if q is a complex function of parameters (e.g., neural network outputs), gradients for optimization can become problematic due to the non-differentiable nature of sampling operations in q. How? To optimize the parameters $\phi$ of $q(s_{1:T} \mid o_{1:T}, a_{1:T})$, we compute gradients of \mathcal{L} with respect to $\phi$. The expectation is over $q(s_{1:T} \mid o_{1:T}, a_{1:T}))$, meaning we must sample $s_{1:T}$​ to estimate this expectation. But sampling $s_{1:T} \sim q(s_{1:T} \mid o_{1:T}, a_{1:T})$ is stochastic and hence non-differentiable - the stochasticity in sampling disrupts the computation of gradients via backpropagation.

Solution: Reparameterization Trick

It re-expresses the stochastic sampling as a deterministic transformation of a fixed random variable. Instead of directly sampling $s_t$​ from $q(s_t \mid o_t, a_{\lt t})$, It introduces a Noise $\epsilon$ by sampling from a fixed distribution (e.g., $\mathcal{N}(0, 1)$ - we move the randomness to $\epsilon$ that is independent of learnable parameters $\phi$ making the computation of $\nabla_\phi$ tractable. In the deterministic transformation as below gradients $\nabla_\phi$​ can flow through $\mu_\phi$​ and $\sigma_\phi$​ via backpropagation: $$s_t = \mu_\phi + \sigma_\phi \cdot \epsilon \sim \mathcal{N}(0, 1)$$