0 x00 the

Now that we know how dist. Autograd sends and receives messages, this article will look at other supporting parts: how to coordinate the send and receive actions, how to identify each send/receive node, and how to identify each Session for message interaction.

Through this article, we can understand: AutogradMetadata is used to pass autograd meta-information between different nodes. DistAutogradContext represents a distributed autograd related information. DistAutogradContainer is responsible for storing DistAutogradContext on a worker.

PyTorch distributed other articles as follows:

PyTorch distributed (1)—— History and Overview

PyTorch how to use GPU

PyTorch distributed (2) —– DataParallel – gradient

PyTorch distributed (3) —– DataParallel – gradient

PyTorch distributed (4)—— Distributed application concepts

—— DistributedDataParallel – what to use

DistributedDataParallel — gradient — gradient — — — — — —

—– DistributedDataParallel – conditional processing groups

PyTorch distributed (8) ——– DistributedDataParallel allel allel allel allel allel allel allel allel allel allel

—– DistributedDataParallel – gradient initialization

PyTorch distributed (10)—— distributed Dataparreducer static schema

—– DistributedDataParallel constructs Reducer and Join operations

—– DistributedDataParallel – gradient forward propagation

—– DistributedDataParallel – gradient back-propagation

PyTorch distributed Autograd (1) —- design

PyTorch Distributed Autograd (2) —- RPC foundation

For better illustration, the code in this article will be streamlined accordingly.

0x01 Design Context

1.1 Review

In the previous part, when sending a message, we get the message of type FORWARD_AUTOGRAD_REQ by getMessageWithAutograd in sendMessageWithAutograd.

c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
    RpcAgent& agent,
    const WorkerInfo& dst,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    bool forceGradRecording,
    const float rpcTimeoutSeconds,
    bool forceDisableProfiling) {
    
  auto msg = getMessageWithAutograd( // This will interact with the context and build FORWARD_AUTOGRAD_REQ
      dst.id_,
      std::move(wrappedRpcMsg),
      MessageType::FORWARD_AUTOGRAD_REQ,
      forceGradRecording,
      agent.getDeviceMap(dst));

  c10::intrusive_ptr<JitFuture> fut;
  if(! forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) {
    auto profilerConfig = torch::autograd::profiler::getProfilerConfig(a);auto msgWithProfiling = getMessageWithProfiling(
        std::move(msg),
        rpc::MessageType::RUN_WITH_PROFILING_REQ, // Build the message
        std::move(profilerConfig));
    // Send a message
    fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
  } else {
    // Send a message
    fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
  }

  return fut;
}
Copy the code

And getMessageWithAutograd can interact with context, its code is located in the torch/CSRC/distributed/autograd/utils. CPP.

Message getMessageWithAutograd(
    const rpc::worker_id_t dstId,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    MessageType msgType,
    bool forceGradRecording,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  
  // 获取到 DistAutogradContainer
  auto& autogradContainer = DistAutogradContainer::getInstance(a);// If there is no valid context and no tensor requires grads, send original
  // rpc message. otherwise, attach grad info and grad functions and send
  // rpcWithAutograd message.
  auto tensorsRequireGrad =
      torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
  if(! autogradContainer.hasValidContext() | | (! forceGradRecording && ! tensorsRequireGrad)) {return std::move(wrappedRpcMsg);
  }

  // Retrieve the appropriate context to modify.
  auto autogradContext = autogradContainer.currentContext(a);// Get the context. Each worker has its own context

  // Wrap the original rpc with autograd information.
  // newAutogradMessageId generates a messageID
  AutogradMetadata autogradMetadata( // AutogradMetadata is built
      autogradContext->contextId(), autogradContainer.newAutogradMessageId());
  auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
      RpcAgent::getCurrentRpcAgent() - >getWorkerInfo().id_,
      msgType,
      autogradMetadata,
      std::move(wrappedRpcMsg),
      deviceMap);

  if (tensorsRequireGrad) {
    // Record autograd information for 'send'.
    addSendRpcBackward( // The local context, the autograd meta information, and so on are packaged together
        autogradContext, autogradMetadata, rpcWithAutograd->tensors());
  }
  // Record the workerID
  autogradContext->addKnownWorkerId(dstId);

  return std::move(*rpcWithAutograd).toMessage(a);// Finally build a message
}
Copy the code

Therefore, AutogradMetadata, DistAutogradContainer, and DistAutogradContext are the basic classes that we’ll examine in detail.

1.2 General Ideas

Let’s outline the general idea.

First, let’s take a look at the problem: if a system includes nodes A, B and C, and each node runs a worker, then when running a propagation operation, we are involved in the communication among these three nodes. Therefore, we need a mechanism that uniquely identifies the propagation process among the three nodes, and also identifies each send/ RECV on each node so that the node can support multiple operations in parallel.

Look at the solution:

  • Use context to uniquely identify a propagation process. DistAutogradContext stores the relevant information of each distributed autograd on a worker. It encapsulates forward and backward propagation in distributed autograd and accumulates gradients. This prevents multiple workers from influencing each other on their gradients. Each autograd_context_ID is given a unique autograd_context_ID from which the DistAutogradContext is uniquely identified in the container.
  • Use autogradMessageId to represent a pair of send/ Recv autograd functions. everysend-recvThe pair is assigned a globally uniqueautograd_message_idUniquely identify thesend-recvright This is useful for finding the corresponding function on the remote node during backward propagation.
  • Finally, each worker needs a place to keep the context and messageID, hence the DistAutogradContainer class. Each worker has a unique singleton DistAutogradContainer that is responsible for:
    • Store its distributed context for each automatic differential procedure.
    • Once this automatic differentiation process is complete, its data is erased.

Thus, during forward propagation, Pytorch stores in context the Send and RECV functions of each Autograd propagation. This ensures that we keep references to the appropriate nodes in the Autograd diagram to keep them active. In addition, this makes it easy to find the corresponding SEND and RECV functions during backward propagation.

0x02 AutogradMetadata

2.1 define

The AutogradMetadata class is used to pass the meta information of AutogradMetadata between nodes, which encapsulates the context and other information. That is, the sender notifies the receiver of its own context information, and the receiver acts accordingly on the context information received.

As a sneak preview, the receiver will use autogradContextId and autogradMessageId as unique identifiers for the context and message, respectively. You can see it in the notes.

  • AutogradContextId is a globally unique integer used to represent a unique distributed propagation of Autograd (both forward propagation and backward propagation). A propagation process involves multiple pairs of send/ Recv autograd functions on the back propagation chain.
  • AutogradMessageId is a globally unique integer used to represent a pair of send/ Recv autograd functions. everysend-recvThe pair is assigned a globally uniqueautograd_message_idUniquely identify thesend-recvright This is useful for finding the corresponding function on the remote node during backward propagation.
// This structure represents autograd metadata that we need to pass across
// different nodes when we call an RPC which needs autograd computation.
struct TORCH_API AutogradMetadata {
  AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId);

  // autogradContextId_ is a globally unique integer that identifies a
  // particular distributed autograd pass.
  int64_t autogradContextId;
  // autogradMessageId_ is a globally unique integer that identifies a pair
  // of send/recv autograd functions.
  int64_t autogradMessageId;
};
Copy the code

How can autogradContextId and autogradMessageId be globally unique?

2.2 autogradMessageId

To recap, autogradMessageId is generated indirectly by rank and incremented internally so that it is globally unique.

Let’s work backwards.

  • See how newAutogradMessageId generates the message ID by incrementing the next_autograd_message_id_ member variable in DistAutogradContainer.
int64_t DistAutogradContainer::newAutogradMessageId(a) {
  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
  return next_autograd_message_id_++;
}
Copy the code
  • And then how do I initialize itnext_autograd_message_id_? As you can see from the init function of DistAutogradContainer, next_autograd_message_id_ is generated based on the worker_id. Work_id is the parameter given by the init function.
DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
  std::lock_guard<std::mutex> guard(dist_container_init_lock_);

  auto& container = getInstanceInternal(a); container.worker_id_ = worker_id; container.next_context_id_ =static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.max_id_ =
      (kAutoIncrementMask |
       (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
  container.initialized_ = true;
  return container;
}
Copy the code
  • Let’s take a look at how to set the worker ID and find the following:
module.def(
    "_init"[] (int64_t worker_id) { DistAutogradContainer::init(worker_id); },
    py::call_guard<py::gil_scoped_release>());
Copy the code

In the Python world, you can see that rank is used as the parameter, and rank is unique to each worker. This ensures that the worker ID is unique, and therefore the message ID is unique.

    def init_rpc(
        name,
        backend=None,
        rank=-1,
        world_size=None,
        rpc_backend_options=None.) :
			dist_autograd._init(rank) # rank is globally unique
Copy the code

Let’s summarize these logical relationships:

worker_id = rank;

container.worker_id_ = worker_id;

container.next_autograd_message_id_ = static_cast<int64_t>(worker_id) << kAutoIncrementBits
Copy the code

Next_autograd_message_id_ is then incremented internally.

int64_t DistAutogradContainer::newAutogradMessageId(a) {
  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
  return next_autograd_message_id_++;
}
Copy the code

Therefore, AutogradMessageId is globally unique. Let’s use a legend to see:

+----------------------------------------------------------------------------------------+
| worker                                                                                 |
|                       +-------------------------------------+                          |
|                       | DistAutogradContainer               |                          |
|                       |                                     |                          |
|                       |                                     |                          |
|              init()   |                                     |                          |
|      rank +--------------+----> worker_id_                  |                          |
|                1      |  |                                  |   newAutogradMessageId() |
|                       |  +----> next_autograd_message_id_+------------------+          |
|                       |                                     |          2| | | +-------------------------------------+ | | | | | | | | | | | | | | | +---------------------------------------------------------------+ | | | getMessageWithAutograd | | | | | | | | | | v | |  | | | | | | AutogradMetadata autogradMetadata(contextId(), MessageId()) | | | |4                           3       |  |
|                     |                                                               |  |
|                     +---------------------------------------------------------------+  |
|                                                                                        |
+----------------------------------------------------------------------------------------+
Copy the code

To see how the autogradContextId can be guaranteed to be unique, we need to analyze DistAutogradContainer and DistAutogradContext first.

0x03 DistAutogradContainer

Each worker has a unique singleton DistAutogradContainer that is responsible for:

  • Store its distributed context for each automatic differential procedure.
  • Once this automatic differentiation process is complete, its data is erased.

Each autograd_context_id is given a unique autograd_context_ID. In each container, the DistAutogradContext is uniquely identified based on the autograd_context_ID. Autograd_context_id is a 64-bit globally unique ID, with the first 16 bis being the worker_ID and the last 48 bits incrementing the ID automatically within each worker. So you can see that in a Container, there are multiple contexts.

This container is also responsible for maintaining globally unique message ids that are used to associate the send/receive automatic differential function pairs. The format is similar to autograd_context_id, which is a 64-bit integer. The first 16 bits are the worker ID, and the last 48 bits are automatically incremented within the worker.

Because the first 16 bits of the message ID and context ID are the worker_ID (rank ID), plus the last 48 bits of the internal increment, the message ID and context ID are guaranteed to be globally unique.

3.1 define

DistAutogradContainer is defined as follows:

  • Worker_id_ : indicates the ID of the worker, which is the rank of the worker.
  • Next_context_id_ : The incremented context ID used to give each autograd_context_ID a unique autograd_context_ID. On a propagation chain, only the DistAutogradContainer on the first node uses next_context_id_ to generate the Context. The DistAutogradContainer of subsequent nodes generates a local context with the corresponding context ID based on the context ID of the first DistAutogradContainer.
  • Next_autograd_message_id_ : Maintains a globally unique message ID used to associate the send/receive auto differential function pairs. This variable is used when this object is sent.
// Singleton class per worker which is responsible for storing the distributed
// autograd context for each autograd pass and also cleans up data for an
// autograd pass once its done.
//
// Each autograd pass is assigned a unique autograd_context_id and all data for
// that pass (DistAutogradContext) is stored in this container indexed by the
// autograd_context_id. The autograd_context_id itself is a 64 bit globally
// unique id. The first 16 bits is the worker_id and the next 48 bits is an
// auto-incrementing id for each worker.
//
// This container is also responsible for maintaining a globally unique message
// id, which is used to associate send/recv autograd function pairs. The format
// is similar to the autograd_context_id where we have a 64 bit integer with
// first 16 bits being the worker id and next 48 bits are auto-incrementing.
class TORCH_API DistAutogradContainer {

 private:
  // Number of shards for the map storing autograd contexts. We'd like this
  // to be a power of 2 and we don't expect a value much higher than the
  // number of cores would provide much benefit.
  static constexpr uint32_t kNumDefaultShards = 128;

  // Use cache line size for alignment.
  static constexpr int kCacheLineSize = 64;

  // Structure holding one shard of the sharded autograd context map with its
  // associated lock. Align to cache line size to avoid contention between
  // adjacent entries.
  struct alignas(kCacheLineSize) ContextsShard {
    // Lock for this shard.
    mutable std::mutex lock;

    // Map storing autograd contexts for this shard.
    std::unordered_map<int64_t, ContextPtr> contexts; // The context pointer is stored here
  };

  // Auto incrementing context id used to identify unique autograd passes.
  // Initialized with the first 16 bits being the worker_id.
  std::atomic<int64_t> next_context_id_; // Add a context ID

  // Unique id to identify a worker in the distributed setting.
  int16_t worker_id_;

  // Whether or not the container has been initialized appropriately.
  bool initialized_;

  // Sharded autograd context map.
  std::vector<ContextsShard> autograd_contexts_; // Store the context list

  // Number of shards for the sharded autograd_contexts_ map.
  uint32_t num_shards_;

  // Autograd message id to identify unique send/recv autograd function pairs.
  std::atomic<int64_t> next_autograd_message_id_;

  // Maximum allowed value for autograd_context_id or autograd_message_id.
  int64_t max_id_;
};
Copy the code

3.2 build

The Init method builds DistAutogradContainer, which uses worker_id to assign values to local member variables.

DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
  std::lock_guard<std::mutex> guard(dist_container_init_lock_);

  TORCH_CHECK(
      worker_id >= 0 && worker_id <= kMaxWorkerId,
      "worker_id needs to be in the range [0, 65535]") auto& container = getInstanceInternal(); TORCH_CHECK( ! container.initialized_ || (worker_id == container.worker_id_),"Container is already initialized with worker_id: ",
      container.worker_id_,
      ", cannot initialize with different worker_id: ",
      worker_id);

  if (container.initialized_) {
    return container;
  }

  container.worker_id_ = worker_id;
  container.next_context_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.max_id_ =
      (kAutoIncrementMask |
       (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
  container.initialized_ = true;
  return container;
}
Copy the code

0x04 DistAutogradContext

DistAutogradContext stores the relevant information of each distributed autograd on a worker. It encapsulates forward and backward propagation in distributed autograd and accumulates gradients. This prevents multiple workers from influencing each other on their gradients.

As you know, contextId_ is globally unique.

4.1 define

Only the DistAutogradContext member variable is given here, ignoring its member functions. Among them, there are three main member variables:

  • ContextId_ is the context ID.
  • SendAutogradFunctions_ is a map-type variable that collects SendRpcBackward propagator corresponding to all sent requests.
  • RecvAutogradFunctions_ is a map-type variable that collects RecvRpcBackward propagators corresponding to all receive and send requests.

As for SendRpcBackward and RecvRpcBackward, we will analyze it later in combination with the engine.

// DistAutogradContext which stores information for a single distributed
// autograd pass on a worker.
class TORCH_API DistAutogradContext {
 private:
  friend class BackwardPassCleanupGuard;
  friend class DistEngine;
  friend class RecvRpcBackward;
  friend class DistAccumulateGradCaptureHook;

  const int64_t contextId_;

  // Set containing known worker IDs, used in cleaning up autograd context.
  // Whenever a sendRpcBackward is attached to the autograd graph for this
  // context, the destination is added here.
  std::unordered_set<rpc::worker_id_t> knownWorkerIds_;

  // Map from autograd_message_id to appropriate 'send' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
      sendAutogradFunctions_;

  // Map from autograd_message_id to appropriate 'recv' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
      recvAutogradFunctions_;

  // Gradients accumulated in this context so far. The key is the variable on
  // which the gradient needs to be accumulated and the value is the gradient
  // that needs to be accumulated on that variable..
  c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;

  // See comments for recordGradEvent(c10::Device device);
  std::unordered_map<c10::Device, c10::Event> gradReadyEvents_;
  const c10::impl::VirtualGuardImpl impl_;

  // The autograd GraphTask for the backward pass on this node for this context.
  std::shared_ptr<torch::autograd::GraphTask> graphTask_;

  // List of futures for RPCs initiated by this node to propagate gradients to
  // other nodes. The distributed autograd engine on this node can return
  // successfully only if all these futures are done and are successful.
  std::vector<c10::intrusive_ptr<rpc::JitFuture>> outStandingRpcs_;

  // Lock to protect concurrent modification of the context.
  mutable std::mutex lock_;
};

Copy the code

4.2 the message

Context consists mainly of several message types, such as:

// Messages with autograd info
FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,

// Messages to propagate gradients on the backward pass.
BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,
Copy the code

4.3 build

Let’s start by looking at how to build context.

4.3.1 getOrCreateContext

The getOrCreateContext function is used to get the context, if it already exists, or build a new one if it doesn’t. This is a passive call, and the recV side uses this.

ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
  auto& shard = getShard(context_id);
  std::lock_guard<std::mutex> guard(shard.lock);
  auto it = shard.contexts.find(context_id); // Use the context ID
  if(it ! = shard.contexts.end()) {
    return it->second; // Return if found
  }

  auto& context = // If not, build a context
      shard.contexts
          .emplace(
              std::piecewise_construct,
              std::forward_as_tuple(context_id),
              std::forward_as_tuple(
                  std::make_shared<DistAutogradContext>(context_id)))
          .first->second;
  return context;
}
Copy the code

4.3.2 newContext

This is an active call, the send side is going to call this method.

4.3.2.1 Python

When a distributed call is made, the Python world generates a context.

            with dist_autograd.context() as context_id:
                output = model(indices, offsets)
                loss = criterion(output, target)

                # Run distributed backward pass
                dist_autograd.backward(context_id, [loss])

                # Run distributed optimizer. Gradients propagated all the way to the parameter servers
                opt.step(context_id)
Copy the code

When generated, __enter__ will call _new_context() to generate a context in C++.

class context(object) :
    ''' Context object to wrap forward and backward passes when using distributed autograd. The ``context_id`` generated in the ``with`` statement is required to uniquely identify a distributed backward pass on all workers. Each worker stores metadata associated with this ``context_id``, which is required to correctly execute a distributed autograd pass. Example:: >>> import torch.distributed.autograd as dist_autograd >>> with dist_autograd.context() as context_id: >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >>> dist_autograd.backward(context_id, [loss]) '''
    def __enter__(self) :
        self.autograd_context = _new_context() A context is generated here
        return self.autograd_context._context_id()

    def __exit__(self, type, value, traceback) :
        _release_context(self.autograd_context._context_id())
Copy the code

Concrete through the following map, we can see the c + + world corresponding method, called the DistAutogradContainer: : getInstance (). The newContext ().

  module.def(
      "_new_context",
      []() -> const ContextPtr {
        return DistAutogradContainer::getInstance().newContext();
      },
      py::return_value_policy::reference);
Copy the code
4.3.2.2 c + +

Here we are in the C++ world. Each thread has an autograd_context_id.

constexpr int64_t kInvalidContextId = - 1;

// Each thread has a single autograd_context_id valid at any point in time.
static thread_local int64_t current_context_id_ = kInvalidContextId;
Copy the code

NewContext generates a DistAutogradContext where the id of the next context is specified by incrementing the member variable next_context_id_ of the Container.

const ContextPtr DistAutogradContainer::newContext(a) {

  auto context_id = next_context_id_++; / / increment
  current_context_id_ = context_id;  // Here the current_context_id_ for the local thread is set

  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(context_id < max_id_);

  auto& shard = getShard(context_id);
  std::lock_guard<std::mutex> guard(shard.lock);
  auto& context =
      shard.contexts
          .emplace(
              std::piecewise_construct,
              std::forward_as_tuple(context_id),
              std::forward_as_tuple(
                  std::make_shared<DistAutogradContext>(context_id)))
          .first->second;

  return context;
}
Copy the code

4.4 How Can I Share a Context

Specifically, the context_id generated in the with statement can be used to uniquely identify a distributed backward propagation (both forward propagation and backward propagation) across all workers. Each worker stores the metadata associated with this context_ID, which is required to perform the distributed autoload process correctly.

Since the context_ID associated metadata needs to be stored in multiple workers, an encapsulation/send/receive mechanism is needed to transfer the metadata between workers. The encapsulation mechanism is the AutogradMetadata mentioned above. Let’s look at how to send/receive context meta-information.

4.4.1 the sender

When sending a message, can use autogradContainer getMessageWithAutograd currentContext () to obtain the current context, to send.

Message getMessageWithAutograd(
    const rpc::worker_id_t dstId,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    MessageType msgType,
    bool forceGradRecording,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  auto& autogradContainer = DistAutogradContainer::getInstance(a);// If there is no valid context and no tensor requires grads, send original
  // rpc message. otherwise, attach grad info and grad functions and send
  // rpcWithAutograd message.
  auto tensorsRequireGrad =
      torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
  if(! autogradContainer.hasValidContext() | | (! forceGradRecording && ! tensorsRequireGrad)) {return std::move(wrappedRpcMsg);
  }

  // Retrieve the appropriate context to modify.
  auto autogradContext = autogradContainer.currentContext(a);// Get the current context

  // Wrap the original rpc with autograd information.
  AutogradMetadata autogradMetadata( // Use context IDS and message ids to build metadata
      autogradContext->contextId(), autogradContainer.newAutogradMessageId());
  auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
      RpcAgent::getCurrentRpcAgent() - >getWorkerInfo().id_,
      msgType,
      autogradMetadata,
      std::move(wrappedRpcMsg),
      deviceMap);

  if (tensorsRequireGrad) {
    // Record autograd information for 'send'.
    addSendRpcBackward(
        autogradContext, autogradMetadata, rpcWithAutograd->tensors());
  }
  // Record the workerID
  autogradContext->addKnownWorkerId(dstId);

  return std::move(*rpcWithAutograd).toMessage(a); }Copy the code

Our previous diagram can now be extended to include context ids.

+----------------------------------------------------------------------------------------+ | worker | | +------------------------------------------+ | | |DistAutogradContainer | | | init() | | | | rank +-------------+----> worker_id_ | | | | | | | | | +----> next_context_id_+-------------+ | | | | | | | | | | +----> next_autograd_message_id_  +----------------------+ | | | | | | | | | | | | | | +------------------------------------------+ | | | | | | | | | | |  | | | | +------------------------------------------------------------------+ | | |getMessageWithAutograd | | | | | | | | | | | | v v | | | | | | | | AutogradMetadata autogradMetadata(contextId(), MessageId()) | | | | | | | | | | | +------------------------------------------------------------------+ | | | +----------------------------------------------------------------------------------------+Copy the code

The addSendRpcBackward is then passed into the current context, and the addSendRpcBackward is retrieved in subsequent backpropagation.

void addSendRpcBackward(
    const ContextPtr& autogradContext,
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors) {
  // Attach autograd information only for tensors requiring grad.
  std::vector<torch::Tensor> tensors_with_grad;
  std::copy_if(
      tensors.begin(),
      tensors.end(),
      std::back_inserter(tensors_with_grad),
      [](const torch::Tensor& t) { return t.requires_grad(a); });// Attach the appropriate autograd edges.
  auto grad_fn = std::make_shared<SendRpcBackward>();
  grad_fn->set_next_edges(
      torch::autograd::collect_next_edges(tensors_with_grad));

  // Add the appropriate input metadata for the grad_fn.
  for (const auto& tensor : tensors_with_grad) {
    grad_fn->add_input_metadata(tensor);
  }

  // Record the send autograd function in our current context.
  autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
}
Copy the code

4.4.2 receiver

In addRecvRpcBackward, is on the basis of passed autogradMetadata. AutogradContextId to build a context.

ContextPtr addRecvRpcBackward(
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors,
    rpc::worker_id_t fromWorkerId,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  // Initialize autograd context if necessary.
  auto& autogradContainer = DistAutogradContainer::getInstance(a);// Generate or get a context, pass in the sender's autogradContextId, that is, use the autogradContextId as the key to find the context
  auto autogradContext = 
      autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);

  if(! tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
    // Attach the tensors as inputs to the autograd function.
    auto grad_fn = std::make_shared<RecvRpcBackward>(
        autogradMetadata, autogradContext, fromWorkerId, deviceMap);
    for (auto& tensor : tensors) {
      if (tensor.requires_grad()) {
        torch::autograd::set_history(tensor, grad_fn); }}// Now update the autograd context with the necessary information.
    autogradContext->addRecvFunction(
        grad_fn, autogradMetadata.autogradMessageId);
  }

  return autogradContext;
}
Copy the code

In this way, the sender and receiver share a context, and the context ID is globally unique.

The logic is as follows: the upper part is the sending end and the lower part is the receiving end.

  • The sender
    • AutogradMetadata (ctx_id, msg_id) is constructed using context_id.
    • Message is built using AutogradMetadata.
    • Send sends a Message using agent.
  • The receiver:
    • Message received.
    • Parse AutogradMetadata from Message.
    • Extract context_ID from AutogradMetadata.
    • The local DistAutogradContext is built with context_id.
  • The sender and receiver share a context (the id of this context is globally unique).
+----------------------------------------------------------------------------------+ | sendMessageWithAutograd | | | | +----------------------------------------------------------------------------+ | | | addSendRpcBackward | | | | | | | | | | | | autogradMetadata = AutogradMetadata(context_id, message_id) | | | | + | | | | | | | | +----------------------------------------------------------------------------+ | |  | | | v | | agent.send(message(autogradMetadata) | | + | | | | +----------------------------------------------------------------------------------+ | | | | Sender +-----------------------------------------------------------------------------------+ | Receiver | message v | +----------------------------------------------------------------------------------+ | processForwardAutogradReq | | | |  | | | message.autogradMetadata | | v | | +----------------------------------------------------------------------------+  | | | addSendRpcBackward | | | | | | | | | | +--------------------+ | | | | | | | | | v | | | | autogradContext = getOrCreateContext(autogradMetadata.autogradContextId) | | | | | | | | | | | +----------------------------------------------------------------------------+ | | | +----------------------------------------------------------------------------------+Copy the code

0x05 Forward propagation interaction

The previous sharing process is still brief, so let’s take a closer look at the complete send/accept process.

5.1 send

Here corresponds to the following text in the design:

During forward propagation, we store the send and RECV functions for each Autograd propagation in the context. This ensures that we keep references to the appropriate nodes in the Autograd diagram to keep them active. In addition, this makes it easy to find the corresponding SEND and RECV functions during backward propagation.

5.1.1 Sending Logic

The code logic is as follows:

  • Generates a grad_FN of type SendRpcBackward.
  • Add subsequent edges to SendRpcBackward by calling collect_next_edges and set_next_edges, functions we analyzed in the previous series.
  • Call add_input_metadata to add input metadata.
  • Call addSendFunction to add grad_fn to the context.
void addSendRpcBackward(
    const ContextPtr& autogradContext,
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors) {
  // Attach autograd information only for tensors requiring grad.
  std::vector<torch::Tensor> tensors_with_grad;
  std::copy_if(
      tensors.begin(),
      tensors.end(),
      std::back_inserter(tensors_with_grad),
      [](const torch::Tensor& t) { return t.requires_grad(a); });// Attach the appropriate autograd edges.
  auto grad_fn = std::make_shared<SendRpcBackward>();
  grad_fn->set_next_edges( // The output edge is set here
      torch::autograd::collect_next_edges(tensors_with_grad));

  // Add the appropriate input metadata for the grad_fn.
  for (const auto& tensor : tensors_with_grad) {
    grad_fn->add_input_metadata(tensor);
  }

  // Record the send autograd function in our current context.
  autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
}
Copy the code

5.1.2 Setting the Context

Recall once again the DistAutogradContext definition, which gives only some of its member variables.

  • ContextId_ is the context ID.
  • SendAutogradFunctions_ is a map-type variable that collects SendRpcBackward propagator corresponding to all sent requests.
  • RecvAutogradFunctions_ is a map-type variable that collects RecvRpcBackward propagators corresponding to all receive and send requests.
// DistAutogradContext which stores information for a single distributed
// autograd pass on a worker.
class TORCH_API DistAutogradContext {

  const int64_t contextId_;

  // Map from autograd_message_id to appropriate 'send' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
      sendAutogradFunctions_;

  // Map from autograd_message_id to appropriate 'recv' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
      recvAutogradFunctions_;
};
Copy the code

AddSendFunction is to add a SendRpcBackward to sendAutogradFunctions_, which can then be obtained with a Message ID.

void DistAutogradContext::addSendFunction(
    const std::shared_ptr<SendRpcBackward>& func,
    int64_t autograd_message_id) {

  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      sendAutogradFunctions_.find(autograd_message_id) ==
      sendAutogradFunctions_.end());
  sendAutogradFunctions_.emplace(autograd_message_id, func);
}
Copy the code

From the context build perspective, this time from the context content perspective.

The sending logic is as follows:

+--------------------------------------------------------------+    +-------------------+
| worker                                                       |    |SendRpcBackward    |
| +---------------------------------------------------------+  |    |                   |
| | DistAutogradContext                                     |  |    |   input_metadata_ |
| |                                                 +-------------> |                   |
| |  contextId_ = context_id_1                      |       |  |    |   next_edges_     |
| |                                                 +       |  |    |                   |
| |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |  |    +-------------------+
| |                                                         |  |
| |                                                         |  |
| |  recvAutogradFunctions_                                 |  |
| |                                                         |  |
| +---------------------------------------------------------+  |
|                                                              |
+--------------------------------------------------------------+

                                                                                  sender
+---------------------------------------------------------------------------------------+

Copy the code

5.2 to accept

Let’s skip the agent send internal processing and look at the business process for FORWARD_AUTOGRAD_REQ.

5.2.1 Receiving Messages –> Receiving Party

When TensorPipeAgent is generated, configure RequestCallbackImpl as a callback function. This is the unified response function of the Agent.

About agent receives the logic in front of the time, we also mentioned, will enter the following function, which can be seen on processForwardAutogradReq processing logic.

void RequestCallbackNoPython::processRpc(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {

    case MessageType::FORWARD_AUTOGRAD_REQ: {
      // Will come here
      processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx));
      return;
    }
    case MessageType::BACKWARD_AUTOGRAD_REQ: {
      processBackwardAutogradReq(rpc, messageId, responseFuture);
      return;
    };  
  
}  
Copy the code

5.2.2 Processing messages

ProcessForwardAutogradReq is responsible for handling the specific message, its processing logic is as follows:

  • Although the forward propagation request is received, the deviceMap is transposed because it is the receiving end and needs to carry out backpropagation later.
  • Add the RPC message to the context using addRecvRpcBackward.
  • There may be the possibility of a nested command, so processRpc needs to be called again.
  • Set the original message to complete and perform related operations.
void RequestCallbackNoPython::processForwardAutogradReq(
    RpcCommandBase& rpc,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {
  
  auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);

  // Need to reverse the device map for the backward pass of distributed
  // autograd.
  std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
  // Transpose deviceMap
  for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
    reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
  }

  // Attach 'recv' autograd function.
  auto autogradContext = addRecvRpcBackward( // addRecvRpcBackward is called to add context
      rpcWithAutograd.autogradMetadata(),
      rpcWithAutograd.tensors(),
      rpcWithAutograd.fromWorkerId(),
      reverseDeviceMap);
  // For this recv thread on server side, before processRpc(),
  // set current_context_id_ to be context_id passed from client.
  // In this way, if there is nested rpc call in python rpc call, original
  // context_id from client can be passed in the chain calls.
  DistAutogradContextGuard ctxGuard(autogradContext->contextId());

  // Process the original RPC.
  auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(a);// Make an overall future for the wrapped response.
  auto wrappedRpcResponseFuture =
      c10::make_intrusive<JitFuture>(at::AnyClassType::get());
  // Kick off processing for the nested RPC command.
  // wrappedRpcResponseFuture will be a Future<T> to the result.
  processRpc( // There may be the possibility of nested commands, so it needs to be processed again
      rpcWithAutograd.wrappedRpc(),
      wrappedMessageType,
      messageId,
      wrappedRpcResponseFuture,
      std::move(ctx));

  auto fromWorkerId = rpcWithAutograd.fromWorkerId(a);// The original future needs to be marked as completed when the wrapped
  // one completes, with the autograd context information wrapped.
  wrappedRpcResponseFuture->addCallback(
      [responseFuture,
       messageId,
       fromWorkerId,
       ctxId =
           autogradContext->contextId()](JitFuture& wrappedRpcResponseFuture) {
        // As this callback can be invoked by a different thread, we have to
        // make sure that the thread_local states in the previous thread is
        // correctly propagated.
        // NB: The execution of TorchScript functions can also run on a
        // different thread, which is addressed by
        // https://github.com/pytorch/pytorch/pull/36395
        // NB: when adding async UDF support, we should also propagate
        // thread_local states there.
        // TODO: Land on a general solution for RPC ThreadLocalState. See
        // https://github.com/pytorch/pytorch/issues/38510
        DistAutogradContextGuard cbCtxGuard(ctxId);

        if (wrappedRpcResponseFuture.hasError()) {
          // Propagate error to responseFuture if we had one.
          responseFuture->setError(wrappedRpcResponseFuture.exception_ptr());
        } else {
          auto msg = getMessageWithAutograd(
              fromWorkerId,
              std::move(
                  *wrappedRpcResponseFuture.value().toCustomClass<Message>()),
              MessageType::FORWARD_AUTOGRAD_RESP);
          msg.setId(messageId);
          responseFuture->markCompleted(
              IValue(c10::make_intrusive<Message>(std::move(msg)))); }}); }Copy the code

5.2.3 Context Interaction

Torch/CSRC/distributed/autograd/utils CPP, addRecvRpcBackward function will be to deal with the context.

Here corresponds to the design:

During forward propagation, we store the send and RECV functions for each Autograd propagation in the context. This ensures that we keep references to the appropriate nodes in the Autograd diagram to keep them active. In addition, this makes it easy to find the corresponding SEND and RECV functions during backward propagation.

Its specific logic is as follows:

  • Get the local context according to the autogradContextId in the RPC message.
  • Generate a RecvRpcBackward.
  • Configure RecvRpcBackward with tensors from RPC messages, including Torch ::autograd::set_history(tensor, grad_fn).
  • Call addRecvFunction to add RecvRpcBackward to the context.
ContextPtr addRecvRpcBackward(
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors,
    rpc::worker_id_t fromWorkerId,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  // Initialize autograd context if necessary.
  auto& autogradContainer = DistAutogradContainer::getInstance(a);auto autogradContext =
      autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);

  if(! tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
    // Attach the tensors as inputs to the autograd function.
    auto grad_fn = std::make_shared<RecvRpcBackward>(
        autogradMetadata, autogradContext, fromWorkerId, deviceMap);
    for (auto& tensor : tensors) {
      if (tensor.requires_grad()) {
        torch::autograd::set_history(tensor, grad_fn); }}// Now update the autograd context with the necessary information.
    autogradContext->addRecvFunction(
        grad_fn, autogradMetadata.autogradMessageId);
  }

  return autogradContext;
}
Copy the code

Add addRecvFunction to see if the operator already exists in recvAutogradFunctions_. If not, add the operator.

void DistAutogradContext::addRecvFunction(
    std::shared_ptr<RecvRpcBackward>& func,
    int64_t autograd_message_id) {
  TORCH_INTERNAL_ASSERT(func ! =nullptr);
  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      recvAutogradFunctions_.find(autograd_message_id) ==
      recvAutogradFunctions_.end());
  recvAutogradFunctions_.emplace(autograd_message_id, func);
}
Copy the code

At this point, the logic extends as follows: There is a DistAutogradContext on both the sender and the receiver with context_ID_1 id.

Within each DistAutogradContext, msG_ID_1 is the key, one is SendRpcBackward, and the other builds RecvRpcBackward.

This corresponds to the design mentioned:

Each autograd_context_ID is given a unique autograd_context_ID from which the DistAutogradContext is uniquely identified in the container. Autograd_context_id is a 64-bit globally unique ID, with the first 16 bis being the worker_ID and the last 48 bits incrementing the ID automatically within each worker. So you can see that in a Container, there are multiple contexts.

This container is also responsible for maintaining globally unique message ids that are used to associate the send/receive automatic differential function pairs. The format is similar to autograd_context_id, which is a 64-bit integer. The first 16 bits are the worker ID, and the last 48 bits are automatically incremented within the worker.

+----------------------------------------------------------------+ | worker | +-------------------+ | | |SendRpcBackward  | | +---------------------------------------------------------+ | | | | | DistAutogradContext | | | input_metadata_ | |  | +-------------> | | | | contextId_ = context_id_1 | | | | next_edges_ | | | + | | | | | | sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] | | +-------------------+ | | | | | | recvAutogradFunctions_ | | | | | | | +---------------------------------------------------------+ | | | | + | | | | +----------------------------------------------------------------+ | | | Sender +-----------------------------------------------------------------------------------------+ | Receiver | v +-----------------------------+----------------------------------+ | worker | | | +-------------------+ | +---------------------------------------------------------+ | |RecvRpcBackward | | | DistAutogradContext | | | | | | | |  | | | | contextId_ = context_id_1 +-----------------> | input_metadata_ | | | | | | | | | | sendAutogradFunctions_ | | | | next_edges_ | | | + | | | | | | recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1]| | +-------------------+ | | | | | +---------------------------------------------------------+ | | | +----------------------------------------------------------------+Copy the code

Let’s add Container and expand the current logic as follows:

  • Each worker includes a DistAutogradContainer.
  • Each DistAutogradContainer contains several distAutogradContexts. Extract the DistAutogradContext based on the context ID.
  • Each DistAutogradContext includes sendAutogradFunctions_ and recvAutogradFunctions_, Get SendRpcBackward or RecvRpcBackward using the MSG ID.

So this back propagation chain is constructed.

+----------------------------------------------------------------------------------------------------------------------- -------------+ | worker | | | | +---------------------------------------+ +---------------------------------------------------------+ +-------------------+ | | | DistAutogradContainer | | DistAutogradContext | |SendRpcBackward | | | | | | +----------> | | | | | worker_id_ | | contextId_ = ctx_id_1 | | | input_metadata_ | | | | | | + | | | | | | next_autograd_message_id_ +---------> | sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] | | next_edges_ | | | | | | | | | | | | | next_context_id_ | | | recvAutogradFunctions_ | +-------------------+ | | | + | | | | | | autograd_contexts_[ctx_id_1 : ctx] | +---------------------------------------------------------+ | | | | | | +----------------------------+----------+  | | | | +----------------------------------------------------------------------------------------------------------------------- -------------+ | | +----------------------------------------------------------------------------------------------------------------------- --------------+ | v +------------------------------+---------------------------------------------------------------------------------------- -------------+ | worker | | | | +---------------------------------------+ +---------------------------------------------------------+ +-------------------+ | | | DistAutogradContainer | | DistAutogradContext | |RecvRpcBackward | | | | | | +----------> | | | | | worker_id_ | | contextId_ = ctx_id_1 | | | input_metadata_ | | | | | | | | | | | | | next_autograd_message_id_ +---------> | sendAutogradFunctions_ | | | next_edges_ | | | | | | | + | | | | | | next_context_id_ | | | recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1] | +-------------------+ | | | + | | | | | | autograd_contexts_[ctx_id_1 : ctx] | +---------------------------------------------------------+ | | | | | | +---------------------------------------+  | | | +----------------------------------------------------------------------------------------------------------------------- -------------+Copy the code

The mobile phone is as follows:

Now that we’ve taken a look at the context-dependent classes, let’s put together what we’ve analyzed so far and take a systematic look at the business logic.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference