0 x00 the

PyTorch Zero Redundancy Optimizer is a class of algorithms designed to solve the tradeoff problem between data parallel training and model parallel training. Zero Redundacy Optimizer is an idea derived from Microsoft’s Zero and implemented as FAIRScale-based OSS.

Fairscale implements ZeRO’s three-stage algorithm. Fairscale is an open source project of Facebook AI Research (FAIR), which is personally understood as a testing ground for large-scale deep learning distributed training at Facebook. If one of these modules becomes mature, it is merged into PyTorch.

OSS is Fairscale’s zero-1 implementation, which implements optimizer state sharding (see the red box below). PyTorch is a FairScale-based OSS implementation of ZeroRedundancyOptimizer.

Note: This article is based on PyTorch 1.9.0.

0 x01 history

1.1 making instructions

ZeroRedundancyOptimizer is in github.com/pytorch/pyt… Let’s look at the description.

ZeroRedundancyOptimizer: an implementation of a standalone sharded optimizer wrapper #46750

Implement the first stage of ZeRO, sharding of the optimizer state, as described in this blog post and this paper. This implementation is completely independent from the DeepSpeed framework, and aims at providing ZeRO-compliant building blocks within the PyTorch scheme of things.

This works by:

  • acting as a wrapper to a pytorch optimizer. ZeROptimizer does not optimize anything by itself, it only shards optimizers for distributed jobs
  • each rank distributes parameters according to a given partitioning scheme (could be updated), and owns the update of a given shard only
  • the .step() is called on each rank as expected, the fact that the optimizer actually works on a shard of the model is not visible from the outside
  • when the update is completed, each rank broadcasts the updated model shard to all the other ranks

This can be used with DDP, although some communications are wasted in that case (gradients are all-reduced to all ranks). This implementation was initially developed in Fairscale, and can also be used with an optimized DDP which only reduces to the relevant ranks. More context on ZeRO and PyTorch can be found in this RFC

The API with respect to loading and saving the state is a known pain point and should probably be discussed an updated. Other possible follow ups include integrating more closely to a modularized DDP, making the checkpoints partition-agnostic, exposing a gradient clipping option and making sure that mixed precision states are properly handled.

original authors include @msbaines, @min-xu-ai and myself(blefaudeux )

1.2 analytical

Therefore, we can know the following information:

  • Zero Redundacy Optimizer is derived from Microsoft’s Zero.
  • Fairscale implements ZeRO’s three-stage algorithm. Fairscale is an open source project of Facebook AI Research (FAIR). It is personally understood as a test field for large-scale deep learning distributed training of Facebook. It will be merged into PyTorch.
  • OSS is a Zero-1 implementation of Fairscale, which implements optimizer state sharding.
  • PyTorch is a FairScale-based OSS implementation of ZeroRedundancyOptimizer.

We need to look at it in detail.

0x02 Background

2.1 the ZeRO

ZeRO (ZeRO Redundancy Optimizer, ZeRO Redundacy Optimizer) is part of Microsoft’s open source DeepSpeed, a framework for optimizing large-scale training. ZeRO is a memory optimization approach for deep learning models that seeks an intermediate point between model parallelism and data parallelism to maximize model extensibility.

The optimization of ZeRO involves many aspects of deep learning model memory usage, including active memory, fragmented memory, and model state memory.

  • Model State Memory: The State of deep learning Model can be divided into three basic processes: optimizer State, gradient and parameter.
  • Activation Memory: After optimizing the model state Memory, it was found that Activation functions also caused bottlenecks. Activation function calculations are in forward propagation and are used to support backward propagation.
  • Fragmented Memory: The inefficiencies of deep learning models are sometimes caused by Fragmented Memory. In the model, the life cycle of each tensor is different, and some memory fragmentation will be caused by the change of the life of different tensors. Because of these fragments, even when there is enough available memory, memory allocation fails due to lack of contiguous memory. ZeRO proactively manages memory based on the lifetime of tensors to prevent memory fragmentation.

For example, optimizations can be seen below:

Image source www.microsoft.com/en-us/resea…

2.2 ZeRO implementation of Fairscale

Let’s take a look at Fairscale’s guide.

This is actually a distributed/large-scale machine learning scheme of a comb, from which you can see, based on the ZeRO < https://arxiv.org/pdf/1910.02054.pdf > to achieve the three different kinds of algorithms, corresponding to the three stages of ZeRO:

  • Optimizer State Sharding (OSS) implements Optimizer Sharding, which optimizes memory usage for partition Optimizer State.
  • Sharded Data Parallel (SDP) is responsible for Optimizer + Gradient State Sharding.
  • Fully Sharded Data Parallel (FSDP) implements Optimizer + Gradient + Horizontal Model Sharding.

2.3 Optimizer State Sharding (OSS)

Since OSS is the source of ZeroRedundancyOptimizer, let’s take a look at the thinking. OSS implements optimizer memory-related optimizations. Optimizers like Adam typically need to maintain momentum and variance. Even if you can train with parameters and gradients with FP16 precision, the parameters and gradients need to be saved as FP32 precision. When each rank updates the full model, this means that a significant portion of memory is occupied by redundant representations of optimizer state. To overcome this redundancy, the optimizer state shard needs to divide the model optimization steps between different ranks so that each rank is only responsible for updating the corresponding shard of the model. This in turn ensures that the optimizer state is much smaller per rank and that it does not contain redundant information across ranks.

2.3.1 Training process

The OSS training process can be modified from the DDP execution process as follows:

  1. Wrapped Optimizer shreds the optimizer state in a greedy algorithm based on parameter size rather than usage order. This is to ensure that each rank has nearly the same amount of optimizer memory.

  2. The training process is similar to PyTorch’s distributed Data parallelism (DDP) process. Forward propagation is done on each rank, followed by backward propagation. Allreduce is used to synchronize gradients during backward propagation.

  3. Each rank updates only the optimizer state parameters for which it is responsible, and discards the remaining optimizer parameters.

  4. After the update, broadcast or AllGather is executed to ensure that all ranks receive the latest parameter values.

See the following figure for details.

2.3.2 Best practices

Some best practices are as follows:

  • OSS exposes a broadcast_FP16 Flag that you should probably use in multi-node jobs. This is not usually necessary in single-node experiments.
  • If your model is highly uneven in size (for example, there is a huge tensor), then this method will be of great help, not the tensor segmentation options, such as fairscale. Nn. FullyShardedDataParallel will be preferable.
  • OSS remains compatible with most DDP functions.
  • OSS should be a temporary solution in the DDP environment.

2.3.3 Performance Description

Here are some performance notes.

  • OSS should always be faster than Vanilla PyTorch on a single node, and memory savings will vary depending on the optimizer used.

  • OSS is useful when you use an optimizer with additional state, such as Adam.

  • If you are using SGD or any optimizer with limited memory footprint, you may see slowdowns when using multiple nodes due to the extra communication in Step 4 of the process above. In the AllReduce process at step 2, there is also some wasted memory for storing gradients, which is then discarded.

  • OSS can also be faster or slower than Vanilla PyTorch when using multiple nodes, depending on the optimizer used and optional flags (broadcast_fp16, gradient compression, gradient accumulation mentioned above)

  • If you can use a larger batch size, it is better to take a larger batch size and reduce the number of ranks involved, or use gradient accumulation, as this can reduce communication costs.

Let’s move on to ZeroRedundancyOptimizer.

0x03 How to Use it

We start with pytorch.org/tutorials/r… Let’s see how to use ZeroRedundancyOptimizer.

3.1 Ideas behind

ZeroRedundancyOptimizer comes from DeepSpeed/ZeRO Project and Marian, which shred optimizer state across distributed data parallel processes to reduce memory footprint for each process. The optimization strategy of ZeRO is mainly to optimize the video memory usage by segmenting model states, which include optimizer states, gradients and model parameters.

ZeroRedundancyOptimizer implements a shred of optimizer states, which are the parameters and local states that the optimizer needs to run. For example, SGD requires the same amount of momentum as the model parameters, and the Adam optimizer stores exp_AVG and EXP_AVG_SQ states for each parameter. As a result, the Adam optimizer consumes at least twice the size of the model. Therefore, optimizer state is not a small video memory overhead when the model is large.

In Getting Started With DistributedDataParallel, we showed you how to use DistributedData aparallel (DDP) to train models. In the DDP:

  • Each worker process (rank, node, or device) keeps a dedicated copy of the optimizer.
  • Because DDP has synchronized gradients with all-reduce in backpropagation, all copies of the optimizer will run on the same parameters and gradients in each iteration.
  • These optimizers update model parameters with gradients after all-reduce, which is why DDP can keep all copies of the model (rank) in the same parameter state.

Based on this observation, we can reduce optimizer memory footprint by splitting optimizer state between DDP processes. To be more specific,

  • The optimizer is divided into different workers, and the optimizer instance on each worker only retains the part (1/world_size) of the optimizer state corresponding to the parameter fragment of its model, instead of creating the corresponding parameter state for all parameters.
  • The optimizerstep()The function is only responsible for updating the parameters in its shard. When the worker finishes parameter updating, it will broadcast the updated parameters to all other peer DDP processes so that all model copies are still in the same state.

3.2 How to Use it

ZeroRedundancyOptimizer with torch. Nn. The parallel. DistributedDataParallel used in combination, to reduce the peak memory consumption of each rank. The following code demonstrates how to use ZeroRedundancyOptimizer. Much of the code is similar to the simple DDP example given in Distributed Data Parallel Notes. The main difference is the if else clause in the example function, which wraps the optimizer construct and can switch between ZeroRedundancyOptimizer and Adam. We simply warp the regular Optimizer using ZeroRedundancyOptimizer.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device) :
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero) :
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000.2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer( # ZeroRedundancyOptimizer is used here
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam, # Wrap Adam
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20.2000).to(rank))
    labels = torch.randn(20.2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main() :
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()
Copy the code

The output is shown below.

With or without ZeroRedundancyOptimizer, the model parameters use the same memory after each iteration, so the printed output is the same. When ZeroRedundancyOptimizer is enabled to wrap Adam, the peak memory consumption of the optimizer step() is half that of Adam. This is as we expected because we shard the Adam optimizer state over two processes.

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
Copy the code

3.3 summary

The ZeroRedundancyOptimizer class encapsulates any Optim. Optimizer and splits its states in ranks. The local optimizer instance in each rank is only responsible for updating parameters of about 1 / world_size, so only 1 / world_size optimizer state needs to be maintained.

So the focus of our analysis is:

  • How do I partition optimizer parameters?
  • How does each rank know its own parameters?

0 x04 initialization

Let’s first look at how to build from __init__, which does three main steps:

  • Initialize the base class.
  • Initialize various member variables.
  • Using _update_trainable internal sync & build buffer, _optim_constructor is called internally to build the internal optimizer.
    def __init__(
        self,
        params,
        optimizer_class: Type[Optimizer], # is the wrapped native optimizer type
        group: Optional[Any] = None,
        parameters_as_bucket_view: bool = False,
        **default: Any.) :
        # Hold all the model params in the root .param_groups
        # NOTE: the default constructor uses `add_param_group` which is partially overloaded here
        # we introduce the `initialized` flag for be able to dissociate the behaviour of
        # `add_param_group` in between super() and ZeroRedundancyOptimizer
        self.initialized = False
        super().__init__(params, default) Initialize the base class

        # Partition information. lazy evaluation, computed if requested
        self._per_device_params_cache: "OrderedDict[torch.device, List[List[Parameter]]]" = (
            OrderedDict()
        )  # device, rank, params

        # Build the wrapped optimizer, responsible for a shard of the params
        self._param_rank_cache: Dict[torch.Tensor, int] = {} Initialize various member variables
        self._param_to_index_cache: Dict[int.int] = {}
        self._partition_parameters_cache: List[List[Dict]] = []
        self._index_to_param_cache: Dict[int, torch.Tensor] = {}
        self._all_params = params
        self._reference_is_trainable_mask = list(map(_is_trainable, self._all_params))

        self.group = group if group is not None else dist.group.WORLD
        self.world_size = dist.get_world_size(self.group)
        self.rank = dist.get_rank(self.group) 
        # global is used to synchronize between processes
        self.global_rank = _get_global_rank(self.group, self.rank)
        self.parameters_as_bucket_view = parameters_as_bucket_view

        self._optim_defaults = default
        self._optim_constructor = optimizer_class How do I generate the native optimizer

        # Optional consolidated optimizer state
        self._all_states: List[Dict[str.Any]] = []
        # Current default device is set by the parameters allocated to this rank
        self._device = list(self._per_device_params.keys())[0]
        self.buckets: Dict[torch.device, List[torch.Tensor]] = {}

        self._update_trainable() Call _optim_constructor to build the internal optimizer
        self.initialized = True
Copy the code

Because of the nature of the Python language, there is no place to initialize a member variable. Instead, it is initialized as soon as a variable is encountered during program execution. Therefore, we do not analyze in the order in which the program is actually initialized, but in the order in which the member variables are logically initialized.

The functions or member variables examined below are called or initialized indirectly in __init__ methods.

4.1 Partition Parameters

The partition_parameters method partitions the parameters and returns _partition_parameters_cache.

The wrapped Optimizer shards the optimizer state using the sorted greek-greedy algorithm based on parameter size (rather than using order), packaging some parameters in each rank so that each parameter belongs to a rank and is not divided between ranks. Partitions are arbitrary and may not match the order in which the parameters are registered or used. This is to ensure that each rank has nearly the same amount of optimizer memory.

def partition_parameters(self) - >List[List[Dict]] :
    r""" Partitions parameters across distributed data parallel ranks. Returns: a list of ``param_groups`` (which is a list of dict) where each element of the list contains the param_groups for a rank. Element 0 corresponds to rank 0, etc. We need all the ranks for the broadcast inside ``step()``. """
    if len(self._partition_parameters_cache) == 0:
        self._partition_parameters_cache = [list(a)for _ in range(self.world_size)]
        # create an array to record the size of each rank
        sizes = [0] * self.world_size 
        
        for param_group in self.param_groups: Pass through the parameter group
            param_lists: List[List] = [list(a)for _ in range(self.world_size)]
              
            for param in param_group["params"] :# Add this param to rank with smallest size.
                rank = sizes.index(min(sizes)) # find the smallest rank
                param_lists[rank].append(param) Put the parameter in the smallest rank
                sizes[rank] += param.numel() # increase the rank size

            for rank, params in enumerate(param_lists): # traversing the list
                param_group_rank = copy.copy(param_group)
                param_group_rank["params"] = params
                self._partition_parameters_cache[rank].append(param_group_rank)

    return self._partition_parameters_cache
Copy the code

The result is a list of param_groups (a dict list). Each element of the list contains a rank’s param_groups. For example, element 0 corresponds to rank 0, and the group parameters of each rank are about the same size. In step(), we need all the rank information to broadcast. The following figure shows the param_groups of Rank 0 and Rank 5.

_partition_parameters_cache

          +
          |
          |
          v                +---------------+
  +-------+---------+      | param_group   |
  |       0         +----> |               |      <-------+  100 M   +------------->
  +-----------------+      +---------------+
  |       1         |      |               |     +--------+---------+------+--------+
  +-----------------+      |   "params" +------> |param 1 | param 2|... | param6|
  |       2| | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + + + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - + -- -- -- -- -- - + -- -- -- -- -- -- -- -- + | | | | |... | | | +---------------+ +-----------------+ | param_group | <-------+105 M  +----------------->
  |       5         +----> |               |
  +-----------------+      +---------------+     +--------+---------+-------+---------+
                           |               |     |        |         |       |         |
                           |  "params"  +------> | param 7| param 8|... | param11|
                           |               |     |        |         |       |         |
                           +---------------+     +--------+---------+-------+---------+
Copy the code

4.2 Assigning Parameters to ranks

Now that the parameters have been divided into groups of similar size, you need to rank these groups.

The _param_to_rank method generates a table that records the rank of each parameter. That is, which parameter is in which rank.

@property
def _param_to_rank(self) - >Dict[torch.Tensor, int] :
    r"""Look up table to match a given param with a data parallel rank"""
    if len(self._param_rank_cache) == 0:
        for rank, param_groups in enumerate(self.partition_parameters()):
            for param_group in param_groups:
                for param in param_group["params"]:
                    self._param_rank_cache[param] = rank
    return self._param_rank_cache
Copy the code

Param 1, param 2, param 6 are in rank 0, param 8, param 11 are in rank 5….. , as follows:

_param_rank_cache

      +
      |
      |
      |
      v
 +----+--------------+------------+
 |                   |            |
 |   param 1         |     0      |
 +--------------------------------+
 |                   |            |
 |   param 2         |     0      |
 +--------------------------------+
 |                   |            |
 |   param 6         |     0      |
 +--------------------------------+
 |                   |            |
 |   param 8         |     5      |
 +--------------------------------+
 |                   |            |
 |   param 11        |     5      |
 +--------------------------------+
 |                   |            |
 |   param n         |     n      |
 |                   |            |
 +-------------------+------------+
Copy the code

4.3 _per_device_params

Now that the parameters are assigned to each rank, it is time to assign them to devices, each of which may contain multiple rank parameter groups. The _per_device_params method divides the optimizer’s param_groups between devices and returns _per_device_params_cache.

Note that _per_device_params includes all the model parameters here, although it has been sorted by device. That is, they are the same in every ZeRO optimizer. This allows for broadcast synchronization of these parameters between ZeRO optimizers.

@property
def _per_device_params(self) - >Dict[torch.device, List[List[Parameter]]]:
    r""" Sorted list of all the params, first per device then per rank. Within a list params are sorted per number of elements to allow for an easy bucketing. "" "
    if len(self._per_device_params_cache) == 0:
        # Go through all params, log them per device
        # The ordering is important here, needs to be the same on all ranks
        # So that ulterior broadcast calls are matching
        for param_group in self.param_groups: Pass through the parameters
            for param in param_group["params"]:
                device = param.device # Find its device
                if self._per_device_params_cache.get(device) is None:
                    self._per_device_params_cache[device] = [[] for _ in range(self.world_size)]
                # Each device also needs to be separated by rank
                self._per_device_params_cache[device][self._param_to_rank[param]] += [param]

        # Sort param_lists by size
        for k in self._per_device_params_cache.keys():
            for r in self._per_device_params_cache[k]:
                r.sort(key=lambda x: x.numel())

    return self._per_device_params_cache
Copy the code

For example, CPU, GPU 1 (ignored), and GPU 2 all have their own parameter lists, and each list is arranged by parameter size.

_per_device_params_cache

      +
      |                                      +--------+--------+-------+--------+
      |                                      |        |        |       |        |
      |                     +---------+      | param1 | param3 |param5 | param6 |
      v                     |         |      |        |        |       |        |
 +----+--------------+      | rank 0  +----> |  1k    |  2k    |  3k   |   7k   |
 |                   |      |         |      +--------+--------+-------+--------+
 |     "CPU"         +----> +---------+
 |                   |      |         |
 +-------------------+      | rank 1  |      +--------+--------+-------+--------+
 |                   |      |         +----> |        |        |       |        |
 |     "GPU 1"       |      +---------+      | param9 | param2 | param4| param8 |
 |                   |                       |        |        |       |        |
 +-------------------+                       |  0.5k  |  1k    |  4k   |   8k   |
 |                   |                       +--------+--------+-------+--------+
 |     "GPU 2"       |      +---------+
 |                   +----> |         |      +---------+------------+-----------+
 +-------------------+      |         |      |         |            |           |
                            | rank 5  +----> | param 11|  param 13  | param 15  |
                            |         |      |         |            |           |
                            +---------+      +---------+------------+-----------+
                            |         |
                            | rank 6  |      +---------+------------+-----------+
                            |         +----> |         |            |           |
                            |         |      | param 19|  param 12  | param 14  |
                            +---------+      |         |            |           |
                                             +---------+------------+-----------+

Copy the code

4.4 _update_trainable

Because some parameters change, you need to synchronize with each other between the local optimizer and ZeroRedundancyOptimizer.

  • First you get self._default_device as “CPU” or “GPU #”.
  • _optim_constructor is then called to build the internal optimizer. Note that this tells the local optimizer that you are responsible for tuning these parameters, regardless of the other shards. The partition_parameters method, as mentioned earlier, partitions the parameters and returns _partition_parameters_cache.
# select the parameters corresponding to your rank to optimize
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)

The runtime variables are as follows:
#_optim_constructor = {type} <class 'torch.optim.adam.Adam'>
#_optim_defaults = {dict: 1} {'lr': 0.01}
Copy the code
  • Next, call _sync_param_groups to synchronize the parameters.

  • Finally, the flat buffer is established.

The specific code is as follows:

def _update_trainable(self) - >None:
    r""" Updates the partitioning and communication patterns if the trainability (``requires_grad``) of some parameters changed. """

    # Create the optim which will work on the param shard
    if not hasattr(self, "optim"):
        self._clear_cache()
        Get the default device
        self._default_device = list(self._per_device_params.keys())[0]
        # build local optimizer, just select the parameters corresponding to this rank
        self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
        # Call _sync_param_groups to synchronize the parameters, self.optim is the wrapped optimizer
        self._sync_param_groups(self.optim.param_groups, self.param_groups)

    if self.parameters_as_bucket_view:
        self._setup_flat_buffers() # Build the flat buffer
Copy the code

For example, the local optimizer of Rank 5 only points to the parameters corresponding to _partition_parameters_cache[5]. The local optimizer only optimizes these parameters.

This enables optimizer parameter partitioning. Parameters such as _partition_parameters_cache[5] can later be placed on gpus so that each GPU contains only a partial partition of the optimizer.

Note that the model parameters and gradients do not change, but the native ZeroRedundancyOptimizer points to some parameters that need to be optimized, so the state of the optimizer is reduced accordingly.

As shown in the figure below, the original optimizer needs to optimize all parameters, which may be 100 M + 105 M +…. Now ZeroRedundancyOptimizer only needs to optimize 105 M.

 _partition_parameters_cache

        +
        |
        |
        v                +---------------+
+-------+---------+      | param_group   |
|       0         +----> |               |      <-------+  100 M   +------------->
+-----------------+      +---------------+
|       1         |      |               |     +--------+---------+------+--------+
+-----------------+      |   "params" +------> |param 1 | param 2|... | param6|
|       2| | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + + + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - + -- -- -- -- -- - + -- -- -- -- -- -- -- -- + | | | | |... | | | +---------------+ +-----------------+ | param_group | <-------+105 M  +----------------->
|       5         +----> |               |
+-----------------+      +---------------+     +--------+---------+-------+---------+
                         |               |     |        |         |       |         |
                    +--> |  "params"  +------> | param 7| param 8|... | param11|
                    |    |               |     |        |         |       |         |
                    |    +---------------+     +--------+---------+-------+---------+
                    |
                    |
                    |
+-----------------------+
| Local Optimizer   |   |
|                   |   |
|                   |   |
|                   +   |
|                       |
|                       |
|                       |
|                       |
+-----------------------+
Copy the code

We need to refine this further by looking at the _sync_param_groups and _setup_flat_buffers functions.

4.4.1 Synchronizing Parameter Groups

_sync_param_groups is used to synchronize the internal optimizer’s parameter groups to the local Zero optimizer’s parameter groups.

    @staticmethod
    def _sync_param_groups(source: List[Dict[Any.Any]], destination: List[Dict[Any.Any]]) - >None:
        r"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""

        for source_group, destination_group in zip(source, destination):
            # Sync everything but the parameters
            for k in filter(lambdax: x ! ="params", source_group.keys()):
                destination_group[k] = source_group[k]
Copy the code

4.4.2 Creating a single buffer

If parameters_AS_bucket_VIEW is set, _setup_flat_buffers is called to establish several buffers. A tensor of the same rank on the same device is treated as a buffer. That’s the _per_device_params.

def _setup_flat_buffers(self) - >None:
    r""" Make all params which are on the same device and tied to the same rank views of a single buffer. This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and ``_update_trainable`` is called. """

    for device, per_rank_params in self._per_device_params.items():
        # Only wipe the existing buckets if there are none
        # (could be that this is called twice, when trainability changes)
        if device not in self.buckets.keys():
            self.buckets[device] = []

        # Make parameters a view of the bucket
        for dst_rank, params in enumerate(per_rank_params):
            if len(params) > 0:

                # Clone the non-trainable params, if in a bucket it will get destroyed
                for param in filter(lambda x: not x.requires_grad, params):
                    param.data = param.data.detach().clone()

                # Merge all the trainable params in a single bucket
                trainable_params = list(filter(_is_trainable, params))
                buffer_size = sum(map(lambda x: x.numel(), trainable_params))
                bucket = torch.empty(buffer_size, dtype=params[0].dtype, device=device)
                offset = 0

                for param in trainable_params:
                    offset_next = offset + param.numel()
                    bucket[offset:offset_next].copy_(param.data.flatten())
                    param.data = bucket[offset:offset_next].view_as(param.data)
                    offset = offset_next

                # Either replace the existing bucket, or create it
                if len(self.buckets[device]) == dst_rank:
                    self.buckets[device].append(bucket)
                else:
                    self.buckets[device][dst_rank] = bucket
            else:
                self.buckets[device].append(torch.zeros(1, device=device))
Copy the code

To illustrate, a tensor of the same rank on the same device is treated as a buffer.

buckets
     +
     |
     |               +---------------------------------------+
     v               | Tensor                                |
+----+-------+       | +-----------------------------------+ |
|            |       | |                                   | |
|  "CPU"     +-----> | | Param 1, param 2,  Param 3.. | | | | | +-----------------------------------+ | +------------+ +---------------------------------------+ | | |"GPU 1"   +-----> +---------------------------------------+
|            |       | Tensor                                |
+------------+       | +-----------------------------------+ |
|            |       | |                                   | |
|            |       | | Param 6, Param 7,  Param 8.. | | | | | +-----------------------------------+ | | | +---------------------------------------+ | | +------------+Copy the code

0x05 Updates the parameter

Let’s look at how the optimizer updates parameters, using the following logic:

  • If the computed graph changes, it needs to be reprocessed.
  • Call _sync_param_groups to synchronize the local optimizer parameters to the ZeRO optimizer to prevent it from being modified by scheduler.
  • Call self.optim.step to let the local optimizer update on top of the local parameters.
  • Call dist. Broadcast to synchronize parameters between ranks.
  • Call _sync_PARAM_groups again to synchronize the local optimizer parameters to the ZeRO optimizer because it has been updated.
def step(self, closure: Optional[Callable[[].float]] = None, **kwargs: Any) - >Optional[float] :
    r""" Performs a single optimization step (parameter update). Arguments: closure (callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. Returns: optional loss, depends on the underlying optimizer .. note: Any extra parameter is passed to the base optimizer as-is """

    # Check whether the model trainability graph changed
    If the calculation graph changes, it needs to be reprocessed
    trainable_mask = list(map(_is_trainable, self._all_params))
    iftrainable_mask ! = self._reference_is_trainable_mask: self._update_trainable() self._reference_is_trainable_mask = trainable_mask# Sync oss param_groups attributes in case they've been updated by a scheduler.
    self._sync_param_groups(self.param_groups, self.optim.param_groups)

    # Run the optimizer step on this shard only:
    Update local parameters
    if closure is not None:
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
    else:
        loss = self.optim.step(**kwargs)

    # Sync all the updated shards in between the ranks
    handles = []
    if self.parameters_as_bucket_view:
        for device in self.buckets.keys():
            for src_rank, bucket in enumerate(self.buckets[device]):
                global_src_rank = _get_global_rank(self.group, src_rank)
                handles.append(dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True))
    else:        
        for device, per_rank_params in self._per_device_params.items(): # traversal the device + its parameters
            for dst_rank, params in enumerate(per_rank_params): # traversal rank
                global_dst_rank = _get_global_rank(self.group, dst_rank)
                for param in params: For each parameter, broadcast
                    handles.append(
                        dist.broadcast(tensor=param.data, src=global_dst_rank, group=self.group, async_op=True_ =))list(map(lambda x: x.wait(), handles))

    # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
    self._sync_param_groups(self.optim.param_groups, self.param_groups)

    return loss
Copy the code

The 5.1 update

The first is to update model parameters locally.

Update local parameters
if closure is not None:
    loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
else:
    loss = self.optim.step(**kwargs)
Copy the code

Assume that the model has a total of eight parameters, divided into two nodes, each node has an optimizer. In order to illustrate this, in the top and bottom optimizer, put the parameter and rank number at the top.

Again: the model parameters and gradients are unchanged, but the native ZeroRedundancyOptimizer points to some parameters that need to be optimized, so the optimizer state is reduced accordingly.

So, the model (parameters to be optimized) is the same size in both optimizers, but:

  • In ZeroRedundancyOptimizer 0, rank 0 is optimized and parameters 0 to 3 are locally optimized. These parameters are globally updated for each node.

  • ZeroRedundancyOptimizer 1 optimizes rank 1, and parameters 4 to 7 are locally optimized. These parameters are globally updated for both nodes.

+--------------------------------------------------------------------------------+
|                                                     ZeroRedundancyOptimizer 0  |
|                                                                                |
|   _per_device_params_cache                                                     |
|       +                                                                        |
|       |                                                                        |
|       v          +--------+           +--------+--------+-------+--------+     |
|   +---+-----+    | rank 1 |           |        |        |       |        |     |
|   |         |    |        +---------> | param4 | param5 | param6| param7 |     |
|   | "GPU"1" +--> +--------+ | | | | | | | | | | | +--------+--------+-------+--------+ | | +---------+ | rank 0 | | | | | +--------+--------+-------+--------+ | | | +---------> | | | | | | | +--------+ | param0 | param1 |param2 | param3 | NEW  | | +----> | | | | | | | +----------------+ | +--------+--------+-------+--------+ | | |Local Optimizer | | | | | +----------+ | | | | | | +----------------+ | | | Node 0 +--------------------------------------------------------------------------------+ +--------------------------------------------------------------------------------+ | | Node 1 | | | _per_device_params_cache | | + | | | +--------+--------+-------+--------+ | | v +--------+ +---> | | | | | | | +---+-----+ | rank 1 | | | param4 | param5 | param6| param7 | NEW | | | | | +---------> | | | | | | | | "GPU"1" +--> +--------+     |     +--------+--------+-------+--------+     |
|   |         |    |        |     |                                              |
|   +---------+    | rank 0 |     |     +--------+--------+-------+--------+     |
|                  |        +---------> |        |        |       |        |     |
|                  |        |     |     | param0 | param1 |param2 | param3 |     |
|                  +--------+     |     |        |        |       |        |     |
|                                 |     +--------+--------+-------+--------+     |
|   +----------------+            |                                              |
|   |Local Optimizer |            |                                              |
|   |                +------------+                                              |
|   |                |                                                           |
|   +----------------+                                 ZeroRedundancyOptimizer 1 |
|                                                                                |
+--------------------------------------------------------------------------------+
Copy the code

5.2 radio

Notice first that _per_device_params includes all of the model parameters here, although it has been sorted by device.

The status is that the optimizer parameters of the rank have been updated, i.e. the part of the model has been updated. To keep models up to date, you need to broadcast to each other.

After the parameters are updated locally, each rank broadcasts its parameters to all other peers to keep all model copies in the same state.

+--------------------------------------------------------------------------------+
|                                                     ZeroRedundancyOptimizer 0  |
|                                                                                |
|   _per_device_params_cache                                                     |
|       +                                                                        |
|       |                                                                        |
|       v          +--------+           +--------+--------+-------+--------+     |
|   +---+-----+    | rank 1 |           |        |        |       |        |     |
|   |         |    |        +---------> | param4 | param5 | param6| param7 |     |
|   | "GPU"1" +--> +--------+ | | | | | | | | | | | +--------+--------+-------+--------+ | | +---------+ | rank 0 | | | | | +--------+--------+-------+--------+ | | | +---------> | | | | | | | +--------+ | param0 | param1 |param2 | param3 | NEW  | | +----> | | | | | | | +----------------+ | +---+----+---+----+-+-----+--+-----+ | | |Local Optimizer | | | | | | | |  | +----------+ | | | | | | | | | ^ | ^ | ^ | ^ | | +----------------+ | | | | | | | | | | | | | | | | | | | Node 0 +--------------------------------------------------------------------------------+ | | | | | | | | | | | | | | | | | | |  | | | | | +--------------------------------------------------------------------------------+ | | | | | | | | | | Node 1  | v | v | v | v | | | _per_device_params_cache | | | | | | + | | | | | | | +------+-+------+-+----+--+------+-+ | | v +--------+ +---> | | | | | | | +---+-----+ | rank 1 | | | param4 | param5 | param6| param7 | NEW | | | | | +---------> | | | | | | | |"GPU"1" +--> +--------+     |     +--------+--------+-------+--------+     |
|   |         |    |        |     |                                              |
|   +---------+    | rank 0 |     |     +--------+--------+-------+--------+     |
|                  |        +---------> |        |        |       |        |     |
|                  |        |     |     | param0 | param1 |param2 | param3 |     |
|                  +--------+     |     |        |        |       |        |     |
|                                 |     +--------+--------+-------+--------+     |
|   +----------------+            |                                              |
|   |Local Optimizer |            |                                              |
|   |                +------------+                                              |
|   |                |                                                           |
|   +----------------+                                 ZeroRedundancyOptimizer 1 |
|                                                                                |
+--------------------------------------------------------------------------------+
Copy the code

5.3 Synchronizing Local Parameters

Finally, you need to call _sync_PARAM_groups again to synchronize the local optimizer parameters to the ZeRO optimizer because it has been updated.

# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(self.optim.param_groups, self.param_groups)
Copy the code

Let’s go over the specific functions again.

@staticmethod
def _sync_param_groups(source: List[Dict[Any.Any]], destination: List[Dict[Any.Any]]) - >None:
    r"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""

    for source_group, destination_group in zip(source, destination):
        # Sync everything but the parameters
        for k in filter(lambdax: x ! ="params", source_group.keys()):
            destination_group[k] = source_group[k]
Copy the code

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

Talk about ZeroRedundancyOptimizer and Join in torch1.10

Pytorch.org/tutorials/r…

Pytorch.org/docs/master…

Medium.com/swlh/inside…

www.microsoft.com/en-us/resea…