DARTS is a very classic NAS method, which breaks the previous discrete network search mode and can carry out end-to-end network search. Since DARTS updates the network based on gradients, the update direction is relatively accurate, and the search time is greatly improved compared to the previous method. Cifar-10 only needs 4GPU days to search.

Source: Xiaofei algorithm engineering Notes public account

DARTS: Differentiable Architecture Search

  • Thesis Address:Arxiv.org/abs/1806.09…
  • Thesis Code:Github.com/quark0/dart…

Introduction


At present, most popular neural network search methods select discrete candidate networks, while DARTS searches continuous search space and uses gradient descent to optimize network structure according to the performance of verification set. The main contributions of this paper are as follows:

  • Based on Bilevel optimization, an innovative gradient-based neural network search method, DARTS, is proposed, which is suitable for convolution structures and cyclic structures.
  • Experiments show that gradient-based structure search method has good competitiveness in BOTH CIFAR-10 and PTB data sets.
  • The search performance is very strong and only a small number of GPU days are required, mainly due to the gradient-based optimization mode.
  • The networks learned by DARTS on CIFAR-10 and PTB can be transferred to the large data sets ImageNet and Wikitext-2.

Differentiable Architecture Search


Search Space

Like NASNet and other methods, DARTS ‘overall search framework searches the network infrastructure of cells and then stacks them into convolution networks or circular networks. The computing unit is a directed acyclic graph, which contains an ordered sequence of NNN nodes, each node x(I)x^{(I)}x(I) represents the intermediate information of the network (such as the feature graph of the convolutional network), and the edge represents the operation o(I,j)o^{(I,j)} O (I,j) on x(I)x^{(I,j) x(I). Each computing unit has two inputs and one output. For the convolutional unit, the input is the output of the computing unit at the first two layers. For the cyclic network, the input is the input of the current step and the state of the previous step. Each intermediate node is calculated based on all the preceding nodes:

There is a special zero operation that specifies that there is no connection between the two nodes. DARTS transforms the learning of computing units into the learning of edge operations. The overall search framework is the same as NASNet and other methods. This paper mainly focuses on how DARTS conducts gradient-based search.

Continuous Relaxation and Optimization

OOO as a set of candidate operations, with o()o(\cdot)o(⋅) applied to x(I)x^{(I)}x(I).

Node (I, j) (I, j) (I, j) between the operation of the mixed weight is expressed as dimension ∣ O ∣ | O | ∣ O ∣ vector of alpha (I, j) \ alpha ^ {} (I, j) alpha (I, j), Search the whole structure is simplified as continuous learning the value of alpha = {alpha (I, j)} \ alpha = \ {\ alpha ^ {} (I, j) \} of alpha = {alpha (I, j)}, as shown in figure 1. At the end of the search, Each node selects the operating with the highest probability o (I, j) = argmaxo ∈ o alpha (I, j) o ^ o {} (I, j) = argmax_ {o \} in o \ alpha ^ {} (I, j) _oo (I, j) = argmaxo ∈ o alpha o (I, j) instead O ˉ(I,j)\bar{o}^{(I,j)} O ˉ(I,j) After simplification, DARTS aims to be able to simultaneously learn the network structure α\alphaα and all operational weights WWW. Compared to the previous method, DARTS was able to use gradient descent for structural optimization based on validation set losses. Define Ltrain\mathcal{L}_{train}Ltrain and Lval\mathcal{L}_{val}Lval as training and validation set losses, which are determined by the network structure α\alphaα and the network weight WWW. Search of the ultimate aim is to find the optimal alpha ∗ \ alpha ^ {*} alpha ∗ to minimize loss Lval validation set (w ∗, alpha ∗) \ mathcal {L} _ {val} (w ^ {*}, \ alpha ^ {*}) Lval (w ∗, alpha ∗), The network weights w ∗ w ^ {*} w ∗ is training by minimizing losses w ∗ = argminwLtrain (w, alpha ∗) w ^ {*} = argmin_w \ mathcal {L} _ {” train “} (w, \ alpha ^ {x}) w ∗ = argminwLtrain (w, alpha ∗). This means that DARTS is a bilevel optimization problem, using the validation set to optimize the network structure, using the training set to optimize the network weight, α\alphaα as the superior variable, WWW as the subordinate variable:

Approximate Architecture Gradient

The cost of calculating the gradient of the network structure in Formula 3 is very high, which mainly lies in the inner optimization of Formula 4, that is, every modification of the structure needs to be retrained to obtain the optimal weight of the network. In order to simplify this operation, the paper proposes a simple approximate improvement:

WWW represents the current network weight, ξ\xiξ is the learning rate of a single update of inner layer optimization. The overall idea is that after the network structure changes, WWW is optimized by a single training step to approximate W (∗)(α)w^{(*)}(\alpha)w(∗)(α), instead of formula 3, which requires complete training until convergence. ∇wLtrain(W,α)=0\ Nabla_ {W}\mathcal{L}_{train}(W, \alpha)=0∇wLtrain(W,α)=0), 5 ∇ alpha Lval formula 6 equivalent formula (w, alpha) \ nabla_ {\ alpha} \ mathcal {L} _ {val} (w, \ alpha) ∇ alpha Lval (w, alpha).

The iterative process, such as Algorithm 1, updates the network structure and weight alternately, and only a small amount of data is used in each update. According to the chain rule, formula 6 can be expanded as:

W – ‘= w factor ∇ wLtrain (w, alpha) w ^ {‘} = w – \ xi \ nabla_w \ mathcal {L} _ {” train “}’ (w, \ alpha) w = w – factor ∇ wLtrain (w, alpha), the second calculation formulas of overhead is very large, The paper uses finite difference to approximate calculation, which is a key step in the paper. ϵ\epsilonϵ is a small scalar, W + = w + ϵ ∇ w ‘Lval (w’, alpha) w ^ {\ PM} = w \ PM \ epsilon \ nabla_ ^ {‘} {w} \ mathcal {L} _ {val} (w ^ {‘}, \ w + = alpha) w + ϵ ∇ w ‘Lval (w’, alpha), are:

Calculate the final difference requires two forward + reverse calculation, computing complexity from O (∣ alpha ∣ ∣ w ∣) O (| \ alpha | | | w) O (∣ alpha ∣ ∣ w ∣) is simplified to O (∣ alpha ∣ + ∣ w ∣) O (| \ alpha | | | + w) O (∣ alpha ∣ ∣ ∣ + w).

  • First-order Approximation

When factor = 0 \ xi = 0 factor = 0, the second derivative of formula 7 disappear, gradient by ∇ alpha L (w, alpha) \ nabla_ {\ alpha} \ mathcal {L} (w, \ alpha) ∇ alpha L (w, alpha) decision, which holds that the current weight is always the best, The loss of validation set is optimized by modifying network structure directly. ξ=0\xi=0ξ=0 Speeds up the search process, but it may cause poor performance. When ξ=0\xi=0ξ=0, it is called first-order approximation; when ξ>0\xi >0 ξ>0, it is called second-order approximation.

Deriving Discrete Architectures

When constructing the final network structure, each node selects top-K non-zero operations with the strongest response from different nodes. Response strength through the exp (alpha (I, j) o) ∑ o ‘∈ Oexp (alpha o’ (I, j)) \ frac {exp (\ alpha ^ {(I, j) _o})} {\ sum_ {^ o {‘} \ in O} exp (\ alpha ^ {(I, j)} _ {O ^ {‘}})} ∑ O ‘∈ Oexp (alpha O’ (I, j)) exp (alpha (I, j) O). In order to improve the network performance of the search, the convolution unit is set as K = 2K = 2K =2, and the loop unit as K =1k=1k=1. The filter Zero operation mainly allows each node to have enough inputs to make a fair comparison with the current SOTA model.

Experiments and Results

Search time, where run indicates that the best result is obtained after multiple searches.

The structure of the search.

Performance comparison on CIFAR-10.

Performance comparison on PTB.

Performance comparison for migration to ImageNet.

Conclustion


DARTS is a very classic NAS method, which breaks the previous discrete network search mode and can carry out end-to-end network search. Since DARTS updates the network based on gradients, the update direction is relatively accurate, and the search time is greatly improved compared to the previous method. Cifar-10 only needs 4GPU days to search.





If this article was helpful to you, please give it a thumbs up or check it out

For more information, please pay attention to wechat official account [Algorithm Engineering Notes of Xiaofei]