Tracking Metrics over Epochs to Understand Labels Better.

I asked a question on twitter on finding bad labels in training data that gave some interesting responses. One suggested idea introduces the idea of tracking “variability” and “confidence” while training a neural network. It’s described in this arxiv paper.

The confidence metric \(\mu_i\) for a data-point is the mean model probability of the true label that we measure across epochs.

\[\hat{\mu}_{i}=\frac{1}{E} \sum_{e=1}^{E} p_{\boldsymbol{\theta}^{(e)}}\left(y_{i}^{*} \mid \boldsymbol{x}_{i}\right)\]

The “variability” metric measures the spread of the estimated confidence across training epochs using standard deviation.

\[\hat{\sigma}_{i}=\sqrt{\frac{\sum_{e=1}^E [p_{\boldsymbol{\theta}^{(e)}}\left(y_{i}^{*} \mid \boldsymbol{x}_{i}\right) - \mu_i ]^2}{E}}\]

For both of these measure you got to note that \(p_{\boldsymbol{\theta}^{(e)}}\) is the probability given weights \(\theta\) during epoch \(e\). As we train our system, the weights update and so do the estimated probabilities.

So what do you get when you train a system this way? You’ll get data that can generate images like this one;

After your training run, you kept track of the confidence and the variability in the training loop. You can then wonder what it might mean to have a data-point with high confidence and low variability. You could argue that these data-points were perhaps the ones that were easy to distinguish and learn.

Similarly you could argue that there’s a region where data-points are hard to label but also that there’s a region with ambiguous data-points. These groups of data-points may deserve an extra look. There may be bad labels in there or classes that deserve a better definition.

It’s a neat idea and I gotta try it out some time.