Temporal Representation

To capture long-term temporal dependency: larger quantities of data is needed as the frequency of long-term events may decrease with the time scale. In this blog, we will talk about ways to make it more sample efficient using temporal difference version of ranking-based constrastive loss.

Contrastive Objective

Two temporally close states implies transition from one state to another is more probable than states which are temporally far.

In other words, distance between representations of 2 states ($s_1$, $s_2$) encodes temporal similarity , which is the state $s_1$’s probability of transitioning to the next state $s_2$ and states likely to be visited after the next state $s_2$.

Ranking-Based Contrastive Loss

Ranking version of binary classification (NCE) is info-NCE or ranking based contrastive loss. In info-NCE, states are ranked based on their relevance or visitation likelihood under a policy. For instance:

  1. Positive Pair:States that are likely to be visited together under a given policy.
  2. Negative Pair:States that are unlikely to be visited together.

Temporal Difference Version

By Incorporating Temporal Difference update into the contrastive infoNCE loss, it ensures that states are ranked not only based on current visitation likelihood but also on their future visitation likelihood under the policy. It uses the Bellman Principle of Optimality to tie the value of a state $s$ to the expected reward and the values of future states $s′$.

TD version of info-NCE loss enables us to use data from one policy to estimate which states a different policy will visit because they rely on transition probabilities and rewards, which are typically shared across policies in a given environment. This allows the model to infer visitation patterns of a new policy based on the data collected from another.

Visitation Probability in Continuous Setting

In tabular settings, computing visitation probability is manageable because the states are discrete, and we can compute expected number of times state $s′$ will be visited starting from state $s$ and following policy $\pi$ for each pair of states directly.

However, in continuous spaces, representing and estimating this matrix becomes computationally infeasible due to infinite possible states.

Workaround

Instead of directly estimating visitation probabilities (or their discounted sums), the idea is to learn state representations in a latent space where the inner product between two state embeddings approximates the visitation probability.

Learning a latent space allows the model to generalize across unseen states by leveraging the learned structure of the environment.

Mathematics

Self-supervised actor critic algorithm is used. Similarity function f between two representations is a parametric function - parameterized by inner product of representations of data $$ f(x, y) = \phi(x)^{\top} \psi(y) $$ In this expression, $\phi(.)$ and $\psi(.)$ map data to $\mathcal{L_{2}}$ normalized vectors of dimension d. f is a critic function, $\phi$ and $\psi$ are contrastive representations of x and y respectively.

Bayes-optimal critic of f is denotes as f* and given by: $$exp(f*(x,y)) = \frac{p(y \mid x)}{p(y)c(x)}$$ Here, c(.) is arbitrary function estimated as the reciprocal of optimal critic over negative pairs sampled from the distribution. $$ \mathbb{E}_{p(y)} \left[ \exp(f^{\star}(x, y)) \right] = \int p(y) \frac{p(y \mid x)} {p(y) c(x)} \, dy = \frac{1}{c(x)} \, \int p(y \mid x) \, dy = \frac{1}{c(x)} $$

Discounted State Occupancy Measure

quantifies the "importance" or "likelihood" of being in a particular state $𝑠t+$ under a specific policy, while considering the discounted nature of rewards (or contributions) over time. $$p_{\pi}(s_{t+} \mid s, a) = (1 - \gamma) \sum_{t=1}^{\infty} \gamma^{t-1} p_{\pi}^{t}(s_{t+} \mid s, a)$$ The normalization factor $(1 - \gamma)$ ensures that the discounted sum is properly normalized across time i.e. it adds up to 1.Without this factor, the infinite summation could diverge or be incorrectly scaled depending on the choice of γ. This factor arises naturally in discounted sums and ensures consistency with the total probability.

MC InfoNCE estimator

Prior works like contrastive RL and c-learning also aim to learn representations whose inner products correspond to the likelihoods of reaching future states. But it uses Monte Carlo estimator which has high variance and is sample inefficient as compared to TD methods. Expression for monte carlo info-NCE estimator: $$ \text{L}_{\text{MCInfoNCE}}(f) = \mathbb{E}_{(s, a) \sim p(s, a), s_{t+}^{(1)} \sim p_{\pi}(s_{t+} \mid s, a), s_{t+}^{(2:N)} \sim p(s_{t+})} \left[ \log \log(\frac{\exp(f(s, a, s_{t+}^{(1)}))}{\sum_{i=1}^{N} \exp(f(s, a, s_{t+}^{(i)}))}) \right] $$ Optimal critic $f^{\star}$ is expressed as: $$ \exp(f^{\star}(s, a, s_{t+})) = \frac{p^{\pi}(s_{t+} \mid s, a)}{p(s_{t+}) \, c(s, a)} $$

Temporal Difference InfoNCE

We avoid sampling from discounted state occupancy measure of policy π by using importance sampling - sampling is then done over $p(s_{t+})$.

Goal Conditioned Policy Learning

TD InfoNCE loss as stated above estimates the discounted state occupancy measure for policy π(a | s). It can be extended to the goal-conditioned setting by replacing $\pi(a \mid s)$ with $\pi(a \mid s, g)$ and $f(s, a, s_{t+})$ with $f(s, a, g, s_{t+})$, resulting in a goal-conditioned TD InfoNCE estimator.