LEARNING TO PROPAGATE · TRANSDUCTIVE PROPAGATION NETWORK FOR Few-shot LEARNING

Tags: Small sample learning

Summary: This paper proposes a transduction propagation network for tag propagation, which consists of four parts: feature extraction, graph model construction, tag propagation, loss calculation, and end-to-end training of the network

Problem definition

Episode training strategy is still adopted in this paper. Given a training set containing a lot of annotated data, we train a classifier that can better identify the samples in the new class, and there are few annotated samples in the new class.

Similarly, N classes are sampled from the training set, and K samples are taken from each class, forming an N-way K-shot scene. This data is used as the support set (S\mathcal{S}S). Query set (Q\mathcal{Q}Q) K is usually very small, that is, very little annotated data. This makes it difficult to get a good classifier. The author’s idea is to use the whole Query set for prediction, rather than just using each sample individually. This alleviates the problem of low data and improves generalization performance.

Transduction networks

The transmission network consists of four parts: feature extraction, graph model construction, label propagation and loss calculation. Feature extraction is realized by a convolutional neural network. Tag propagation is the propagation of tags from a Support set to a Query set. The loss calculation uses the cross entropy loss to calculate the difference between the actual and propagated labels in the Query set.

Feature extraction

A convolutional neural network fφf_{\varphi}fφ consists of four convolution blocks for feature extraction. Each convolution block includes a 2-dimensional convolution layer, using a 3×3 kernel, a total of 64; The Batch-normalization layer is followed by the Relu activation function and the 2×2 Max-pooling layer. φ\ varPHI φ this parameter is adjusted according to the final predicted loss.

Graph model construction

Nodes of the graph represent samples, and edges represent connections. Gaussian similarity function (Formula 1) is used to determine whether there are edges between two nodes. The similarity between nodes is calculated by this function, and then the k closest samples are selected to build edges. The constructed graph is called a K-nearest neighbor graph.


W i j = e x p ( d ( x i . x j ) 2 sigma 2 ) (1) W_{ij}=exp(-\frac{d(x_i, x_j)}{2\sigma ^2})\tag{1}

In formula 1, D (xi,xj)d(x_i, X_j)d(xi,xj) is a distance measuring function, which is used to measure the distance between the two samples. Details about this function need to refer to the original paper.

σ\sigma sigma is a superparameter, and the authors mention that there is no good rule for adjusting this parameter, so the authors construct a simple convolutional neural network G ϕ(⋅) G_ {\phi}(·)gϕ(⋅) to calculate this parameter. In effect, this is equivalent to having G ϕg_{\phi}gϕ adjust its own parameters ϕ\phiϕ so that its calculation of sigma sigma optimizes the final prediction, rather than manually adjusting the σ\sigma sigma parameters, depending on the loss of ϕ\phi}. Each sample calculates its own σ\sigmaσ parameters.

The convolutional network gϕg_{\phi}gϕ consists of two convolutional blocks and two fully connected layers. The first convolution block includes 64 3×3 Filter + Batch Normalization + ReLU activation + 2×2 Max pooling. The second convolutional block includes a 3×3 Filter + Batch normalization + ReLU activation + 2×2 Max pooling. The first full connection layer consists of 8 neurons, and the second full connection layer consists of 1 neuron.

With sigma sigma, the edge weight matrix WWW can be calculated. The author also uses a standardized graph Laplacian to treat the matrix, as shown in Formula 2:


S = D 1 / 2 W D 1 / 2 (2) S=D^{-1/2}WD^{-1/2}\tag{2}

Label propagation

Use F\mathcal{F}F to represent the set of 100*100 nonnegative matrices. Y∈FY\in \mathcal{F}Y∈F represents the label matrix, Yij=1Y_{ij}=1Yij=1 indicates that the sample Xix_ixi comes from the support set and belongs to the JTH class. Other cases are denoted as Yij=0Y_{ij}=0Yij=0. By iterating through Formula 3 from YYY, you can propagate labels according to the constructed graph structure. Where FtF_{t}Ft represents the tag prediction score under the t timestamp. Formula 4 is the convergence of Formula 3, which can be used to directly calculate the predicted tag score.


F t + 1 = Alpha. S F t + ( 1 Alpha. ) Y (3) F_{t+1}=\alpha SF_t+(1-\alpha)Y\tag{3}

F = ( I Alpha. S ) 1 Y (4) F^{\star}=(I-\alpha S)^{-1}Y\tag{4}

Loss calculation

The input samples are from the Support set and Query Set sections. Softmax activation function was used on the basis of F⋆F^{\star}F⋆. The loss function uses the cross entropy loss function to calculate the difference between the calculated tag and the real tag. The loss function has two parameters φ\varphiφ and ϕ\phiϕ. Where φ\varphiφ represents the parameters of the convolutional neural network for feature extraction, ϕ\phiϕ is the parameter of the convolutional neural network for computing the parameter σ\sigma sigma in the Gaussian similarity function.

The overview

  1. The samples in the Support set and query set are combined, denoted as XXX, and then input into the convolutional network F φf_{\varphi}fφ to extract the feature F φ(X)f_{\varphi}(X)fφ(X).
  2. To extract the characteristics of the phi f (X) f_ phi {\ varphi} (X) f (X) input to the convolution network g ϕ g_ {\ phi} g ϕ sigma \ sigma sigma of each sample.
  3. From this σ\sigmaσ, the similarity matrix WWW is calculated, and the normalized graph Laplacian is applied to the matrix to obtain SSS. Select k nodes that are most similar to each other to build edges.
  4. The constructed graph is used for label propagation, so as to obtain the predicted labels on the Query set, and then calculate the loss. With the loss, the gradient can be calculated and then propagated back to update the parameters.

The appendix

Manifold structure: the set of points in space, is a set concept. For example, a curve in a two-dimensional space is a one-dimensional manifold in a two-dimensional space.

Transduction reasoning: Prediction of specific training samples by observing specific training samples. You can think of it as going from special to special

Inductive reasoning: learn the passed rules from the training samples, and then judge the test samples by the rules. The purpose of induction is to find a general rule, so generalization must be involved here.

End-to-end study: from the input (input data) to the output end will get a forecast results, comparing with the real result will get an error, the error will be transfer (back propagation) model in each layer, each layer of said will make adjustment according to the error, until the end of the model convergence or achieve the desired effect, it is end to end. In this article, I think the most obvious place is the study of ϕ\phi and φ\varphi phi.