“This is the 26th day of my participation in the First Challenge 2022. For details: First Challenge 2022.”

Recently, I came across the sparse version of the cross entropy loss function of multi-label classification proposed by Su Jianlin and found it very interesting. Moreover, there are few codes on Github, so I used PyTorch to reproduce it and recorded the relevant learning process here.

Big Guy blog:

Kexue. FM/archives / 73…

Kexue. FM/archives / 88…

From single label to multiple label

What is multi-label classification? Softmax and the cross entropy loss function are the basic operations for dealing with the general multi-label classification problem (i.e., single-label classification) described in the previous article.

Single-label classification selects one target category from n candidate categories for classification. The optimization objective of the loss function is to maximize the score of the target category, which can be referred to the cross entropy loss function in the previous paper.

For multi-label classification, we select K target categories from n candidate categories (as a positive example, that is, the question of yes and no). Another understanding is that we carry out N binary classification tasks simultaneously.

Intuitively, we can directly choose sigmoID activation and use the sum of the cross entropy of dichotomies as Loss. However, when N >> K, there will be a serious category imbalance problem. When K is very small, the network only needs to simply predict all the results as negative cases, and a small loss value can be obtained. However, in single-label classification, k=1 does not have such a problem of category imbalance, because softmax is used to make the cross entropy obtain appropriate loss for each prediction without bias.

Therefore, an intuitive idea is that the loss function of multi-label classification can be extrapolated by Softmax, in other words, when k=1, the loss function will degenerate into Softmax.

Combination softmax

Su Jianlin first considered the case of fixed K. Obviously, when reasoning, we only need to output top-K score. Then how about loss during training?

By analogy with the n selection 1 of a single label, we can express the multi-label as CnkC_n^kCnk select 1, so that the loss should be:


log e s t 1 + s t 2 + + s t k 1 Or less i 1 < i 2 < < i k Or less n e s i 1 + s i 2 + + s i k = log Z k ( s t 1 + s t 2 + + s t k ) (1) -\log \frac{e^{s_{t_1}+s_{t_2}+\dots+s_{t_k}}}{\sum\limits_{1\leq i_1 < i_2 < \cdots < i_k\leq n}e^{s_{i_1}+s_{i_2}+\dots+s_{i_k}}}=\log Z_k – (s_{t_1}+s_{t_2}+\dots+s_{t_k}) \tag 1

Sk=∑ I =1neksiS_k = \sum\limits_{I =1}^n e^{k s_i}Sk= I =1∑ NEksi, we can get:


Z 1 = S 1 2 Z 2 = Z 1 S 1 S 2 3 Z 3 = Z 2 S 1 Z 1 S 2 + S 3 k Z k = Z k 1 S 1 Z k 2 S 2 + + ( 1 ) k 2 Z 1 S k 1 + ( 1 ) k 1 S k \begin{aligned} Z_1 =&\, S_1\\ 2Z_2 =&\, Z_1 S_1 – S_2\\ 3Z_3 = &\, Z_2 S_1 – Z_1 S_2 + S_3\\ \vdots\\ k Z_k = &\, Z_{k-1} S_1 – Z_{k-2} S_2 + \dots + (-1)^{k-2} Z_1 S_{k-1} + (-1)^{k-1} S_k \end{aligned}

We will not be too confused here and say something that Su Jianlin did not say. We will go back to the form of loss itself, which is almost the same as the form of Softmax, but the object has changed from a sis_ISI to a group of {sti}\{s_{t_I}\}{sti}. After careful analysis, we will find a problem:

For Softmax, we want the target sIS_ISI to be large enough and the other sis_ISI to be small enough, and for the above equation, we want the sum of the set of STIs_ {t_I} STI to be large enough, but if one of the stis_{t_I} STI becomes large enough, Loss also becomes small enough that optimization stops.

Here I try to prove it:


l o g ( Z k ) = l o g ( 1 Or less i 1 < i 2 < < i k Or less n e s i 1 + s i 2 + + s i k ) log(Z_k)=log(\sum\limits_{1\leq i_1 < i_2 < \cdots < i_k\leq n}e^{s_{i_1}+s_{i_2}+\dots+s_{i_k}})

Notice that the above formula is actually LogSumExp, and LogSumExp is the smooth approximation of Max function, so Loss can be transformed into:


L material M A X ( e s m 1 + s m 2 + + s m k ) ( s i 1 + s i 2 + + s i k ) ( 1 Or less m 1 < m 2 < < m k Or less n ) L\approx MAX(e^{s_{m_1}+s_{m_2}+\dots+s_{m_k}})-(s_{i_1}+s_{i_2}+\dots+s_{i_k})\qquad \\(1\leq m_1 < m_2 < \cdots < m_k\leq n)

Therefore, when one of the StiS_{t_I} STIs becomes large enough, loss will become small enough even if all the others in the same group are small.

Uncertain K

Generally, in multi-label classification tasks, the number of output is often not fixed. Therefore, a maximum number of target labels K is determined and 0 label is used as filling. The number of output labels is not more than K.


log Z K ( s t 1 + s t 2 + + s t k + s 0 + + s 0 K k a ) , log, overline} {Z _K – (s_ {t_1} + s_ {t_2} + \ dots + s_ {t_k} + \ underbrace {\ dots s_0 + + s_0} _ {K – K \ text {a}})

For example, we only need to output 2 labels, and the maximum number of target labels is 10. When making labels, we only need to add corresponding labels, and the remaining 8 bits are filled with 0 labels, which is an invalid label (but the network needs to predict this label, (4) as num_classes becomes NUM_classes +1, or else num_classes can still not output a fixed number of labels), allow repeated output, inference still output topK, but remove the 0 tag, Z‾K\overline{Z}_KZK can also use recursive solution, here is no longer described.

Unified Loss format

While verifying the effectiveness of the above loss, Su Jianlin consulted some other bigwigs and found the unified form of Loss in Circle Loss (read it when you have time). He realized that this unified form contains a more concise promotion scheme, and the author of Circle Loss also once said the error of the above method: www.zhihu.com/question/38… .

Unified forms of Loss are as follows:


L u n i = l o g [ 1 + i = 1 K j = 1 L e x p ( gamma ( s n j s p i + m ) ) ] = l o g [ 1 + i = 1 K e x p ( gamma ( s n j + m ) ) j = 1 L e x p ( gamma ( s p i ) ) ] \begin{aligned} L_{uni} &= log[1+\sum_{i=1}^K\sum_{j=1}^Lexp(\gamma(s_n^j-s_p^i+m))]\\ &=log[1+\sum_{i=1}^Kexp(\gamma(s_n^j+m))\sum_{j=1}^Lexp(\gamma(-s_p^i))] \end{aligned}

The above formula calculates positive and negative cases separately, and we write the cross entropy function in a similar form:


log e s t i = 1 n e s i = log 1 i = 1 n e s i s t = log i = 1 n e s i s t = log ( 1 + i = 1 . i indicates t n e s i s t ) -\log \frac{e^{s_t}}{\sum\limits_{i=1}^n e^{s_i}}=-\log \frac{1}{\sum\limits_{i=1}^n e^{s_i-s_t}}=\log \sum\limits_{i=1}^n e^{s_i-s_t}=\log \left(1 + \sum\limits_{i=1,i\neq t}^n e^{s_i-s_t}\right)

If this formula looks familiar, this is the smooth approximation of Max, the LogSumExp function, which I mentioned earlier, so let’s talk about how it’s derived.

LogSumExp

Reference:

Kexue. FM/archives / 32…

www.matrix67.com/blog/archiv…

www.johndcook.com/blog/2010/0…

When x acuity 0, 0 x or y \ geq0, y \ geq0x acuity 0, y 0 or higher:


max ( x . y ) = 1 2 ( x + y + x y ) (2) \max(x,y)=\frac{1}{2}\left(|x+y|+|x-y|\right)\tag2

To approximate the Max function, we can first find the approximate function of absolute value, whose derivative is as follows:


f ( x ) = { 1 . x > 0 1 . x < 0 (3) f'(x) = \left\{\begin{aligned}1,&\,x > 0\\ -1,&\, x < 0\end{aligned}\right.\tag3

We use the unit step function to approximate:


Theta. ( x ) = { 1 . x > 0 0 . x < 0 (4) \theta(x) = \left\{\begin{aligned}1,&\,x > 0\\ 0,&\, x < 0\end{aligned}\right.\tag4

f ( x ) = 2 Theta. ( x ) 1 (5) f'(x)=2\theta(x)-1\tag5

We can approximate the Max function by an approximation function of θ(x)\theta(x)θ(x). The common approximation in physics is:


Theta. ( x ) = lim k + up 1 1 + e k x \theta(x)=\lim_{k\to +\infty} \frac{1}{1+e^{-k x}}

Substitute this equation into equation (5), and the integral can be obtained:


x = lim k + up 1 k ln ( e k x + e k x ) |x|=\lim_{k\to +\infty} \frac{1}{k}\ln(e^{kx}+e^{-kx})

In this way, the approximate formula for Max can be obtained:


max ( x . y ) = lim k + up 1 2 k ln ( e 2 k x + e 2 k x + e 2 k y + e 2 k y ) \max(x,y)=\lim_{k\to +\infty} \frac{1}{2k}\ln(e^{2kx}+e^{-2kx}+e^{2ky}+e^{-2ky})

Since x≥0,y≥0x\geq0,y\geq0x≥0,y≥0, e−2kxe^{-2kx}e−2kx and e−2kye^{-2ky}e−2ky tend to 0, they can be further simplified as:


max ( x . y ) = lim k + up 1 k ln ( e k x + e k y ) \max(x,y)=\lim_{k\to +\infty} \frac{1}{k}\ln(e^{kx}+e^{ky})

And the above formula satisfies any real number, and can even be extended to multivariables:


max ( x . y . z . ) = lim k + up 1 k ln ( e k x + e k y + e k z + ) \max(x,y,z,\dots)=\lim_{k\to +\infty} \frac{1}{k}\ln(e^{kx}+e^{ky}+e^{kz}+\dots)

But k here should approach infinity, what does that have to do with log sump to the e?

In models, we usually set K to 1, which is equivalent to fusing KK into the model itself and letting the model determine the size of K itself.

Unified cross entropy function in the form of Loss


log i = 1 n e s i s t material m a x ( 0 s 1 s t s t 1 s t s t + 1 s t s n s t ) \log \sum\limits_{i=1}^n e^{s_i-s_t}\approx max\begin{pmatrix}0 \\ s_1 – s_t \\ \vdots \\ s_{t-1} – s_t \\ s_{t+1} – s_t \\ \vdots \\ s_n – s_t\end{pmatrix}

We only need to pay attention to this expression, which can explain why the Softmax + cross entropy loss function works.

On the above, we already know the type is smooth approximation of Max, so that this formula was equivalent to seek other non-target categories with a maximum of the difference in value between the target category, and wanted to take the maximum value less than zero, because the target category score minus itself is equal to zero, thus can guarantee the target category score is greater than the target category.

Multilabel classification

We have obtained the cross entropy function in the unified form previously. We divide the target into positive and negative cases according to its form, and the following formula can be obtained:


log ( 1 + i Ω n e g . j Ω p o s e s i s j ) = log ( 1 + i Ω n e g e s i j Ω p o s e s j ) \log \left(1 + \sum\limits_{i\in\Omega_{neg},j\in\Omega_{pos}} e^{s_i-s_j}\right)=\log \left(1 + \sum\limits_{i\in\Omega_{neg}} e^{s_i}\sum\limits_{j\in\Omega_{pos}} e^{-s_j}\right)

When k is fixed, the above formula can be directly used. If K is uncertain, we add an additional 0 class according to the previous method and hope that the scores of the targets are all greater than s_0 and the scores of non-targets are all less than s_0. The following formula can be obtained:


log ( 1 + i Ω n e g . j Ω p o s e s i s j + i Ω n e g e s i s 0 + j Ω p o s e s 0 s j ) = log ( e s 0 + i Ω n e g e s i ) + log ( e s 0 + j Ω p o s e s j ) \begin{aligned} &\log \left(1 + \sum\limits_{i\in\Omega_{neg},j\in\Omega_{pos}} e^{s_i-s_j}+\sum\limits_{i\in\Omega_{neg}} e^{s_i-s_0}+\sum\limits_{j\in\Omega_{pos}} e^{s_0-s_j}\right)\\ =&\log \left(e^{s_0} + \sum\limits_{i\in\Omega_{neg}} e^{s_i}\right) + \log \left(e^{-s_0} + \sum\limits_{j\in\Omega_{pos}} e^{-s_j}\right)\\ \end{aligned}

If the threshold is specified as 0, it can be simplified as:


log ( 1 + i Ω n e g e s i ) + log ( 1 + j Ω p o s e s j ) (6) \log \left(1 + \sum\limits_{i\in\Omega_{neg}} e^{s_i}\right) + \log \left(1 + \sum\limits_{j\in\Omega_{pos}} e^{-s_j}\right)\tag6

Therefore, there is no need to add an extra class for training. The code implementation is given below:

 def multilabel_categorical_crossentropy(y_true, y_pred) :
     The cross entropy of multi-label classification shows that y_true and y_pred have the same shape. The elements of y_true are either 0 or 1, 1 means that the corresponding class is the target class, and 0 means that the corresponding class is non-target class. Warning: please make sure that the range of y_pred is all real numbers, in other words, in general y_pred does not need to be activated, especially not sigmoid or softmax! The prediction phase outputs classes with y_pred greater than 0. If in doubt, please read and understand this article carefully. "" "
     y_pred = (1 - 2 * y_true) * y_pred Multiply the positive case by -1 and the negative case by 1
     y_pred_neg = y_pred - y_true * 1e12 # change positive examples to negative infinity, eliminating effects
     y_pred_pos = y_pred - (1 - y_true) * 1e12 # change the negative example to negative infinity
     zeros = torch.zeros_like(y_pred[..., :1]) 
     y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) # 0 threshold
     y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
     neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
     pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
     return neg_loss + pos_loss
Copy the code

Sparse version of multi-label classification cross entropy

Multi-label classification cross entropy can be used not only for multi-label classification tasks, but also for many tasks, as long as n chooses K. Su Jianlin gave an example of Global Pointer in CV field, such as Focal Loss, which replaces target detection.

When the positive and negative cases in some tasks are very unbalanced (here the positive cases are much smaller than the negative cases) and the label size is very large, we can change the strategy:


log ( 1 + i N e S i ) = log ( 1 + i A e S i i P e S i ) = log ( 1 + i A e S i ) + log ( 1 ( i P e S i ) / ( 1 + i A e S i ) ) \begin{aligned} &\,\log \left(1 + \sum\limits_{i\in \mathcal{N}} e^{S_i}\right) = \log \left(1 + \sum\limits_{i\in \mathcal{A}} e^{S_i} – \sum\limits_{i\in \mathcal{P}} e^{S_i}\right) \\ =&\, \log \left(1 + \sum\limits_{i\in \mathcal{A}} e^{S_i}\right) + \log \left(1 – \left(\sum\limits_{i\in \mathcal{P}} e^{S_i}\right)\Bigg/\left(1 + \sum\limits_{i\in \mathcal{A}} e^{S_i}\right)\right) \end{aligned}

The loss of negative cases can be written as the complete set minus the positive cases. In this way, when making labels, we only need to save the labels of the positive cases. In training, we can calculate through the direct index of the positive case labels.

The author gave the code for TensorFlow, but the Pytorch version was nowhere to be found on the web, so I tried to reproduce it and posted it on my Github: github.com/Asthestarsf…

 def sparse_multilabel_categorical_crossentropy(label: Tensor, pred: Tensor, mask_zero=False, reduction='none') :
     """Sparse Multilabel Categorical CrossEntropy Reference: https://kexue.fm/archives/8888, https://github.com/bojone/bert4keras/blob/4dcda150b54ded71420c44d25ff282ed30f3ea42/bert4keras/backend.py#L272 Args: label: label tensor with shape [batch_size, n, num_positive] or [Batch_size, num_positive] should contain the indexes of the positive rather than a ont-hot vector pred: logits tensor with shape [batch_size, m, num_classes] or [batch_size, num_classes], don't use acivation. mask_zero: if label is used zero padding to align, please specify make_zero=True. when mask_zero = True, make sure the label start with 1 to num_classes, before zero padding. """
     zeros = torch.zeros_like(pred[..., :1])
     pred = torch.cat([pred, zeros], dim=-1)
     if mask_zero:
         infs = torch.ones_like(zeros) * float('inf')
         pred = torch.cat([infs, pred[..., 1:]], dim=-1)
     pos_2 = batch_gather(pred, label)
     pos_1 = torch.cat([pos_2, zeros], dim=-1)
     if mask_zero:
         pred = torch.cat([-infs, pred[..., 1:]], dim=-1)
         pos_2 = batch_gather(pred, label)
     pos_loss = torch.logsumexp(-pos_1, dim=-1)
     all_loss = torch.logsumexp(pred, dim=-1)
     aux_loss = torch.logsumexp(pos_2, dim=-1) - all_loss
     aux_loss = torch.clip(1 - torch.exp(aux_loss), 1e-16.1)
     neg_loss = all_loss + torch.log(aux_loss)
     loss = pos_loss + neg_loss
 ​
     if reduction == 'mean':
         return loss.mean()
     elif reduction == 'sum':
         return loss.sum(a)elif reduction == 'none':
         return loss
     else:
         raise Exception('Unexpected reduction {}'.format(self.reduction))
Copy the code

To explain the main points:

  1. When a label needs to be aligned using zero padding, add one to the label value.
  2. Specify mask_zero to True, because positive cases are indexed in pred by label, and the padded 0 value matters, so concat an infinite quantity at the beginning of pred, which results in 0 when LogSumExp is entered.
  3. As for whether to change the category number from NUM_classes to NUM_classes +1, I think it is unnecessary, because 0 has been explicitly used in the loss to represent the score of the additional category number we need, and through the previous analysis, it can be seen intuitively that the purpose of Equation (6) is actually to make the score of positive cases greater than 0. When the negative example score is less than 0, the category with a score greater than 0 can be output directly during reasoning.

In addition, PyTorch does not have an API for Batch_Gather, so it simply implements one according to loss requirements:

def batch_gather(input: Tensor, indices: Tensor) :
    """ Args: input: label tensor with shape [batch_size, n, L] or [batch_size, L] indices: predict tensor with shape [batch_size, m, l] or [batch_size, l] Return: Note that when second dimention n ! = m, there will be a reshape operation to gather all value along this dimention of input if m == n, the return shape is [batch_size, m, l] if m ! = n, the return shape is [batch_size, n, l*m] """
    ifindices.dtype ! = torch.int64: indices = torch.tensor(indices, dtype=torch.int64) results = []for data, indice in zip(input, indices):
        if len(indice) < len(data):
            indice = indice.reshape(-1)
            results.append(data[..., indice])
        else:
            indice_dim = indice.ndim
            results.append(torch.gather(data, dim=indice_dim-1, index=indice))
    return torch.stack(results)

Copy the code