Source: ACL 2017

Enhanced LSTM for Natural Language Inference

Don’t know what happened this year, the essays in this match to the problem of data mining competition emerge in endlessly, since will Question Pairs | Kaggle began, the tianchi CIKM AnalytiCup 2018 | format is introduced, and then to ATEC ants developer contest, And ppDAI AI development platform – the third Magic Mirror Cup Competition… It is suddenly like the night spring breeze, thousands of pear trees open.

Today I’d like to take the opportunity to write about one of the biggest killers in short text matching, ESIM, which has swept so many competitions that most of the winners of these competitions have used this method (integrated mandatory model). Meanwhile, as before, I’ll attach the implementation code, this time using PyTorch to implement the model.

Start getting down to business.

ESIM, “Enhanced LSTM for Natural Language Inference”. As the name suggests, an enhanced VERSION of LSTM designed for natural language inference. As for how it enhances LSTM, listen to me.

Unlike the previous top models that use very complicated network 
architectures, we first demonstrate that carefully designing sequential inference 
models based on chain LSTMs can outperform all previous models.
Based on this, we further show that by explicitly considering recursive 
architectures in both local inference modeling and inference composition, 
we achieve additional improvement.Copy the code

In conclusion, ESIM is superior to other short text classification algorithms in two main ways:

  1. Elaborate sequential inference structures.
  2. Consider local inference and global inference.

The author mainly uses intra-sentence attention to realize local inference and further global inference.

ESIM is mainly divided into three parts: input encoding, local inference modeling and inference composition. As shown below, ESIM is the left part.


input encoding

There is nothing to say but input two sentences respectively to be embeding + BiLSTM. Here why not use the recently popular BiGRU, the author explained that the experimental effect is not good. TreeLSTM can also be used if sentence parsing is possible. The original ESIM did not have this part.


BiLSTM can be used to learn how to represent the relationship between word and its context in a sentence. We can also understand that this is re-encoded in the current context after Word embedding to get a new embeding vector. This part of the code is as follows, more intuitive.

def forward(self, *input):
   # batch_size * seq_len
    sent1, sent2 = input[0], input[1]
    mask1, mask2 = sent1.eq(0), sent2.eq(0)

   # embeds: batch_size * seq_len => batch_size * seq_len * embeds_dim
    x1 = self.bn_embeds(self.embeds(sent1).transpose(1, 2).contiguous()).transpose(1, 2)
    x2 = self.bn_embeds(self.embeds(sent2).transpose(1, 2).contiguous()).transpose(1, 2)

   # batch_size * seq_len * embeds_dim => batch_size * seq_len * hidden_size
    o1, _ = self.lstm1(x1)
    o2, _ = self.lstm1(x2)Copy the code

local inference modeling

Before local inference, the two sentences need to be aligned. Here soft_align_attention is used.

To do this, first calculate the similarity between two words and obtain a 2-dimensional similarity matrix, using torch. Matmul.


Then the local inference was made with two sentences. The similarity matrix obtained before is used to combine sentences A and B to generate similar-weighted sentences with unchanged dimensions. It’s a little convoluted here, so let’s use the following code to explain.


After local inference, Enhancement of local inference information was carried out. Enhancement here means calculating the difference and dot product of a after a and align, which reflects a kind of difference and makes more use of the later study.


def soft_align_attention(self, x1, x2, mask1, mask2): ''' x1: batch_size * seq_len * hidden_size x2: batch_size * seq_len * hidden_size ''' # attention: batch_size * seq_len * seq_len attention = torch.matmul(x1, x2.transpose(1, 2)) mask1 = mask1.float().masked_fill_(mask1, float('-inf')) mask2 = mask2.float().masked_fill_(mask2, float('-inf')) # weight: batch_size * seq_len * seq_len weight1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1) x1_align = torch.matmul(weight1, x2) weight2 = F.softmax(attention.transpose(1, 2) + mask1.unsqueeze(1), dim=-1) x2_align = torch.matmul(weight2, x1) # x_align: batch_size * seq_len * hidden_size return x1_align, x2_align def submul(self, x1, x2): Return torch. Cat ([sub, mul], -1) def forward(self, *input): batch_size * seq_len * hidden_size q1_align, q2_align = self.soft_align_attention(o1, o2, mask1, mask2) # Enhancement of local inference information # batch_size * seq_len * (8 * hidden_size) q1_combined = torch.cat([o1, q1_align, self.submul(o1, q1_align)], -1) q2_combined = torch.cat([o2, q2_align, self.submul(o2, q2_align)], -1) ...Copy the code

inference composition

This is the last step. It’s easy.

BiLSTM is used again to advance the context information, while MaxPooling and AvgPooling are used, and finally a full connection layer is followed. It’s more traditional here. There’s nothing to say.

def apply_multiple(self, x):
    # input: batch_size * seq_len * (2 * hidden_size)
    p1 = F.avg_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
    p2 = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
    # output: batch_size * (4 * hidden_size)
    return torch.cat([p1, p2], 1)

def forward(self, *input):
    ...
    
    # inference composition
    # batch_size * seq_len * (2 * hidden_size)
    q1_compose, _ = self.lstm2(q1_combined)
    q2_compose, _ = self.lstm2(q2_combined)

    # Aggregate
    # input: batch_size * seq_len * (2 * hidden_size)
    # output: batch_size * seq_len * (2 * hidden_size)
    q1_rep = self.apply_multiple(q1_compose)
    q2_rep = self.apply_multiple(q2_compose)

    # Classifier
    x = torch.cat([q1_rep, q2_rep], -1)
    sim = self.fc(x)
    return simCopy the code

thinking

Why does ESIM work so well? Here I would like to mention a few of my own thoughts. I think ESIM is particularly impressive in its inter-sentence attention, soft_align_attention in the above code, where the two sentences to be compared interact. I’ve seen structures like Siamese networks before, where there’s no interaction, just cosine distance at the last layer or whatever.


References: Enhanced LSTM for Natural Language Inference

Code address: PengShuang/text-similarity