SSMix: Saliency-based Span Mixup for Text ClassificationArxiv.org/pdf/2106.08…Thesis Code:Github.com/clovaai/ssm…{Soyoungyoon etc.}

abstract

Data enhancement has proved effective for a variety of computer vision tasks. Despite the great success of text, there have been obstacles to hybrid applications to NLP tasks because text consists of discrete tokens of variable length. In this work, we propose SSMix, a new hybrid approach in which operations are performed on input text rather than hidden vectors as in the previous approach. SSMix synthesizes a sentence through span-based mixing while preserving the position of the two original texts, and retains more prediction-related markers depending on saliency information. Through a large number of experiments, we empirically verify that our method is superior to the hidden level hybrid method on a wide range of text classification criteria, including text implicit, sentiment classification and question type classification.

Data enhancement has been proven effective in a variety of computer vision tasks. Although data enhancement is very effective, there has been a barrier to applying mixup to NLP tasks because text is made up of variously long discrete characters. In this paper, the author proposes the SSMix algorithm, a mixup algorithm for input text enhancement, instead of the previous method for hidden vectors. SSMix combines a sentence with two original texts by span-based mixing, and preserves the positions of the two original texts, and retains more prediction-related markers depending on saliency information. Through a large number of experiments, we verify that the proposed algorithm is superior to the hidden level hybrid method on a wide range of text classification criteria, including text inference, emotion classification and problem type classification tasks.

Introduction of algorithm

Data enhancement is becoming increasingly important in natural language processing (NLP) due to the high cost of data collection and labeling. Some of these past studies have involved generating similar texts based on simple rules and models. For example, the training is combined with the original samples through standard methods or advanced training methods, and also enhanced based on mixup interpolated text and tags.

Mixup and its variant training algorithms are commonly used regularization methods in computer vision to improve the generalization ability of neural networks. Mixing methods are divided into input level mixing and hidden-level mixup, depending on the location of the mixing operation. Input-level blending is a more common method than hiding-level blending because of its simplicity and ability to capture locality, resulting in better accuracy.

Due to the discreteness and variable sequence length of text data, applying mixup in NLP is more challenging and difficult than in computer vision. Thus, most previous attempts at text mixing applied mixups to embedded vectors, such as embedded or intermediate representations. However, according to the enhanced intuitions of computer vision, input level blending generally has an advantage over hidden level blending. This motivation encourages authors to explore input-level obfuscation methods for textual data.

In this work, the authors propose SSMix (FIG. 1), a new significance-level hybrid data enhancement algorithm combining input level with Span (Span). First, the authors obfuscated by replacing successive markers with spans in another text, a move inspired by CutMixarXiv that preserves the place of the two source texts in the mixed text. Second, it may be semantically important to select a span to replace and replace based on saliency information so that the mixed text contains tags that are more relevant to the output prediction. The input level method of text is different from the implicit level blending method. While the current implicit level blending method linearly interpolates the original implicit vector, our method mixes text characters on the input level, producing nonlinear output. At the same time, the significance value is used to select the span from each sentence, and the length and mixing ratio of the span are defined discretely, which is different from hidden level mixing enhancement.

SSMix has been proven to be effective through numerous experiments with text classification benchmarks. In particular, the paper proves that the input-level mixing method is generally superior to the hidden layer mixing method. The paper also demonstrates the importance of using saliency information and restricted marker selection at the span level while performing text blending enhancement.

SSMix algorithm

The basic principles of SSMix are as follows: Given two texts xAx^{A}xA and xBx^{B}xB, the new text is generated by replacing A fragment of text xAx^{A}xA, xSAx_{S}^{A}xSA, with A conspicuous fragment of information xSBx_{S}^{B}xSB from another text, xBx^{B}xB X x ~ \ widetilde {x}. Meanwhile, for the new text x~\widetilde{x}x, re-set A new label y~\widetilde{y}y for the new text x~\widetilde{x}x based on the two text labels yAy^{A}yA and yBy^{B}yB. Finally, the generated enhanced virtual sample (x~\ Widetilde {x}x, y~\ Widetilde {y}y) can be used to train the model.

Saliency: significant information

Saliency measures the influence of each character of text data on the prediction of the final result. In previous studies, gradient-based methods have been widely used in significance calculation. This paper also calculates the gradient of classification loss L\mathcal LL relative to input embedded EEE, and uses its magnitude as significance: S = ∥ partial L/partial ∥ s = e \ left \ | \ partial} {\ mathcal L / \ partial} {\ mathcal e \ right \ | s = ∥ partial L/partial e ∥. The L2 norm is used to obtain the magnitude of a gradient vector, representing the puzzlemix-like significance of each character.

Mixing Text: Text composition

Mixing Text refers to how xAx^{A}xA and xBx^{B}xB compose new Text. The general idea is to obtain the significance score of each character in the two texts according to the gradient significance calculation method, and then select A fragment xSAx_{S}^{A}xSA with the lowest significance in text xAx^{A}xA with the length of lAl_{A}lA. A fragment xSBx_{S}^{B}xSB with the length lBl_{B}lB is selected in the text xBx^{B}xB with the lowest significance. Set the length to lAl_{A}lA= lBl_{B}lB= The Max (min (∣ [lambda 0 xA ∣], ∣ xB ∣), 1) the Max (min ([\ lambda_ {0} | x ^ {A} |], | x ^ {B} |), 1) the Max (min (∣ [lambda 0 xA ∣], ∣ xB ∣), 1), the lambda 0 \ lambda_ {0} lambda 0 for mixup ratio parameter. Finally, a new text x~\widetilde{x}x w is (xLA; xSB; xRA)(x_{L}^{A}; x_{S}^{B}; x_{R}^{A})(xLA; xSB; XRA), where xLAx_{L}^{A}xLA and xLBx_{L}^{B}xLB are the left and right parts of the replacement fragment xSAx_{S}^{A}xSA in the original text xAx^{A}xA.

Sample SPAN Length: indicates the length of equal segments

This article sets the length of the original (lAl_{A}lA) and the replacement (lBl_{B}lB) span to be the same, mainly because using spans of different lengths would lead to redundant and semantically ambiguous mixup conversions. In addition, calculating mixup ratio columns between spans of different lengths is too complicated. This same size substitution strategy has been used in previous studies. In the case of the same replacement span length, the SSMix algorithm can maximize the effect of significance. Since SSMix does not limit character positions, it is possible to select both the most significant span and the least significant segment to be replaced. As in Figure 1, where in this is not significant in text xAx_{A}xA and transcedent love is the most significant in text xBx_{B}xB, you can replace in this with transcedent love.

Mixing Text: label composition

The author sets the mixup ratio column to:
Lambda. = s S B / x ~ \lambda =|s_{S}^{B}|/|\widetilde{x}|
Since λ is recalculated by counting the number of characters within the fragment, it may not be equal to λ0. then
x ~ \widetilde{x}
The label is:
y ~ = ( 1 Lambda. ) y A + Lambda. y B \widetilde{y}=(1-\lambda)y^{A}+\lambda y^{B}
Algorithm 1 shows how to use the original sample pair to calculate the mixing loss of the augmented sample. In the formula, the cross entropy loss of the enhanced output logit relative to the original target tags of each sample is calculated and combined by weighting. Therefore, the SSMix algorithm is not related to the number of tags in the data set. In any data set, the ratio of output tags is calculated by the linear combination of two original tags.

1. Paired sentence tasks

For tasks that require a pair of texts as input, such as textual implied inference and similarity classification, SSMix mixes in pairs and calculates mixup ratios by aggregating token counts in each mixup result. Given sample xA=(pA,qA)x^{A}=(p^{A},q^{A})xA=(pA,qA), XB =(pB,qB)x^{B}=(p^{B},q^{B})xB=(pB,qB) x~=(p~,q~)=(mixup(pA,pB),mixup(qA,qB))\widetilde{x}=(\widetilde{p},\widetilde{q})=(mixup(p^{A},p^{B}),mixup(q^{A},q^{B}) ) x = (p, q) = (mixup (pA, pB), mixup (qA and qB)), mixup proportion to remember Lambda = (∣ pS ∣ + ∣ qS ∣)/(∣ p + ∣ q ~ ~ ∣ ∣) \ lambda = (| p_ {S} | | + q_ {S} |)/(| \ widetilde + {p} | | \ widetilde {q} |) lambda = (∣ pS ∣ + ∣ qS ∣)/(∣ ∣ ∣ ∣ p + q), one of them PSp_ {S}pS and qSq_{S}qS are replacement fragments in each mixup operation.

As shown below:

  • PAp ^{A}pA = “Fun for only children.”
  • QAq ^{A}qA for “Fun foradults and children.”
  • PBp ^{B}pB = “Problems in data synthesis.”
  • QBq ^{B}qB = “Issues in data synthesis.”

  • Lambda. = ( p S + q S ) / ( p ~ + q ~ ) = ( 1 + 1 ) / ( 5 + 6 ) = 2 / 11 material 0.18 Lambda = (| | pS | | + qS)/(| \ widetilde + {p} | | \ widetilde {q} |) = (1 + 1)/(5 + 6) = 0.18 material 2/11

The experimental setup

Experimental data set

The experimental data set of the paper has text classification and sentence pair classification tasks:

Contrast experiment

The paper compared SSMix with three baselines :(1) standard training without mixup, (2) EmbedMixMix, and (3) TMix.Experimental results were compared with baseline and ablation studies. All accuracy values are the average accuracy (%) of five runs using different seeds. MNLI represents the accuracy of MNLI- mismatched development sets. The paper reported the validation accuracy of GLUE, the test accuracy of TREC, and the effective (upper)/test (lower) accuracy of ANLI. It can be seen that SSMix performs better than other hybrid enhancement algorithms in most data sets.

Thesis summed up

  • Compared with the hidden layer hybrid method, SSMix fully proves its effectiveness on data sets with sufficient data volume. Because SSMix is a discrete combination rather than a linear combination of two data samples, it creates a larger range of data samples on a composite space than a hidden level of mixing. The paper postulates that large amounts of data contribute to better representation in synthetic space.

  • SSMix is particularly effective for multiple tag-like datasets (TREC, ANLI, MNLI, QNLI). Thus, the accuracy gain of SSMix on TREc-Fine (47 labels) was much higher than that of TRECcrare (6 labels), with + of 3.56 and + of 0.52, respectively, without mixed training conditions. Datasets with multiple aggregate class labels increase the likelihood that cross-labels will be selected in a random sample of mixed sources, so it can be assumed that the mixing performance in these multi-label classification datasets will be significantly improved
  • It has a significant advantage in paired sentence tasks, such as textual implicature or similarity classification. The existing method (hidden layer blending) applies blending on the hidden layer, regardless of the special tags, i.e. [SEP], [CLS]. These methods may lose information about the beginning of the sentence or proper separation of sentence pairs. SSMix, by contrast, can take single character characteristics into account when applying blending.

Results of ablation studies on -SSMIX and its variants showed that performance improved with increasing fragment constraints and significance information. Adding fragment constraints to mixed operations benefits from better locability, and most significant fragments have more relationship to the corresponding tag, while discarding the least significant fragments, which are semantically insignificant relative to the original tag. Among them, introducing saliency information contributes more to the accuracy than fragment constraint.

Code implementation

import copy
import random
import torch
import torch.nn.functional as F

from .saliency import get_saliency


class SSMix:
    def __init__(self, args):
        self.args = args

    def __call__(self, input1, input2, target1, target2, length1, length2, max_len):
        batch_size = len(length1)

        if self.args.ss_no_saliency:
            if self.args.ss_no_span:
                inputs_aug, ratio = self.ssmix_nosal_nospan(input1, input2, length1, length2, max_len)
            else:
                inputs_aug, ratio = self.ssmix_nosal(input1, input2, length1, length2, max_len)
        else:
            assert not self.args.ss_no_span

            input2_saliency, input2_emb, _ = get_saliency(self.args, input2, target2)

            inputs_aug, ratio = self.ssmix(batch_size, input1, input2,
                                           length1, length2, input2_saliency, target1, max_len)

        return inputs_aug, ratio

    def ssmix(self, batch_size, input1, input2, length1, length2, saliency2, target1, max_len):
        inputs_aug = copy.deepcopy(input1)
        for i in range(batch_size):  # cut off length bigger than max_len ( nli task )
            if length1[i].item() > max_len:
                length1[i] = max_len
                for key in inputs_aug.keys():
                    inputs_aug[key][i][max_len:] = 0
                inputs_aug['input_ids'][i][max_len - 1] = 102
        saliency1, _, _ = get_saliency(self.args, inputs_aug, target1)
        ratio = torch.ones((batch_size,), device=self.args.device)

        for i in range(batch_size):
            l1, l2 = length1[i].item(), length2[i].item()
            limit_len = min(l1, max_len) - 2  # mixup except [CLS] and [SEP]
            mix_size = max(int(limit_len * (self.args.ss_winsize / 100.)), 1)

            if l2 < mix_size:
                ratio[i] = 1
                continue

            saliency1_nopad = saliency1[i, :l1].unsqueeze(0).unsqueeze(0)
            saliency2_nopad = saliency2[i, :l2].unsqueeze(0).unsqueeze(0)

            saliency1_pool = F.avg_pool1d(saliency1_nopad, mix_size, stride=1).squeeze(0).squeeze(0)
            saliency2_pool = F.avg_pool1d(saliency2_nopad, mix_size, stride=1).squeeze(0).squeeze(0)

            # should not select first and last
            saliency1_pool[0], saliency1_pool[-1] = 100, 100
            saliency2_pool[0], saliency2_pool[-1] = -100, -100
            input1_idx = torch.argmin(saliency1_pool)
            input2_idx = torch.argmax(saliency2_pool)
            inputs_aug['input_ids'][i, input1_idx:input1_idx + mix_size] = \
                input2['input_ids'][i, input2_idx:input2_idx + mix_size]

            ratio[i] = 1 - (mix_size / (l1 - 2))

        return inputs_aug, ratio

    def ssmix_nosal(self, input1, input2, length1, length2, max_len):
        inputs_aug = copy.deepcopy(input1)
        ratio = torch.ones((len(length1),), device=self.args.device)

        for idx in range(len(length1)):
            if length1[idx].item() > max_len:
                for key in inputs_aug.keys():
                    inputs_aug[key][idx][max_len:] = 0
                inputs_aug['input_ids'][idx][max_len - 1] = 102  # artificially add EOS token.
            l1, l2 = min(length1[idx].item(), max_len), length2[idx].item()

            if self.args.ss_winsize == -1:
                window_size = random.randrange(0, l1)  # random sampling of window_size
            else:
                # remove EOS & SOS when calculating ratio & window size.
                window_size = int((l1 - 2) *
                                  self.args.ss_winsize / 100.) or 1

            if l2 <= window_size:
                ratio[idx] = 1
                continue

            start_idx = random.randrange(0, l1 - window_size)  # random sampling of starting point
            if (l2 - window_size) < start_idx:  # not enough text for reference.
                ratio[idx] = 1
                continue
            else:
                ref_start_idx = start_idx
            mix_percent = float(window_size) / (l1 - 2)

            for key in input1.keys():
                inputs_aug[key][idx, start_idx:start_idx + window_size] = \
                    input2[key][idx, ref_start_idx:ref_start_idx + window_size]

            ratio[idx] = 1 - mix_percent
        return inputs_aug, ratio

    def ssmix_nosal_nospan(self, input1, input2, length1, length2, max_len):
        batch_size, n_token = input1['input_ids'].shape

        inputs_aug = copy.deepcopy(input1)
        len1 = length1.clone().detach()
        ratio = torch.ones((batch_size,), device=self.args.device)

        for i in range(batch_size): # force augmented output length to be no more than max_len
            if len1[i].item() > max_len:
                len1[i] = max_len
                for key in inputs_aug.keys():
                    inputs_aug[key][i][max_len:] = 0
                inputs_aug['input_ids'][i][max_len - 1] = 102

            mix_len = int((len1[i] - 2) * (self.args.ss_winsize / 100.)) or 1
            if (length2[i] - 2) < mix_len:
                mix_len = length2[i] - 2

            flip_idx = random.sample(range(1, min(len1[i] - 1, length2[i] - 1)), mix_len)
            inputs_aug['input_ids'][i][flip_idx] = input2['input_ids'][i][flip_idx]
            ratio[i] = 1 - (mix_len / (len1[i].item() - 2))

        return inputs_aug, ratio

Copy the code