In this note I prove that the InfoNCE loss is a lower bound on mutual information. While the original work of (van der Oord et al., 2019) already proves this statement, when first reading the paper it was unclear to me ⅰ. how the authors proved the inequality step (that is, why equation (9) implies inequality (10) in the appendix) and ⅱ. why the bound gets tighter with the number of negative samples. The following presentation offers a different (and hopefully clearer) perspective.
Proof. We start from the definition of the InfoNCE loss (which we denote by \(L\)):
where \(r\) denotes the score function, that a assigns a non-negative value to pairs of data points. The expectation is over the following random variables:
In short:
We assume that the score function is optimal (that is, it minimizes the loss). As shown in the paper the optimal function is given by the following ratio of distributions:
Using this observation together with the linarity of the expectation we get:
The first term is the mutual information, while the second term can be simplified, because the values \(\bar x_i\) are sampled independetly from \(y\) and thus we have \(p(\bar x|y) = p(\bar x)\):
Now we apply Jensen's inequality to the second term. Jensen's inequality states that \(\mathbb E\left[\phi(z)\right] \ge \phi\left(\mathbb E z \right)\) if the function \(\phi\) is convex. In our case we apply the inequality to the function
which is convex (because it is the composition of a concave and nondecreasing function, \(\log\) with a convex function, \(1/z\); see page 84 in Boyd and Vanderberghe). This implies that
but since
we obtain
or, equivalently,
The paper also claims that the bound gets tighter with the minibatch size, \(n\). To me it was not obvious why this the case because both \(-L\) and \(\log n\) increase with \(n\). I believe the explanation is that Jensen's inequality becomes tighter when the function \(\phi\) is closer to a linear function; in the limit, if the function is linear we achieve equality: \(\mathbb E\left[\phi(z)\right] = \phi\left(\mathbb E z \right)\). In our case, the function \(\phi\) that we have previously defined is almost constant if \(n\) dominates the other terms, that is, \(n \gg p(x|y) / p(x)\). This happens either when \(n\) is very large or when the two random variables are losely correlated; a similar observation is made by (Poole et al, 2019):
"[A]ccurately estimating mutual information still needs a large batch size at test time if the mutual information is high."
Code that numerically simulates the estimation of the lower bound is available here.