It is well known that the Attention-based Transformer class model has good parallelization performance, but its spatial and temporal complexity is O(n2)\mathcal{O}(n^2)O(n2), NNN is the sequence length, Therefore, when the NNN is large, the computation of Transformer model is unbearable. More recently, there has been a lot of effort to reduce the computational complexity of Transformer models, such as streamlining techniques such as model pruning, quantification and distillation, or modifying the Attention structure. So that the complexity can be reduced to O(nlog⁡n)\mathcal{O}(nlog⁡n)O(nlog⁡n) or even O(n)\mathcal{O}(n)O(n)

Thesis “Transformers are RNNs: I was intrigued by a method for Linear Attention in Fast Autoregressive Transformers with Linear Attention, and read a few blogs about it. There are some good results, and finally I will summarize my understanding of linear Attention in this article

#### Attention

The most popular Attention mechanic at the moment is calli-dot Attention

\begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{aligned}

Q∈Rn× DK,K∈Rm× DK,V∈Rm× DV \ boldSymbol {Q}\in \mathbb{R}^{n\times d_k}, \ boldSymbol {K}\in \mathbb{R}^{m\times d_k}, \boldsymbol{K}\in \mathbb{R}^{m\times d_k} \boldsymbol{V}\in \mathbb{R}^{m\times d_v}Q∈Rn× DK,K∈Rm× DK,V∈Rm× DV. In this article we focus on the Self Attention scenario, so for the sake of introduction, Let Q,K,V∈Rn×d\ boldSymbol {Q},\ boldSymbol {K},\ boldSymbol {V}\in \mathbb{R}^{n\times d}Q,K,V∈Rn×d

#### To remove Softmax

Perhaps surprisingly, the key constraint on Attention’s performance is the Softmax! In fact, it’s a simple derivation. QKTQK^TQKT At this step we get an n×nn\times nn×n matrix, and then we need to make a Softmax

Softmax for a 1×n1\times n1×n row vector is O(n)O(n)O(n), but Softmax for each row of an N ×nn\times nn×n matrix is O(n2)O(n^2)O(n2).

If there is no Softmax, then the formula of Attention becomes three matrix multiplication QK⊤V\boldsymbol{QK^{\top}V}QK⊤V, and matrix multiplication is satisfied with the combination rate, so we can first calculate K⊤V\boldsymbol{K^{\top}V}K⊤V, Obtain a matrix d×dd\times dd×d (the time complexity of this step is O(d2n)O(d^2n)O(d2n)), then left multiply it by QQQ (the time complexity of this step is O(d2n)O(d^2n)O(d2n)), because d < nd \ll nd < n, So the approximate time complexity is just O(n), O(n), O(n).

For BERT base, d=64d=64d=64 instead of 768, why? Because 768 is actually obtained by multi-head concatenation, and d= 64D = 64D =64 for each Head

In other words, removing the Attention complexity of Softmax can reduce it to the optimal linear level O(n)\mathcal{O}(n)O(n)! This is obviously our ultimate pursuit: Linear Attention

#### General definition

The question is, would it still be Attention without Softmax? Can he still have the standard Attention effect? To answer this question, let’s rewrite the definition of scale-dot Attention as equivalent (the vectors in this article are column vectors)

\begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}_j}{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{aligned}

Q,K∈Rn×d\ boldSymbol {Q},\ boldSymbol {K}\in \mathbb{R}^{n\times d}Q,K∈Rn×d, M=Q×K⊤\boldsymbol{M} = \ boldSymbol {Q}\times \boldsymbol{K^{\top}}M=Q×K⊤ The first row of M\ boldSymbol {M}M is obtained by multiplying the first row of Q\ boldSymbol {Q}Q by all columns of K⊤\ boldSymbol {K^{\top}}K⊤

Attention(Q,K,V)iAttention(\ boldSymbol {Q},\ BoldSymbol {K},\ BoldSymbol {V})_iAttention(Q,K,V) I indicates the third row of the final output matrix

Qi ⊤\ boldSymbol {q}_i^{\top}qi⊤ q ∈Rn×d\ boldSymbol {q} \in \mathbb{R}^{n\times d} q ∈Rn×d

Kj \ boldSymbol {k} _jKJ \ k ⊤∈Rd×n\ boldSymbol {k ^{\top}}\in \mathbb{R}^{d\times n} k ⊤

Vj \boldsymbol{v}_jvj = v ⊤∈Rd×nV^{\top}\in \mathbb{R}^{d\times n} v ⊤

Eqi ⊤kje^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}eqi⊤kj for vj\boldsymbol{v}_jvj. So we can come up with a general definition of Attention

\begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{aligned}

Eqi ⊤kje^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}eqi⊤kj,ki\boldsymbol{q}_i,\boldsymbol{k}_iqi,ki Sim (qi,kj)\text{sim}(\boldsymbol{q}_i,\ boldSymbol {k}_j)sim(qi, KJ)\text{sim}(\boldsymbol{q}_i,\ boldSymbol {k}_j) We require sim(Qi,kj)≥0\text{sim}(\boldsymbol{q}_i, \ boldSymbol {k}_j)\geq 0sim(Qi, KJ)≥0. In other words, if we want to define new Attention, we must retain the form (3), and sim(Qi,kj)≥0\text{sim}(\boldsymbol{q}_i, \ boldSymbol {k}_j)\geq 0sim(Qi, kJ)≥0

This general form of Attention is also called non-local Networks in CV, from the paper “Non-Local Neural Networks”

#### A few examples

If I just remove Softmax, So sim(qi,kj)=qi⊤kj\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \ boldSymbol {q}_i^{\top}\ BoldSymbol {k}_jsim(Qi,kj)=qi⊤kj. Here are some options

It is worth mentioning that the first two kinds of Linear Attention introduced below are from the CV field, and the third one was conceived by Mr. Su Jianlin (in addition to the following introduction, there are also the improvement work of Attention in THE CV field such as EMANet).

#### Kernel form

A natural idea is that if qi,kj\boldsymbol{q}_i, \boldsymbol{k}_jqi,kj \boldsymbol{k}_jqi, each element of KJ is non-negative, then the inner product is also non-negative. To accomplish this, we can add the activation function ϕ,φ\phi,\varphiϕ,φ to Qi,kj\boldsymbol{q}_i, \boldsymbol{k} _jQi,kj respectively

\begin{aligned}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\tag{4}\end{aligned}

ϕ (⋅), phi (⋅) \ phi (\ cdot), \ varphi (\ cdot) ϕ (⋅), phi (⋅) is the range of the activation function. The beginning of the article mentioned in the paper the Transformers are RNNs “choose the ϕ (x) = phi (x) = elu (x) + 1 \ phi (x) = \ varphi (x) = \ text {elu} (x) + 1 ϕ (x) = phi (x) = elu (x) + 1, one of them

$\text{elu}(x)=\begin{cases}x& \text{if} \ x>0\\ \alpha (e^x-1) & \text{if}\ x<0\end{cases}$

The common α\alphaα values are [0.1,0.3][0.1, 0.3][0.1,0.3] [0.1,0.3]

If you want to tell a story, formula (4) can be associated with the “kernel method”, especially when ϕ=φ\phi=\varphiϕ=φ, ϕ\phiϕ is equivalent to a kernel function, And ⟨ ϕ (qi), ϕ (kj) ⟩ \ langle \ phi (\ boldsymbol {q} _i), \ phi (\ boldsymbol {k} _j) \ rangle ⟨ ϕ (qi), ϕ (kj) ⟩ is through the kernel function is defined by the inner product. For this consideration, please refer to the paper Transformer Dissection: An Unified Understanding for Transformer’s Attention via the Lens of kernel

#### Use Softmax

An earlier article, Efficient Attention: Attention with Linear Complexities, offers a more interesting option. In QK⊤\ boldSymbol {QK^{\top}}QK⊤, Q,K∈Rn×d\ boldSymbol {Q},\ boldSymbol {K}\in \mathbb{R}^{n\times d}Q,K∈Rn×d, If “Q\ boldSymbol {Q}Q is normalized in DDD, and K\ boldSymbol {K}K is normalized in NNN”, then QK⊤\ boldSymbol {QK^{\top}}QK⊤ automatically meets normalization, so the choice it gives is

\begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\tag{5}\end{aligned}

Where softmax1SoftMax_1SoftMax1 and softMax2SoftMax_2SoftMax2 respectively mean Softmax operation in the first (n)(n)(n) and the second dimension (D)(d)(d). In other words, we give Q,K\boldsymbol{Q},\ boldSymbol {K}Q,K plus Softmax, instead of QK⊤\boldsymbol{QK^{\top}}QK⊤ and then add Softmax

In fact, it can be proved that this form is also a special case of equation (4), This corresponds to ϕ (qi) = softmax (qi), phi (kj) = ekj \ phi (\ boldsymbol {q} _i) = softmax (\ boldsymbol {q} _i), \ varphi (\ boldsymbol {k} _j) = e ^ {\ boldsymbol {k} _ J}ϕ(qi)=softmax(Qi),φ(kj)=ekj

#### The conception of god Su

Here, God Su gives an idea. The starting point of this idea is no longer Eq. (4), but comes from our Taylor expansion of the original definition (2). And by Taylor unfolding we have

\begin{aligned}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\tag{6}\end{aligned}

If qi⊤kj≥−1\boldsymbol{q}_i^{\top}\ boldSymbol {k}_j\ geq-1Qi ⊤kj≥−1, then the right end is guaranteed to be nonnegative, So let sim(qi,kj)=1+qi⊤kj\text{sim}(\boldsymbol{q}_i, \ boldsymbol _j) = 1 + \ boldsymbol {k} {q} _i ^ {\ top} \ boldsymbol {k} _jsim (qi, kj) = 1 + qi ⊤ kj. As you might have thought by now, to ensure that QI ⊤kj≥−1\boldsymbol{q}_i^{\top}\ BoldSymbol {k}_j\ geq-1qi ⊤kj≥−1, We only need to normalize L2L_2L2 for QI, KJ \ boldSymbol {q}_i,\ boldSymbol {k} _jQi, KJ respectively. Therefore, the final plan proposed by Su God was:

\begin{aligned}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\tag{7}\end{aligned}

If x = \ [x1, x2,…, xn] boldsymbol {x} = (x_1, x_2,…, x_n] x = [x1, x2,…, xn], Is ∥ ∥ x = x12 + x22 + ⋅ ⋅ ⋅ + xn2 \ \ Vert Vert x = \ SQRT {x_1 ^ 2 + x_2 ^ 2 +… + x_n ^ 2} ∥ ∥ x = x12 + x22 + ⋅ ⋅ ⋅ + xn2

This is different from Formula (4), but theoretically it is closer to the original scale-dot Attention

#### implementation

Here, the method proposed by Su Shen is mainly implemented. However, due to the limited level of the author, there are actually some problems in the final implementation code, mainly as follows:

1. From the test results, the improved computing speed did not improve
2. You can’t add up to one

Code implementation is mainly aimed at BERT PyTorch realize the code of this article, more specifically, it only changes the ScaledDotProductAttention this function, so the following only released this part of the code

class ScaledDotProductAttention(nn.Module) :
def __init__(self) :
super(ScaledDotProductAttention, self).__init__()

def forward(self, Q, K, V, attn_mask) :
Q = F.normalize(Q, dim=3)
K = F.normalize(K, dim=3)
M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
M_sum = torch.sum(M, dim=3)
M = M / M_sum.unsqueeze(3).repeat(1.1.1, M.shape[3])
Copy the code