directory

  1. BatchNorm principle

  2. PyTorch implementation of BatchNorm

2.1 _NormBase class

2.1.1 initialization

2.1.2 Simulate BN Forward

2.1.3 Update of running_mean and running_var

2.1.4 \gamma, \beta update

2.1.5 eval mode

2.2 BatchNormNd class

  1. PyTorch implementation of SyncBatchNorm

3.1 forward

3.2 backward

1. BatchNorm principle

BatchNorm was first proposed in fully connected networks to normalize the input of each neuron. To extend to CNN, it is to normalize the input of each convolution kernel, or to normalize all dimensions outside the channel. There are many benefits brought by BN. Here are a few:

  • Prevention of over-fitting: The output of a single sample depends on the whole mini-batch to prevent over-fitting of a sample;
  • Accelerate convergence: gradient descent in the process of each layer 和 As a result, the distribution of output results is constantly changing, and the back-end network must constantly adapt to such distribution changes. With BN, the distribution of inputs at each layer can be kept approximately constant.
  • Prevent gradient dispersion: In the forward process, it gradually approaches the upper and lower ends of the value interval of the nonlinear function (take Sigmoid as an example). In this case, the gradient of the back layer becomes very small, which is not conducive to training.

The mathematical expression of BN is:

There’s a scaling factor herePeaceshift factor, the author explains how they work:

  • To Normalize , This causes the new distribution to lose features and knowledge passed from the previous layer
  • Take Sigmoid for example, join , It prevents most of the values from falling in the middle of the nearly linear part, making it impossible to take advantage of the nonlinear part

PyTorch implementation of BatchNorm

PyTorch BN related several classes on the torch. The nn. Modules. Batchnorm, contains the following categories:

  • _NormBase:nn.ModuleThe subclass of BN defines a series of properties and methods of initializing and reading data;
  • _BatchNorm:_NormBaseA subclass offorwardMethods;
  • BatchNorm1d & BatchNorm2d & BatchNorm3d:_BatchNormSubclasses of, which define different_check_input_dimMethods.

2.1 _NormBase class

2.1.1 initialization

The _NormBase class defines some of the attributes associated with BN, as shown in the following table:

attribute meaning
num_features Number of channels entered
track_running_stats The default value is True. Running_mean, running_var
running_mean The input mean was used in the training and was later used in inference
running_var Input var was counted during training, and then used in inference
momentum Default 0.1, update momentum when running_mean, running_var
num_batches_tracked PyTorch 0.4 is new. When momentum is set to None, num_batches_tracked calculates the momentum of each update
affine Defaults to True, training weight and bias; Otherwise, their values are not updated
weight So \gamma in the equation, initialize to all 1 tensor
bias So \beta in the equation, initialize to all 0 tensor

Here is the source code for PyTorch:

class _NormBase(Module): """Common base of _InstanceNorm and _BatchNorm""" # check whether PyTorch 0.4.1 is earlier than PyTorch 0.4.1 __constants__ = ['track_running_stats', 'momentum', 'eps', 'num_features', 'affine'] def __init__(self, num_features, Eps = 1-5 e, momentum = 0.1, affine = True, track_running_stats = True) : super(_NormBase, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: # If you open affine, Then use the scaling factor and translation factor self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else:  self.register_parameter('weight', None) self.register_parameter('bias', None) # if self.track_running_stats: Self.register_buffer ('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None) self.reset_parameters() def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self): self.reset_running_stats() if self.affine: init.ones_(self.weight) init.zeros_(self.bias) def _check_input_dim(self, input): Raise NotImplementedError def extra_repr(self): return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 'track_running_stats={track_running_stats}'.format(**self.__dict__) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if (version is None or version < 2) and self.track_running_stats: # at version 2: added num_batches_tracked buffer # this should have a default value of 0 num_batches_tracked_key = prefix + 'num_batches_tracked' if num_batches_tracked_key not in state_dict: Older versions of checkpoint don't have this key. Set it to 0 state_dict[num_batches_tracked_key] =torch. Tensor (0, dtype=torch. Long) super(_NormBase, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class _BatchNorm(_NormBase): Def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_BatchNorm, self).__init__( num_features, eps, momentum, affine, track_running_stats) def forward(self, input): self._check_input_dim(input) # exponential_average_factor is set to self.momentum # (when it is available) only so that it gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: Exponential_average_factor = 0.0 the else: Exponential_average_factor = self.momentum # If you are in train state and self.track_running_stats is set to True, Track_running_stats: if self.training and self.track_running_stats: if self.num_batches_tracked is not None: Self. num_batches_tracked = self.num_batches_tracked + 1 # If momentum is set to None, Just use num_batches_tracked to weight if self.momentum is None: Exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)Copy the code

2.1.2 Simulate BN Forward

The Python portion of BN in PyTorch does initialization, parameter passing, and low-level method calls. Here we use Python to simulate the underlying calculations of BN.

Import the torch import torch. The nn as nn import torch. The nn. Modules. Batchnorm # create random input def create_inputs () : Return torch. Randn (8, 3, 20, 20) # mean_val (var_val) Def dummy_bn_forward(x, bn_weight, bn_bias, eps, mean_val=None, var_val=None): if (x, bn_weight, bn_bias, eps, mean_val=None, var_val=None): mean_val = x.mean([0, 2, 3]) if var_val is None: Var is an unbiased estimate by default. Therefore, you need to set unbiased=False var_val = x.var([0, 2, 3], unbiased=False) x = x-mean_val [None,... None, None] x = x / torch.sqrt(var_val[None, ..., None, None] + eps) x = x * bn_weight[..., None, None] + bn_bias[..., None, None] return mean_val, var_val, xCopy the code

Verify the correctness of dummy BN output:

Outputs = bn_layer(inputs) bn_layer = nn.BatchNorm2d(num_features=3) inputs = create_inputs() Outputs = dummy_BN_forward (Inputs, BN_layer. weight, BN_layer. bias, bn_layer.eps) assert torch.allclose(expected_outputs, bn_outputs)Copy the code

No exception was reported, so the calculated value is correct.

2.1.3 Update of running_mean and running_var

BatchNorm turns track_RUNNING_STATS on by default, so runNING_mean and running_var are updated each time a forward motion is made based on the current minibatch statistics.

The default value of momentum is 0.1, which controls the relative impact of historical statistics on the current minibatch when runNING_mean and runNING_var are updated.

Among them 、respectivelyMean and variance of; It should be noted that unbiased estimation is used for statistical variance here, which is consistent with the paper. Manually simulate this process as shown below:

Running_mean = torch. Zeros (3) running_var = torch. Ones_like (running_mean) Momentum = 0.1 # This is also the default value for momentum when BN initializes bn_layer = nn.BatchNorm2d(num_features=3, momentum=momentum) # Simulate forward 10 times for t in range(10): inputs = create_inputs() bn_outputs = bn_layer(inputs) inputs_mean, inputs_var, _ = dummy_bn_forward( inputs, bn_layer.weight, bn_layer.bias, Bn_layer.eps) n = inputs. Numel ()/inputs. Size (1) # update running_var and running_mean running_var = running_var * (1 - momentum) + momentum * inputs_var * n / (n - 1) running_mean = running_mean * (1 - momentum) + momentum * inputs_mean assert torch.allclose(running_var, bn_layer.running_var) assert torch.allclose(running_mean, bn_layer.running_mean) print(f'bn_layer running_mean is {bn_layer.running_mean}') print(f'dummy bn running_mean is {running_mean}') print(f'bn_layer running_var is {bn_layer.running_var}') print(f'dummy bn running_var is {running_var}')Copy the code

Output result:

Bn_layer running_mean is tensor([0.0101, -0.0013, 0.0101]) dummy BN Running_mean is tensor([0.0101, -0.0013, Dummy BN running_var is tensor([0.9857, 0.9883, 1.0205]) BN_layer running_var is tensor([0.9857, 0.9883, 1.0205]) dummy BN running_var is tensor 1.0205])Copy the code

The initial value of running_mean is 0 and changes after forward. At the same time, running_mean of BN is simulated,running_var is also consistent with PyTorch implementation.

This is where momentum is used. After PyTorch 0.4.1, add the num_batches_tracked attribute to count how many minibatches BN is forward. When momentum is set to None, num_batches_tracked controls the ratio of historical statistics to the current minibatch:

Next, simulate the process manually:

Running_mean = torch. Zeros (3) running_var = torch. Ones_like (running_mean) num_batches_tracked = 0 # Momentum set to None, Bn_layer = nn.BatchNorm2d(num_features=3, Momentum =None) # also simulate forward 10 times for t in range(10): inputs = create_inputs() bn_outputs = bn_layer(inputs) inputs_mean, inputs_var, _ = dummy_bn_forward( inputs, bn_layer.weight, bn_layer.bias, Bn_layer. Eps) num_batches_tracked += 1 # exponential_average_factor eaf = 1.0 / num_batches_tracked N = elsion.numel () / inputs. Size (1) # update running_var = running_var * (1-eAF) + eaf * inputs_var * n/(n-1) running_mean = running_mean * (1 - eaf) + eaf * inputs_mean assert torch.allclose(running_var, bn_layer.running_var) assert torch.allclose(running_mean, bn_layer.running_mean) bn_layer.train(mode=False) inference_inputs = create_inputs() bn_outputs = bn_layer(inference_inputs) _, _, dummy_outputs = dummy_bn_forward( inference_inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps, running_mean, running_var) assert torch.allclose(dummy_outputs, bn_outputs) print(f'bn_layer running_mean is {bn_layer.running_mean}') print(f'dummy bn running_mean is {running_mean}')  print(f'bn_layer running_var is {bn_layer.running_var}') print(f'dummy bn running_var is {running_var}')Copy the code

Output:

Bn_layer running_mean is tensor([-0.0040, 0.0074, -0.0162]) dummy BN Running_mean is tensor([-0.0040, 0.0074, Dummy BN Running_var is tensor([1.0097, 1.0086, 0.9815]) BN_layer running_var is tensor([1.0097, 1.0086, 0.9815]) 0.9815])Copy the code

Manual simulation results are the same as PyTorch.

2.1.4  , The update of

The BatchNormweight.biasThey correspond to what’s in the formula , , the update method is gradient descent.

import torchvision from torchvision.transforms import Normalize, ToTensor, Compose the import torch. Nn. The functional as F the from the torch. The utils. Data. The dataloader import dataloader # use mnist as toy dataset mnist = torchvision.datasets.MNIST(root='mnist', download=True, transform=ToTensor()) dataloader = DataLoader(dataset=mnist, Sequential(nn.linear (28 ** 2, 128), nn.batchnorm1d (128), nn.relu (), nn.Linear(128, 10), nn.Sigmoid()) optimizer = torch.optim.SGD(toy_model.parameters(), Lr =0.1) bn_1d_layer = toy_model[1] print(f'Initial weight is {bn_layer.weight[:4].tolist()}... ') print(f'Initial bias is {bn_layer.bias[:4].tolist()}... For (I, data) in enumerate(dataloader): Output = toy_model(data[0].view(data[0].shape[0], -1)) (f.cros_entropy (output, data[1])).backward() # Print (f' gradient of weight is {bn_1d_layer.weight. Grad [:4].tolist()}... ') print(f'Gradient of bias is {bn_1d_layer.bias.grad[:4].tolist()}... ') optimizer.step() optimizer.zero_grad() if i == 1: break print(f'\nNow weight is {bn_1d_layer.weight[:4].tolist()}... ') print(f'Now bias is {bn_1d_layer.bias[:4].tolist()}... ') inputs = torch.randn(4, 128) bn_outputs = bn_1d_layer(inputs) new_bn = nn.BatchNorm1d(128) bn_outputs_no_weight_bias = new_bn(inputs) assert not  torch.allclose(bn_outputs, bn_outputs_no_weight_bias)Copy the code

Output:

Initial weight is [0.9999354481697083, 1.0033478736877441, 1.0019147396087646, 0.9986106157302856] [0.9999354481697083, 1.0033478736877441, 1.0019147396087646, 0.9986106157302856] Initial bias is [-0.0012734815245494246, 0.001349383033812046, 0.0013358002761378884, -0.00148777367547154] Initial bias is [-0.0012734815245494246, 0.001349383033812046, 0.0013358002761378884, -0.00148777367547154] Gradient of weight is [-0.0004475426103454083, -0.0021388232707977295, -0.0032624618615955114, 0.0009599098702892661]. Gradient of bias is [0.00011698803427862003, -0.001291472464799881, -0.0023048489820212126, -0.0009493136312812567]... Gradient of weight is [-0.00035325769567862153, -0.0014295700239017606, -0.002102235099300742, 0.000851186050567776]... Gradient of bias is [-0.00026844028616324067, -0.00025666248984634876, -0.0017800561618059874, 0.00024933076929301023]... Now weight is [1.0000154972076416, 1.0037046670913696, 1.0024511814117432, 0.9986214637756348]... Now BIAS is [-0.0012583363568410277, 0.0015041964361444116, 0.0017442908138036728, -0.0006448794738389552] Now BIAS is [-0.0012583363568410277, 0.0015041964361444116, 0.0017442908138036728, -0.0006448794738389552]...Copy the code

2.1.5 eval mode

All of the above validates BN’s performance in train mode. Eval mode takes several important parameters.

  • track_running_statsThe default isTrue, in train moderunning_meanandrunning_varStatistics are used in eval mode 和 . Set toFalse, the EVAL mode directly calculates the mean and variance of the input.
  • running_mean,running_var: Statistics in train mode.

That is, BN.training is not the only parameter that determines BN behavior. If bN. training or not Bn. track_running_STATS is met, the variance of the input data’s mean will be calculated directly, otherwise statistics will be used instead.

Bn_layer.train (mode=False) inference_inputs = create_inputs() Print (f'bn_layer running_mean is {bn_layer.running_mean}') print(f'bn_layer running_var is {bn_layer.running_var}') bn_outputs = bn_layer(inference_inputs) print(f'Now bn_layer running_mean is {bn_layer.running_mean}') print(f'Now bn_layer running_var is {bn_layer.running_var}') # Outputs running_mean and running_var instead of running_mean and running_var _, _, dummy_inputs = dummy_BN_forward (Inference_inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps, running_mean, running_var) assert torch.allclose(dummy_outputs, Bn_outputs) # After track_running_stats is turned off, even in eval mode, Track_running_stats = False BN_outputs_notrack = BN_layer (inference_INPUTS) _, _, dummy_outputs_notrack = dummy_bn_forward( inference_inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps) assert torch.allclose(dummy_outputs_notrack, bn_outputs_notrack) assert not torch.allclose(bn_outputs, bn_outputs_notrack)Copy the code

The following output is displayed:

Bn_layer running_means tensor([-0.0143, 0.0089, -0.0062]) Bn_layer running_var is tensor([0.9611, 1.0380, Now bn_layer running_var is tensor([0.9611, 1.0380, 1.0181])Copy the code

2.2 BatchNormNd class

Including BatchNorm1d, BatchNorm2d, and BatchNorm3d. The difference is only to check the validity of the input, here is a brief paste BatchNorm2d implementation:

class BatchNorm2d(_BatchNorm): def _check_input_dim(self, input): if input.dim() ! = 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim()))Copy the code

BatchNorm1d accepts 2D or 3D input, BatchNorm2d accepts 4D input, and BatchNorm3d accepts 5D input.

PyTorch implementation of SyncBatchNorm

The performance of BN is closely related to batch size. The larger the batch size is, the more accurate the statistic of BN will be. However, tasks such as detection use a high amount of video memory, and a graphics card often can only be trained with a small number of images (say, 2), which leads to poor performance of BN. One solution is to SyncBN: all cards share the same BN to get global statistics.

PyTorch SyncBN respectively in the torch/nn/modules/batchnorm py and torch/nn/modules / _functions py made implementation. The former is responsible for checking input validity and calling the latter based on Settings such as momentum. The latter is responsible for calculating single-card statistics and interprocess communication.

class SyncBatchNorm(_BatchNorm): Def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None): super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats) self.process_group = process_group # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used # under supported condition (single GPU per process) self.ddp_gpu_size = None def _check_input_dim(self, input): if input.dim() < 2: raise ValueError('expected at least 2D input (got {}D input)' .format(input.dim())) def _specify_ddp_gpu_num(self, gpu_size): if gpu_size > 1: raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process') self.ddp_gpu_size = gpu_size def  forward(self, input): if not input.is_cuda: raise ValueError('SyncBatchNorm expected input tensor to be on GPU') self._check_input_dim(input) # exponential_average_factor is set to self.momentum # (when it is available) only so that it gets updated # in ONNX graph If self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked. Item () else: # If track_running_stats is disabled or in train mode, Training or not self.track_running_stats if need_sync: process_group = torch.distributed.group.WORLD if self.process_group: process_group = self.process_group world_size = torch.distributed.get_world_size(process_group) need_sync = world_size > 1 # If no synchronization is required, SyncBN behaves the same as normal BN. return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) else: if not self.ddp_gpu_size: raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel') return sync_batch_norm.apply( input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, Exponential_average_factor, process_group, world_size) @classmethod def convert_sync_batchnorm(CLS, module, process_group=None): module_output = module if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): module_output = torch.nn.SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group) if module.affine: with torch.no_grad(): module_output.weight.copy_(module.weight) module_output.bias.copy_(module.bias) # keep requires_grad unchanged module_output.weight.requires_grad = module.weight.requires_grad module_output.bias.requires_grad = module.bias.requires_grad module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked for name, child in module.named_children(): module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group)) del module return module_outputCopy the code

3.1 forward

To review how variance is calculated:

The BN on a single card calculates the mean and variance of the input for that card and Normalize it. SyncBN needs to get global statistics, namely the mean and variance of “inputs on all cards”. A simple idea is to do it in two steps:

  1. Each card calculates its mean separately and then synchronizes to get the global mean
  2. The global mean was used to calculate the variance corresponding to each card, and then a synchronization was done to get the global variance

But two synchronizations take more time, and in fact one can be done 和 The calculation of the:

You just have to do it in sync 和 Can. Here’s a diagram to illustrate the process.

When implemented, batchnorm.SyncBatchNorm sets parameters based on its own hyperparameter Settings, train/eval, etc., and calls _functions.SyncBatchNorm. Def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): First, calculate the mean and variance on the single card:

1/(SQRT (var+eps)) mean, invstd = torch. Batch_norm_stats (input, eps)Copy the code

Mean_all and INVSTD_ALL are obtained by synchronising the data of each card. Then calculate the global statistics and update runNING_mean and running_var:

Batch_norm_gather_stats_with_counts (input, mean_all, invstD_all, running_mean, running_var, momentum, eps, count_all.view(-1).long().tolist() )Copy the code

3.2 backward

Since different processes share the same set of BN parameters, communication of processes needs to be done before and after BACKWARD to BN, which is realized in _functions.syncbatchnorm:

# calculate local stats as well as grad_weight / grad_bias
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
    grad_output,
    saved_input,
    mean,
    invstd,
    weight,
    self.needs_input_grad[0],
    self.needs_input_grad[1],
    self.needs_input_grad[2]
)
Copy the code

Figure out the gradient of weight, bias, and , Used to calculateThe gradient of:

# all_reduce calculate gradient sum_dy_all_reduce = the SUM of the torch. The distributed. All_reduce (sum_dy, torch. Distributed. ReduceOp. The SUM, process_group, async_op=True) sum_dy_xmu_all_reduce = torch.distributed.all_reduce( sum_dy_xmu, torch.distributed.ReduceOp.SUM, process_group, async_op=True) # ... # According to the total size, Average divisor = count_tensor. Sum () mean_dy = sum_dy/divisor mean_dy_xmu = sum_dy_xmu/divisor # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu )Copy the code