Why we can interpret softmax scores as probabilities

Foreword

I wrote this short essay when I was a postdoc. I started writing it to convince myself that I could interpret softmax scores as probabilities, and not just scores. I hope you find it as useful as I did thinking through it.

Motivation

Suppose we possess data and labels, \(\{(x^{(1)},c^{(1)}),…,(x^{(m)},c^{(m)})\}\), collected from the joint distribution \((x,c) \sim p(x,c)\), with only two possible classes: \(c \in \{A,B\}\).

The loss function for a binary classifier \(f_{\theta}(x)=p(A \mid x)\) is given by the well-known1:

\begin{align} \mathcal{L}(\theta) &= -\sum_{c=A}\log [f_\theta(x)] - \sum_{c=B}\log [1- f_\theta(x)] \end{align}

But where does these expressions come from? It is visible upon inspection that minimizing Equation 1 will yield an effective classifier. If the neural network \(f_\theta\) is terminated with a sigmoid function2, then examples of class \(A\) will push \(f_\theta(x)\) towards 1 for their values of \(x\), and examples of class \(B\) will be driven toward 0.

However, this observation is not enough to say with confidence that the output of the classifier is a good approximation of the conditional probability of the class label, \(P(c \mid x)\).

Derivation

To build such a classifier, we will minimize the KL divergence between our target distribution for a binary classifier, \(P(c \mid x)\), and our model distribution parameterized by \(\theta\), \(f_\theta(c,x)\).

For the sake of clarity, while in Equation 1, we used \(f_\theta(x)\) to denote our classifier, here we indicate matching the distribution over all the classes \(\{A,B\}\), and so will use \(f_\theta(c,x)\) instead. There is no real need for the argument \(c\), since we are only interested in the single output \(f_\theta(A,x) = P(A \mid x) = 1 - P(B \mid x)\), and we will make this substitution later to simplify the derived loss.

The KL divergence of conditional distributions is the expectation of the distribution of KL divergences taken over each conditional variable:

\begin{align} D_{KL}(P(c \mid x)||f_\theta(c,x)) = \mathbb{E}_{p(x)}\left[ \mathbb{E}_{P(c \mid x)}\log\left(\frac{P(c \mid x)}{f_\theta(c,x)}\right)\right] \end{align}

And to motivate choice of parameters:

\begin{align} \hat{\theta} &= \arg \min_{\theta} D_{KL}(P(c \mid x)||f_\theta(c,x))\\ &= \arg \min_{\theta} \mathbb{E}_{p(x)}\left[ \mathbb{E}_{P(c \mid x)}\log(\frac{P(c \mid x)}{f_\theta(c,x)}\right]\\ &= \arg \min_{\theta} \mathbb{E}_{p(x)}\left[ \mathbb{E}_{P(c \mid x)}\log(P(c \mid x)) - \log(f_\theta(c,x)\right] \end{align}

Taking an expectation over \(p(x)\) and then \(P(c \mid x)\) is the same as drawing samples from and calculating an expectation over samples from the joint distribution, so we can write:

\begin{align} \hat{\theta} &= \arg \min_{\theta} \mathbb{E}_{p(c,x)}\left[ \log(P(c \mid x)) - \log(f_\theta(c,x))\right] \end{align}

And after dropping terms not involving \(\theta\):

\begin{align} \hat{\theta} &= \arg \min_{\theta} -\mathbb{E}_{p(c,x)}\log\left(f_\theta(c,x)\right) \end{align}

Interestingly enough, at this point, minimizing the KL divergence statement of Equation 2 in order to produce \(P(c \mid x)\) looks as if we were trying to match the joint distribution \(p(c,x)\), because we’ve dropped any terms indicating conditional probability. It must be, therefore, that some feature in the architecture of \(f_{\theta}\) is the reason we say it mimics the conditional distribution, and not the joint.

By the law of large numbers, we can estimate the expectation in Equation 3 by summing over samples:

\begin{align} \hat{\theta} &= \arg \min_{\theta} - \frac{1}{N_A} \sum_{A} \log \left(f_\theta(A,x)\right) -\frac{1}{N_B} \sum_{B} \log\left(f_\theta(B,x)\right) \end{align}

where \(\sum_A, \sum_B\) indicate taking the sum using examples of classes \(A\) and \(B\), and \(N_A\) and \(N_B\) are the number of samples of each class. If there are equal number of samples of each class \(N_A=N_B\), so the prior \(P( c) = P(A) = P(B) = .5\), as we do in GAN:

\begin{align} \hat{\theta} &= \arg \min_{\theta} - \sum_{A} \log \left(f_\theta(A,x)\right) -\sum_{B} \log\left(f_\theta(B,x)\right) \end{align}

Choice of architecture allows us to say \(f_{\theta}\) approximates \(P(c \mid x)\), not objective function. Here, we choose \(f_{\theta}\) to be a legitimate probability distribution over class labels \(c\) by defining \(f_{\theta}(A,x) = 1 - f_{\theta}(B,x)\), and for simplicity, write \(f_{\theta}(A,x)\) as \(f_{\theta}(x)\). If we choose \(f_{\theta}\) so its output is between 0 and 1, by use of a sigmoid activation function, then we can learn a valid approximator for the conditional probability function \(P(c \mid x)\) by minimizing:

\begin{align} \mathcal{L}(\theta) &= -\sum_{A}\log [f_\theta(x)] - \sum_{B}\log [1- f_\theta(x)] \end{align}
  1. This is the same as the discriminator network of a GAN, whose loss function is given: \(\mathcal{L}(\psi)=-\mathbb{E}_{p(z)}\log\sigma(g_{\psi}(f_\theta(z))) - \mathbb{E}_{q(x)}\log[1 - \sigma(g_{\psi}(x))]\) 

  2. Without loss of generality, we can consider binary classifiers of one channel terminated by sigmoid instead of softmax. 




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Indicating which blocks are loaded in webpage literate elisp
  • Studying Eigenvalues of Rotation Group Matrices
  • Hosting my CV with github actions
  • Literate emacs config as a webpage
  • Company Announcement about LLM Project