0 x00 the

NVIDIA Megatron is a distributed training framework based on PyTorch, which is used to train the large Transformer language model. Through the comprehensive application of data parallelism, Tensor parallelism and Pipeline parallelism, it is worth analyzing the mechanism behind GPT3.

This series will consist of 6 ~ 7 articles, which will be studied through papers and source code. This article takes a look at how Megatron handles model parallelism.

Other articles in this series are:

Model Parallel Distributed Training Megatron (1) — Thesis & Foundation

Model parallel distributed training Megatron (2) — Overall architecture

0x01 Parallel Transformer layer

In the paper, we learned that as the model grew larger, its size far exceeded the processor’s memory limit, leading to memory management techniques such as Activation Checkpointing. Model parallelism, on the other hand, overcomes the memory limit of a single processor by slicing the model in various ways so that model weights and their associated optimizer states can be spread across multiple devices.

The ParallelTransformerLayer is a parallel implementation of the Transformer layer, so let’s move on.

1.1 the initialization

ParallelTransformerLayer initialization method, set up as follows:

  • Generate a LayerNorm to process the input data.
  • Generate parallel Attention.
  • Generate a LayerNorm that handles attention output.
  • If it is a decoder, a ParallelAttention is generated.
  • Generate a parallel MLP.
class ParallelTransformerLayer(MegatronModule) :
    """A single transformer layer. Transformer layer takes input with size [b, s, h] and returns an output of the same size. """

    def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding) :
        args = get_args()

        super(ParallelTransformerLayer, self).__init__()
        self.layer_number = layer_number
        self.layer_type = layer_type

        self.apply_residual_connection_post_layernorm \
            = args.apply_residual_connection_post_layernorm

        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

        # Layernorm on the input data.
        self.input_layernorm = LayerNorm( Generate a LayerNorm to process the input data
            args.hidden_size,
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)

        # Self attention.
        self.self_attention = ParallelAttention( Generate parallel Attention
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion

        # Layernorm on the attention output
        self.post_attention_layernorm = LayerNorm( Generate a LayerNorm that handles attention output
            args.hidden_size,
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)

        if self.layer_type == LayerType.decoder: # if this layer is decoder
            self.inter_attention = ParallelAttention( # generates a ParallelAttention
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)

        # MLP
        self.mlp = ParallelMLP(init_method, Generate a parallel MLP
                               output_layer_init_method)
Copy the code

The corresponding is:

1.2 Forward propagation

The forward propagation method is as follows, which is to call various member functions for forward operations.

def forward(self, hidden_states, attention_mask,
            encoder_output=None, enc_dec_attn_mask=None,
            inference_params=None) :
    # hidden_states: [b, s, h]

    # Layer norm at the beginning of the transformer layer.
    layernorm_output = self.input_layernorm(hidden_states) # process input
    
    # Self attention.
    attention_output, attention_bias = \ # attention operation
        self.self_attention(
            layernorm_output,
            attention_mask,
            inference_params=inference_params)

    # Residual connection
    if self.apply_residual_connection_post_layernorm:
        residual = layernorm_output #norm and the result is X
    else:
        residual = hidden_states # original input X

    # jit scripting for a nn.module (with dropout) is not
    # trigerring the fusion kernel. For now, we use two
    # different nn.functional routines to account for varying
    # dropout semantics during training and inference phases.
    if self.bias_dropout_fusion: # dropout operation
        if self.training:
            bias_dropout_add_func = bias_dropout_add_fused_train
        else:
            bias_dropout_add_func = bias_dropout_add_fused_inference
    else:
        bias_dropout_add_func = get_bias_dropout_add(self.training)

    # re-enable torch grad to enable fused optimization.
    with torch.enable_grad():
        layernorm_input = bias_dropout_add_func( # dropout operation
            attention_output,
            attention_bias.expand_as(residual),
            residual,
            self.hidden_dropout)

    # Layer norm post the self attention.
    layernorm_output = self.post_attention_layernorm(layernorm_input) # Handle attention output

    if self.layer_type == LayerType.decoder:
        attention_output, attention_bias = \
            self.inter_attention(layernorm_output,
                                 enc_dec_attn_mask,
                                 encoder_output=encoder_output)
        # residual connection
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        # re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

        # Layer norm post the decoder attention
        layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

    # MLP.
    mlp_output, mlp_bias = self.mlp(layernorm_output) # MLPS operation

    # Second residual connection.
    if self.apply_residual_connection_post_layernorm: # Residual operation
        residual = layernorm_output
    else:
        residual = layernorm_input

    # re-enable torch grad to enable fused optimization.
    with torch.enable_grad():
        output = bias_dropout_add_func( # dropout operation
            mlp_output,
            mlp_bias.expand_as(residual),
            residual,
            self.hidden_dropout)

    return output
Copy the code

0 x02 parallel MLP

The ParallelTransformerLayer contains Attention and MLP, which are analyzed here due to space constraints. For Attention, we’ll take a quick look at the line sharding mechanism. After all, we want to understand model parallelism, not Transformer.

Megatron’s parallel MLP consists of two linear layers. The first linear layer implements the conversion from Hidden Size to 4 x Hidden Size, and the second linear layer implements the conversion from 4 x Hidden Size back to Hidden Size. The logic of specific MLPS is as follows:

Figure: MLP with model parallelism. F and G represent operations associated with communication chunking, which are conjugate. The forward propagation of F is an identity operator, and the backward propagation is an All-reduce; the forward propagation of G is an All-reduce, and the backward propagation is an identity operator. Here f is from ColumnParallelLinear, g is from RowParallelLinear. That is, MLP is a combination of ColumnParallelLinear and RowParallelLinear.

So, the question here is: how do you slice these two linear layers onto different GPU cards? See above, the second option is used here,

Another option would be to split A along the column and get A=[A1, A2]A=[A_1, A_2]A=[A1, A2]. This partition allows GeLU nonlinear to be applied independently to the GEMM output of each partition:


[ Y 1 Y 2 ] = [ G e L U ( X A 1 ) . G e L U ( X A 2 ) ] \begin{bmatrix} Y_1& Y_2 \end{bmatrix}= \begin{bmatrix} GeLU(XA_1),GeLU(XA_2) \end{bmatrix}

This method is better because it removes the synchronization points and simply concatenates the output of the two GELUs. Therefore, we divide the first GEMM in this column parallelism and split the second GEMM along its row so that it directly gets the output of the GeLU layer without any other communication (for example, all-reduce), as shown in the figure.

Let’s take a closer look at why we chose this option.

According to conventional logic, MLP forward propagation should be divided into two phases, corresponding to the two lines in the figure below,

  • The first line splits the parameter A into columns, and then splices the result into columns. The result is exactly equivalent to the result without using the parallel strategy.
  • The second line splits the activation Y by column, the parameter B by row splits in parallel, and finally adds the output to get Z.

But each split results in two additional communications (one for forward propagation and one for back propagation; only forward propagation is shown below). For the second line, the input Y is essentially XA1 and XA2 parallel, so in order to reduce the traffic, we can delay the data communication or cancel the communication altogether, that is, omit the all_gather at the end of the first line and the initial split in the second line. This is really the transitivity and associativity of mathematics (the sum of local sums is the global sum). So we get to the second option in the paper.

Combined with the code, it is:

  • ColumnParallelLinear implements the first half of the MLP or considers the case where this linear layer is used independently.
  • RowParallelLinear implements the second half of the MLP or considers cases where this linear layer is used independently.

2.1 Naming Conventions

Let’s first look at the naming conventions, which are used as follows:

  • h: hidden size
  • n: number of attention heads
  • p: number of model parallel partitions
  • np: n/p
  • hp: h/p
  • hn: h/n
  • b: batch size
  • s: sequence length
  • l: number of layers
  • Transformer’s input size is [s, b, h] and returns a tensor of the same size. We use HyperParameters as hyperparameters for Transformer.

2.2 the MLP code

2.2.1 initialization

Megatron/model/transformer. Of py ParallelMLP are defined as follows:

  • Define a ColumnParallelLinear to perform the first H to 4 H conversion.
  • And then a gelu.
  • And then RowParallelLinear is used to convert 4H to H back.

Dropout operation is performed in forward of the ParallelTransformerLayer above.

Therefore, MLP is roughly shown in the figure, where A and B are their respective weight matrices:

That’s the figure in the paper.

Here’s the code.

class ParallelMLP(MegatronModule) :
    """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """

    def __init__(self, init_method, output_layer_init_method) :
        super(ParallelMLP, self).__init__()
        args = get_args()

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear( # column segmentation
            args.hidden_size,
            args.ffn_hidden_size,
            gather_output=False.# here is false, use the second option
            init_method=init_method,
            skip_bias_add=True)

        self.bias_gelu_fusion = args.bias_gelu_fusion # gelu
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear( # line segmentation
            args.ffn_hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True)
Copy the code

2.2.2 Forward Operation

Here ColumnParallelLinear converts H to 4H, and RowParallelLinear converts 4H to H.

def forward(self, hidden_states) :

    # [s, b, 4hp]
    intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) # Vertical segmentation

    if self.bias_gelu_fusion:
         intermediate_parallel = \
                 bias_gelu_impl(intermediate_parallel, bias_parallel)
    else:
        intermediate_parallel = \
            self.activation_func(intermediate_parallel + bias_parallel)

    # [s, b, h]
    output, output_bias = self.dense_4h_to_h(intermediate_parallel) # Horizontal segmentation
    return output, output_bias
Copy the code

ColumnParallelLinear and RowParallelLinear are introduced next. ColumnParallelLinear can be used independently or as the first half of ParallelMLP, and RowParallelLinear can be used independently or as the second half of ParallelMLP.

0x03 ColumnParallelLinear

ColumnParallelLinear means to divide by column, which is the longitudinal knife flow. Notice that we’re talking about column sharding of weights. Is this:


Y = X A = X [ A 1 . A 2 ] = [ X A 1 . X A 2 ] Y = XA = X[A_1, A_2] = [XA_1, XA_2]

The specific segmentation is as follows:

3.1 define

For Y=XA+bY =XA+bY =XA+b, A is parallelized in the following way: A=[A1,…,Ap]A = [A_1, …, A_p]A=[A1,…,Ap]

class ColumnParallelLinear(torch.nn.Module) :
    """Linear layer with column parallelism. The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p]. Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias gather_output: If true, call all-gether on output and make Y avaiable to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i init_method: method to initialize weights. Note that bias is always set to zero. stride: For the strided linear layers. keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization. skip_bias_add: This was added to enable performance optimations where bias can be fused with other elementwise operations. we skip adding bias but instead return it. """
Copy the code

3.2 the initialization

The initialization code mainly uses sharded information to initialize weights.

def __init__(self, input_size, output_size, bias=True, gather_output=True,
             init_method=init.xavier_normal_, stride=1,
             keep_master_weight_for_test=False,
             skip_bias_add=False) :
    super(ColumnParallelLinear, self).__init__()

    # Keep input parameters
    self.input_size = input_size
    self.output_size = output_size
    self.gather_output = gather_output
    # Divide the weight matrix along the last dimension.
    world_size = get_tensor_model_parallel_world_size() Get the world size of this tensor's parallel group
    self.output_size_per_partition = divide(output_size, world_size) # Print size to get the submodel
    self.skip_bias_add = skip_bias_add

    # Parameters.
    # Note: torch.nn.functional.linear performs XA^T + b and as a result
    # we allocate the transpose.
    # Initialize weight.
    args = get_args()
    if args.use_cpu_initialization:
        Initialize the weights with the shard size
        self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                            self.input_size,
                                            dtype=args.params_dtype))
        self.master_weight = _initialize_affine_weight_cpu( Initialize the weight
            self.weight, self.output_size, self.input_size,
            self.output_size_per_partition, 0, init_method,
            stride=stride, return_master_weight=keep_master_weight_for_test)
    else:
        Initialize the weights with the shard size
        self.weight = Parameter(torch.empty(
            self.output_size_per_partition, self.input_size,
            device=torch.cuda.current_device(), dtype=args.params_dtype))
        _initialize_affine_weight_gpu(self.weight, init_method, Initialize the weight
                                      partition_dim=0, stride=stride)

    if bias:
        if args.use_cpu_initialization:
            Initialize the weights with the shard size
            self.bias = Parameter(torch.empty(
                self.output_size_per_partition, dtype=args.params_dtype))
        else:
            Initialize the weights with the shard size
            self.bias = Parameter(torch.empty(
                self.output_size_per_partition,
                device=torch.cuda.current_device(),
                dtype=args.params_dtype))
        set_tensor_model_parallel_attributes(self.bias, True.0, stride)
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
    else:
        self.register_parameter('bias'.None)
    self.async_tensor_model_parallel_allreduce = (
            not args.no_async_tensor_model_parallel_allreduce and
            world_size > 1)
Copy the code

3.2.1 segmentation size

Self.output_size_per_partition = divide(output_size, world_size) There is a split size operation to get the weight size that each submodel should have.

def ensure_divisibility(numerator, denominator) :
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0.'{} is not divisible by {}'.format(
        numerator, denominator)


def divide(numerator, denominator) :
    """Ensure that numerator is divisible by the denominator and return the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator
Copy the code

3.2.2 Initializing weights

The following code implements the initialization weights.

def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1) :
    """Initialize affine weight for model parallel on GPU."""

    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

    with get_cuda_rng_tracker().fork():
        init_method(weight)


def _initialize_affine_weight_cpu(weight, output_size, input_size,
                                  per_partition_size, partition_dim,
                                  init_method, stride=1,
                                  return_master_weight=False) :
    """Initialize affine weight for model parallel. Build the master weight on all processes and scatter the relevant chunk."""

    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
                                dtype=torch.float,
                                requires_grad=False)
    init_method(master_weight)
    args = get_args()
    master_weight = master_weight.to(dtype=args.params_dtype)

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
    rank = get_tensor_model_parallel_rank()
    world_size = get_tensor_model_parallel_world_size()
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)
    if return_master_weight:
        return master_weight
    return None
Copy the code

3.3 Logical Sorting

For better analysis, we introduce the figure below (from Reference 1) that corresponds to the forward and backward propagation of the ColumnParallelLinear class. The f and G operations here are actually abstracted from the code, which can be understood as f is the processing of input, and G is the final output after processing. This corresponds to the bold type described in the paper:

Figure 3. Blocks of Transformer with Model Parallelism. f and g are conjugate. f is an identity operator in the forward pass and all reduce in the backward pass while g is an all reduce in the forward pass and identity in the backward pass.

GTC 2020: Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

Let’s sort out the logic of the figure above.

3.3.1 Forward propagation

We’ll refine it step by step.

First, the overall semantics are: Y = XA + b.

Secondly, the logic of forward propagation is as follows:

  • Input: Here A is shred along the column and X is all the input (every GPU has the same X).
  • Calculation: After calculation, the output Y1,Y2Y_1, Y_2Y1,Y2 are also shred by column. Each GPU has only its own partition.
  • Output: Y1,Y2Y_1, Y_2Y1,Y2 must be combined to get the final output Y.

Again, we use operator to refine:

  • Input: Since each GPU needs to get a complete input X, X needs to be distributed to each GPU in the forward operation, thus using the Identity operation.
  • Calculation: After calculation, the output Y1,Y2Y_1, Y_2Y1,Y2 are also shred by column. Each GPU has only its own partition.
  • Output: because Y1,Y2Y_1, Y_2Y1,Y2 need to be combined to get the final output Y. So you need an all-Gather operation to aggregate Y=[Y1,Y2] Y=[Y_1, Y_2]Y=[Y1,Y2].

We label these logical points with red boxes on the figure above. Input X is first processed by F, and output Y is the result of g integration.

3.3.2 Backward propagation

Let’s look at backward propagation. For the figure above, backward propagation is from top to bottom, and the gradient goes through G and is processed by F.

The logic of back propagation is as follows:

  • Now we have the gradient ∂L∂Y\frac{\partial L}{\partial Y}∂Y∂L \ \partial L}{\partial Y}∂Y∂L \ \partial L}{\partial Y}∂Y∂L \ Ensure that there is a gradient partial L partial Yi\frac{\partial L}{\partial Y_i} partial Yi\ L on each GPU. The operation is ∂L∂Yi(split)\frac{\partial L}{\partial Y_i}(split)∂Yi∂L(split).
  • A gradient calculation of X is performed on each GPU, so each GPU has a gradient of X (but its content is different).
  • Finally, the gradient of X on each GPU needs to be added to obtain the complete gradient, which requires an All-reduce operation. Namely partial L partial X = partial L partial X ∣ 1 + partial L partial X ∣ 2 \ frac {\ partial L}} {\ partial X = \ frac {\ partial L}} {\ partial X | _1 + \ frac {\ partial L} {\ partial X} | _2 partial partial X L = partial partial X ∣ L 1 + partial partial X ∣ 2 L

So we indicate on the diagram the operators corresponding to backward propagation with blue rounded rectangles.

3.4 Code Implementation

Let’s look at the code.

3.3.1 ColumnParallelLinear

In the forward code of ColumnParallelLinear, it mainly implements the forward operation of F and G, and sets up the backward operation of F and G, as follows:

  • If configured with asynchronous operations, use ColumnParallelLinearWithAsyncAllreduce complete f the function of the operator, this function includes identity operation, matrix multiplication, build to spread after operation.
  • If the operation is synchronous, then:
    • Use copy_to_tensor_model_PARALLEL_region to complete the forward propagation identity operation and establish the back-propagation all-reduce, which is the BACKWARD of F in the figure. The identity operation is a complete copy of input X on multiple Gpus, similar to X becoming [X, X… X] through the forward operation of F.
    • Perform matrix multiplication with Linear on [X, X… X] and weight A.
  • ifgather_outputIf it is True, the forward propagation is transmitted
    Y i Y_i
    To do all-gather, since complete gradient scatter is needed to be scattered on the corresponding GPU during reverse propagation, the appropriate split operation should be built. In MLP implementation, the value is set to False, so that each GPU outputs 4h/ P of its own partition, which is directly transmitted to the next linear layer.
def forward(self, input_) :
    # If bias is ignored, it is set to None and no further processing is needed
    bias = self.bias if not self.skip_bias_add else None

    # Below is mainly the f operation in the figure
    if self.async_tensor_model_parallel_allreduce:
        # Create asynchronous all-reduce for backpropagation
        input_shape = input_.shape
        input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
        # Maxtrix multiply with asynchronouse all-reduce execution
        output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
                input_, self.weight, bias)
        output_parallel = output_parallel.view(
                input_shape[0], input_shape[1], output_parallel.shape[1])
    else:
        # Set up backprop all-reduce.
        # Establish back-propagation all-reduce, which is the BACKWARD of F in the figure
        input_parallel = copy_to_tensor_model_parallel_region(input_) 

        # Matrix multiply.
        output_parallel = F.linear(input_parallel, self.weight, bias) # matrix multiplication operation

    # Below is the g operation in the figure
    if self.gather_output: # Whether the aggregation operation is required
        # All-gather across the partitions.
        The aggregate output is the forward of g in the figure
        output = gather_from_tensor_model_parallel_region(output_parallel) #
    else:
        output = output_parallel
        
    output_bias = self.bias if self.skip_bias_add else None # If bias is not ignored, spread the word
    return output, output_bias
Copy the code

3.3.2 rainfall distribution on 10-12 f operation

F operation is to perform preliminary processing on the input, specifically:

  • Copy directly when propagating forward.
  • Perform all-reduce for back-propagation.

3.3.2.1 Synchronizing Operations

Here we mainly analyze copy_to_tensor_model_parallel_region, which does forward copy and builds backward all-reduce.

def copy_to_tensor_model_parallel_region(input_) :
    return _CopyToModelParallelRegion.apply(input_)
Copy the code

We still need to look at _CopyToModelParallelRegion. As can be seen, its forward is simply transferring the input to the output, corresponding to the forward copy identity.

class _CopyToModelParallelRegion(torch.autograd.Function) :
    """Pass the input to the model parallel region."""

    @staticmethod
    def symbolic(graph, input_) :
        return input_
    
    @staticmethod
    def forward(ctx, input_) :
        return input_ # Simply transfer input to output, which corresponds to forward copy identity

    @staticmethod
    def backward(ctx, grad_output) :
        return _reduce(grad_output) # For backpropagation, the input is the whole gradient on multiple Gpus, combined by all-reduce
Copy the code

The corresponding backward propagation uses All-reduce. In reverse propagation, the input is the whole gradient on multiple Gpus, which is merged through All-Reduce.

def _reduce(input_) :
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size()==1:
        return input_

    # All-reduce.
    torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())

    return input_
Copy the code
3.3.2.2 asynchronous All – Reduce

ColumnParallelLinearWithAsyncAllreduce placed the multiplication operation of synchronous also come in here.

class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function) :
    """ Column-parallel linear layer execution with asynchronous all-reduce execution in backprop. """
    @staticmethod
    def forward(ctx, input, weight, bias) :
        ctx.save_for_backward(input, weight)
        ctx.use_bias = bias is not None
        output = torch.matmul(input, weight.t()) # Syncing multiplication is also here
        if bias is not None:
            output = output + bias
        return output

    @staticmethod
    def backward(ctx, grad_output) :
        input, weight = ctx.saved_tensors
        use_bias = ctx.use_bias
        grad_input = grad_output.matmul(weight)
        # Asyncronous all-reduce
        handle = torch.distributed.all_reduce( # Backpropagation operation
                grad_input, group=get_tensor_model_parallel_group(), async_op=True)
        # Delay the start of weight gradient computation shortly (3us) to have
        # all-reduce scheduled first and have GPU resources allocated
        _ = torch.empty(1, device=grad_output.device) + 1
        grad_weight = grad_output.t().matmul(input)
        grad_bias = grad_output.sum(dim=0) if use_bias else None
        handle.wait()
        return grad_input, grad_weight, grad_bias
Copy the code

3.3.3 g operation

The following corresponds to the g operation in the figure. The G operation is the final output Y, and the logic is:

  • Do all-gather when spreading forward;
  • Backward propagation requires split and gradient scatter to different Gpus.

def gather_from_tensor_model_parallel_region(input_) :
    return _GatherFromModelParallelRegion.apply(input_)
Copy the code

The specific code is as follows:

class _GatherFromModelParallelRegion(torch.autograd.Function) :
    """Gather the input from model parallel region and concatinate."""

    @staticmethod
    def symbolic(graph, input_) :
        return _gather(input_)
    
    @staticmethod
    def forward(ctx, input_) :
        return _gather(input_)

    @staticmethod
    def backward(ctx, grad_output) :
        return _split(grad_output)
Copy the code

3.3.4 Basic functions

Let’s now look at some of the basic functions used above.

3.3.4.1 gather

The _Gather is concatenated along the last dimension.

def _gather(input_) :
    """Gather tensors and concatinate along the last dimension."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size==1:
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
    rank = get_tensor_model_parallel_rank() Get the rank of this worker in tensor parallel

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
    Do all-Gather operations between your tensor process groups
    torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output
Copy the code
3.3.4.2 split

_split completes the tensor split operation.

def split_tensor_along_last_dim(tensor, num_partitions,
                                contiguous_split_chunks=False) :
    """Split a tensor along its last dimension. Arguments: tensor: input tensor. num_partitions: number of partitions to split the tensor contiguous_split_chunks: If True, make each chunk contiguous in memory. """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions) Get the size of each shard
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # Syncopation of tensors
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list

def _split(input_) :
    """Split the tensor along its last dimension and keep the corresponding slice."""

    world_size = get_tensor_model_parallel_world_size() Get the world size of this tensor process group
    # Bypass the function if we are using only 1 GPU.
    if world_size==1:
        return input_

    # Split along last dimension.
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
    rank = get_tensor_model_parallel_rank() Get your own rank
    output = input_list[rank].contiguous() # get your own rank

    return output
Copy the code

Get_tensor_model_parallel_rank is used to get the rank of this process in the tensor parallel group.

def get_tensor_model_parallel_rank() :
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
Copy the code

0x04 RowParallelLinear

RowParallelLinear here is dividing by row, which is the cross knife flow, and notice that this is dividing by row with the weight A. Let’s say Y = XA, where X is the input, A is the weight, and Y is the output, and row splitting is the partitioning of the first dimension of A, where the last dimension of X1X_1X1 is equal to the first dimension of A1A_1A1.


X A = [ X 1 . X 2 ] [ A 1 A 2 ] = X 1 A 1 + X 2 A 2 = Y 1 + Y 2 = Y XA = \begin{bmatrix}X_1,X_2\end{bmatrix} \begin{bmatrix}A_1 \\ A_2\end{bmatrix} = X_1 A_1 + X_2 A_2 = Y_1 + Y_2 = Y

Details are as follows:

4.1 define

Only comments are useful in the definition, so you can see how to shard.

class RowParallelLinear(torch.nn.Module) :
    """Linear layer with row parallelism. The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension as: - - | A_1 | | . | A = | . | X = [X_1, ..., X_p] | . | | A_p | - - Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. Note that bias is not parallelized. input_is_parallel: If true, we assume that the input is already split across the GPUs and we do not split again. init_method: method to initialize weights. Note that bias is always set to zero. stride: For the strided linear layers. keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization. skip_bias_add: This was added to enable performance optimization where bias can be fused with other elementwise operations. We skip adding bias but instead return it. """
Copy the code

4.2 the initialization

Similar to column sharding, initialization is about getting the size of each weight partition and then sharding the weights accordingly.

def __init__(self, input_size, output_size, bias=True,
             input_is_parallel=False,
             init_method=init.xavier_normal_, stride=1,
             keep_master_weight_for_test=False,
             skip_bias_add=False) :
    super(RowParallelLinear, self).__init__()

    # Keep input parameters
    self.input_size = input_size
    self.output_size = output_size
    self.input_is_parallel = input_is_parallel
    # Divide the weight matrix along the last dimension.
    world_size = get_tensor_model_parallel_world_size()
    self.input_size_per_partition = divide(input_size, world_size) Get the size of each weight partition
    self.skip_bias_add = skip_bias_add

    # Parameters.
    # Note: torch.nn.functional.linear performs XA^T + b and as a result
    # we allocate the transpose.
    # Initialize weight.
    args = get_args()
    if args.use_cpu_initialization:
        self.weight = Parameter(torch.empty(self.output_size,
                                            self.input_size_per_partition,
                                            dtype=args.params_dtype))
        # Syncopated weights
        self.master_weight = _initialize_affine_weight_cpu(
            self.weight, self.output_size, self.input_size,
            self.input_size_per_partition, 1, init_method,
            stride=stride, return_master_weight=keep_master_weight_for_test)
    else:
        self.weight = Parameter(torch.empty(
            self.output_size, self.input_size_per_partition,
            device=torch.cuda.current_device(), dtype=args.params_dtype))
        # Syncopated weights
        _initialize_affine_weight_gpu(self.weight, init_method,
                                      partition_dim=1, stride=stride)
    if bias:
        if args.use_cpu_initialization:
            self.bias = Parameter(torch.empty(self.output_size,
                                              dtype=args.params_dtype))
        else:
            self.bias = Parameter(torch.empty(
                self.output_size, device=torch.cuda.current_device(),
                dtype=args.params_dtype))
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
    else:
        self.register_parameter('bias'.None)
Copy the code

4.3 Logical Sorting

For better analysis, we introduce the figure below (from Reference 1), which corresponds to the forward and backward propagation processes of the RowParallelLinear class. The f and G operations here are actually abstracted from the code, which can be understood as f is the processing of input, and G is the final output after processing.

Let’s sort out the logic of the figure above.

4.3.1 Forward propagation

We’ll refine it step by step.

First, the overall semantics are: Y = XA + b.

Secondly, the logic of forward propagation is as follows:

  • Input: here A splits along the row, and since the dimension of A has changed, X has to do the same, so X has to split by column, so that each of the X segments can be multiplied by each of the A segments. If the input here is already split (input_is_parallel is True), there is no need to split again.
  • Calculation: calculation is Y1=X1A1Y_1 = X_1 A_1Y1=X1A1 and Y2=X2A2Y_2 = X_2A_2Y2=X2A2. After calculation, the output shape of Y1,Y2Y_1, Y_2Y1,Y2 is the final shape of Y. Each GPU has only its own partition.
  • Output: Y1,Y2Y_1, Y_2Y1,Y2 must be combined to get the final output Y. But since Y1,Y2Y_1, Y_2Y1,Y2 all have the same shape, they all have the same shape of Y, we can just add simple matrices.

Again, we use operator to refine:

  • Input: X needs to be split vertically, this is a split operation, and we get [X1,X2][X_1, X_2][X1,X2], these two partitions need to be placed on the two Gpus respectively.
  • Calculation: After calculation, each GPU has only its own partition.
  • Output: because Y1,Y2Y_1, Y_2Y1,Y2 need to be combined to get the final output Y. In this case, you need to add Y1Y_1Y1 and Y2Y_2Y2 (because they are two Gpus, there are wait operations between them), which is an All-reduce operation.

We label these logical points with red boxes on the figure above. Input X is first processed by F, and output Y is the result of g integration.

4.3.2 Backward propagation

Let’s look at backward propagation. For the figure above, backward propagation is from top to bottom, and the gradient goes through G and is processed by F.

The logic of back propagation is as follows:

  • Partial L}{\partial Y}∂Y∂L, because Y1,Y2Y_1, Y_2Y1,Y2 have the same shape, Partial L∂Y\frac{\partial L}{\partial Y} partial Y∂L \frac{\partial L}{\partial Y} partial Y∂L \ The operator is ∂L∂Yi=∂L∂Y(identity)\frac{\partial L}{\partial Y_i} = \frac{\partial L}{\partial Y} (ID entity)∂Yi∂L=∂Y∂L(identity). Here to explain, in forward propagation, XA results need all-reduce, which can be understood as the sum operator. Therefore, in reverse propagation, gradient can be directly copied.
  • A gradient calculation of X is performed on each GPU, so each GPU has a gradient of X (but its content is different).
  • Finally, the gradient of X on each GPU needs to be aggregated to obtain a complete gradient, which is the reverse operation of split in forward, and the gradient is splice according to the last column, namely all-Gather operation.

So we indicate on the diagram the operators corresponding to backward propagation with blue rounded rectangles.

4.4 Code Implementation

Let’s look at how the code works.

4.4.1 RowParallelLinear

In the forward code of RowParallelLinear, it mainly implements the forward operation of F and G, and sets up the backward operation of F and G, as follows:

def forward(self, input_) :
    # Here, the input tensor has been divided into each GPU, and the output tensor is the whole after all-reduce
    # Set up backprop all-reduce.
    if self.input_is_parallel:  # is already the input to split
        # Transformer's MLP reaches here, because it is split, so the input is directly connected without scatter
        input_parallel = input_
    else:
        At this point, forward slice and back splice will be performed
        input_parallel = scatter_to_tensor_model_parallel_region(input_)
        
    # Matrix multiply.
    So let's multiply X_i and A_i
    output_parallel = F.linear(input_parallel, self.weight)
    
    # All-reduce across all the partitions.
    # Carry out the forward all-reduce operation, so that the latest complete results are on each GPU, and the backward identity operation is built at the same time.
    output_ = reduce_from_tensor_model_parallel_region(output_parallel)
    if not self.skip_bias_add:
        # and bias,
        output = output_ + self.bias if self.bias is not None else output_
        output_bias = None
    else:
        output = output_
        output_bias = self.bias
    return output, output_bias
Copy the code

4.4.1 f operation

Scatter_to_tensor_model_parallel_region corresponds to the f operation, which does:

  • Split input forward, while building a backward all-Gather operation.
  • A backward operation performs an all-Gather operation.

The code is:

def scatter_to_tensor_model_parallel_region(input_) :
    return _ScatterToModelParallelRegion.apply(input_)
Copy the code

Specific _ScatterToModelParallelRegion completed the actual business, specific _split, _gather operation are introduced in front.

class _ScatterToModelParallelRegion(torch.autograd.Function) :
    """Split the input and keep only the corresponding chuck to the rank."""

    @staticmethod
    def symbolic(graph, input_) :
        return _split(input_)

    @staticmethod
    def forward(ctx, input_) :
        return _split(input_)

    @staticmethod
    def backward(ctx, grad_output) :
        return _gather(grad_output)
Copy the code

4.4.2 g operation

Reduce_from_tensor_model_parallel_region corresponds to the G operation, which works as follows:

  • The forward operation is all-reduce and then the final output is obtained.

  • Reverse operations directly copy operations.

The code is:

def reduce_from_tensor_model_parallel_region(input_) :
    return _ReduceFromModelParallelRegion.apply(input_)
Copy the code

The specific business is as follows:

class _ReduceFromModelParallelRegion(torch.autograd.Function) :
    """All-reduce the input from the model parallel region."""

    @staticmethod
    def symbolic(graph, input_) :
        return _reduce(input_)
    
    @staticmethod
    def forward(ctx, input_) :
        return _reduce(input_) # Introduced earlier

    @staticmethod
    def backward(ctx, grad_output) :
        return grad_output The indEntity operation copies the input directly onto both Gpus
Copy the code

0x05 Embedding

Let’s look at embedding next. In order to make the memory evenly configured, the shard operation is performed on the embedding according to the VOCab dimension, and the partition is finally placed on multiple Gpus. So each card has a part of the embedded table.

class VocabParallelEmbedding(torch.nn.Module) :
    """Embedding parallelized in the vocabulary dimension. This is mainly adapted from torch.nn.Embedding and all the default values are kept. Arguments: num_embeddings: vocabulary size. embedding_dim: size of hidden state. init_method: method to initialize weights. """

    def __init__(self, num_embeddings, embedding_dim, init_method=init.xavier_normal_) :
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \ Get the start and end positions of the partition
            VocabUtility.vocab_range_from_global_vocab_size(
                self.num_embeddings, get_tensor_model_parallel_rank(),
                self.tensor_model_parallel_size)
        self.num_embeddings_per_partition = self.vocab_end_index - \ Get the number of inserts in the partition
            self.vocab_start_index

        # Allocate weights and initialize.
        args = get_args()
        if args.use_cpu_initialization:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
            _initialize_affine_weight_cpu( Partition the weights
                self.weight, self.num_embeddings, self.embedding_dim,
                self.num_embeddings_per_partition, 0, init_method)
        else:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method, Partition the weights
                                          partition_dim=0, stride=1)
Copy the code

Since each GPU only gets part of the overall embedding, there may be an input that cannot be embedding for each worker. Therefore, an all-reduce operation is required for the final output of embedding to get a complete embedding.

def forward(self, input_) :
        if self.tensor_model_parallel_size > 1:
            # Build the mask.
        		# input_mask means that the word is not in the local partition of the worker, so it is set to 0
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
        if self.tensor_model_parallel_size > 1:
            output_parallel[input_mask, :] = 0.0
        # Reduce across all the model parallel GPUs.
        output = reduce_from_tensor_model_parallel_region(output_parallel)
        return output
Copy the code

0 x06 summary

6.1 MLP parallel

Let’s summarize the parallel implementation of MLP as shown in the figure below, where the logic is as follows:

  • The gray in the middle is the concept map from the paper.
  • After associating the code, we can see that it is done by a ColumnParallelLinear followed by a RowParallelLinear, and we convert the concept diagram to two boxes on the left.
  • ColumnParallelLinear is a column partition of weights, RowParallelLinear is a row partition of weights.
  • ColumnParallelLinear Y1,Y2Y_1, Y_2Y1,Y2 is typed directly to RowParallelLinear instead of all Gather RowParallelLinear has X1,X2X_1, X_2X1,X2, that is, RowParallelLinear has no f operation.
  • The f in the concept diagram is the ColumnParallelLinear f, and the g is the RowParallelLinear g. The specific logic is shown in the figure.

6.2 Conjugate Function

Conjugate functions are mentioned in this paper.

f and g are conjugate. f is an identity operator in the forward pass and all reduce in the backward pass while g is an all reduce in the forward pass and identity in the backward pass.

We used that in the previous code as well, so let’s sort it out as follows, where two and two are conjugate functions of each other.

  • Copy_to_tensor_model_parallel_region is the forward operation copy(identity) and the backward operation all-reduce.

  • Reduce_from_tensor_model_parallel_region is the forward operation all-reduce, and the backward operation copy(identity).

In fact, it’s the f and g operations in MLP, and these two are conjugate functions.

Similarly, gather_FROM_tensor_model_PARALLEL_region is the forward-all-gather operation and the forward-scatter operation, This is also a conjugate function of scatter_to_tensor_model_parallel_region.

The code for these functions is as follows:

def copy_to_tensor_model_parallel_region(input_) :
    return _CopyToModelParallelRegion.apply(input_)


def reduce_from_tensor_model_parallel_region(input_) :
    return _ReduceFromModelParallelRegion.apply(input_)


def scatter_to_tensor_model_parallel_region(input_) :
    return _ScatterToModelParallelRegion.apply(input_)


def gather_from_tensor_model_parallel_region(input_) :
    return _GatherFromModelParallelRegion.apply(input_)
Copy the code

Now that we have completed the analysis of the parallel implementation of the model, the next article will look at how to set up various parallel configurations in the source code.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

Developer.nvidia.com/gtc/2020/sl…

Megatron Papers and Code Analysis (2)

Megatron Papers and Code Analysis (1)

Megatron-lm megatron-LM

Megatron-lm megatron-LM

Megatron learning summary

GTC 2020: Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

Tensor model parallelism in large-scale training Transformer