0 x00 the

Previously, we gave the design idea of distributed Autograd. At the beginning of this article, we conducted specific source code analysis. Since both forward propagation and back propagation rely on RPC for completion, let’s first look at some of the basic functions encapsulated on TOP of RPC, such as initialization, proxy (RPC related functions are based on proxy completion), message acceptance, sending, and so on.

Through this article, you can understand: how to initialize the RPC back end, how to generate the RPC proxy, how to use the RPC proxy to send and receive, how to connect to the remote dist. Autograd automatic differential engine.

PyTorch distributed other articles as follows:

PyTorch distributed (1)—— History and Overview

PyTorch how to use GPU

PyTorch distributed (2) —– DataParallel – gradient

PyTorch distributed (3) —– DataParallel – gradient

PyTorch distributed (4)—— Distributed application concepts

—— DistributedDataParallel – what to use

DistributedDataParallel — gradient — gradient — — — — — —

—– DistributedDataParallel – conditional processing groups

PyTorch distributed (8) ——– DistributedDataParallel allel allel allel allel allel allel allel allel allel allel

—– DistributedDataParallel – gradient initialization

PyTorch distributed (10)—— distributed Dataparreducer static schema

—– DistributedDataParallel constructs Reducer and Join operations

—– DistributedDataParallel – gradient forward propagation

—– DistributedDataParallel – gradient back-propagation

PyTorch distributed Autograd (1) —- design

For better illustration, the code in this article will be streamlined accordingly.

0 x01 sample

We took some code from the PyTorch sample section and modified it to allow collaboration between two workers via RPC. The example worker is divided into two parts:

  • RPC operations, building a dependency base.
  • Perform backward propagation.
def my_add(t1, t2) :
  return torch.add(t1, t2)

def worker0() :
    # On worker 0:

    # Setup the autograd context. Computations that take
    # part in the distributed backward pass must be within
    # the distributed autograd context manager.
    with dist_autograd.context() as context_id:
      t1 = torch.rand((3.3), requires_grad=True)
      t2 = torch.rand((3.3), requires_grad=True)

      # Stage 1: RPC operations, building the dependency base
      
      # Perform some computation remotely.
      t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

      # Perform some computation locally based on remote result.
      t4 = torch.rand((3.3), requires_grad=True)
      t5 = torch.mul(t3, t4)

      # Compute some loss.
      loss = t5.sum(a)# Stage 2, perform backward propagation
      
      # Run the backward pass.
      dist_autograd.backward(context_id, [loss])

      # Retrieve the gradients from the context.
      dist_autograd.get_gradients(context_id)

      print(loss)  
Copy the code

Two workers can be started by using RPC. Init_rpc to initialize RPC. Worker0 starts and then performs some operations on worker 1 using RPC.

def run_worker(rank, world_size) :
    r""" A wrapper function that initializes RPC, calls the function, and shuts down RPC. """

    # We need to use different port numbers in TCP init_method for init_rpc and
    # init_process_group to avoid port conflicts.
    rpc_backend_options = TensorPipeRpcBackendOptions()
    rpc_backend_options.init_method = "tcp://localhost:29501"

    # Rank 0 and 1 are trainers.
    if rank == 0:
        rpc.init_rpc(
            "worker0",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        worker0()

    elif rank == 1:
        rpc.init_rpc(
            "worker1",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

    # block until all rpcs finish
    rpc.shutdown()
Copy the code

0 x02 RPC

2.1 the initialization

Looking at the sample code from the beginning, when the script starts, RPC is initialized with a call to rpc.init_rpc. You can see two concepts in RPC comments, namely the common rank and world_size.

rank (int): a globally unique id/rank of this node.
world_size (int): The number of workers in the group.
Copy the code

The specific initialization code is:

def init_rpc(
    name,
    backend=None,
    rank=-1,
    world_size=None,
    rpc_backend_options=None.) :
        dist_autograd._init(rank) # We will discuss distributed automatic differential engines later
        _set_profiler_node_id(rank)
        # Initialize RPC.
        _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)  
Copy the code

_init_rpc_backend sets the backend.

2.1.1 Initializing the Backend

_init_rpc_BACKEND Checks the generated Agent based on the configuration and sets the Agent to the current context. RPC has two back-end types, TENSORPIPE and PROCESS_GROUP, of which the PROCESS_GROUP has been deprecated and will gradually migrate to TENSORPIPE.

def _init_rpc_backend(
    backend=BackendType.TENSORPIPE,  The default backend is TENSORPIPE
    store=None,
    name=None,
    rank=-1,
    world_size=-1,
    rpc_backend_options=None.) :

    _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)

    if _is_current_rpc_agent_set():
        raise RuntimeError("RPC is already initialized")

    # Initialize RPC.
    rpc_agent = backend_registry.init_backend( Create an agent
        backend,
        store=store,
        name=name,
        rank=rank,
        world_size=world_size,
        rpc_backend_options=rpc_backend_options,
    )

    api._init_rpc_states(rpc_agent) Set the proxy to the current context
Copy the code

As you can see, TensorPipeAgent is generated by default.

2.1.2 Generating agents

We then look at how to generate TensorPipeAgent, concrete is in the torch/CSRC/distributed/RPC/init. CPP. When TensorPipeAgent is generated here, configure RequestCallbackImpl as a callback function. This callback function is used internally by the agent to process the received request.

shared_ptr_class_<TensorPipeAgent>(module."TensorPipeAgent", rpcAgent)
    .def(
        py::init([] (const c10::intrusive_ptr<::c10d::Store>& store,
                    std::string selfName,
                    worker_id_t selfId,
                    int worldSize,
                    c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
                    TensorPipeRpcBackendOptions opts) {
          return std::shared_ptr<TensorPipeAgent>(
              new TensorPipeAgent(
                  store,
                  std::move(selfName),
                  selfId,
                  worldSize,
                  std::move(processGroup),
                  std::move(opts),
                  std::make_unique<RequestCallbackImpl>()), // RequestCallbackImpl is configured on the Agent
              impl::destroy_without_gil<TensorPipeAgent>);
        })
Copy the code

Details are as follows:

+-----------------+ +-----------------------+ | TensorPipeAgent | | RequestCallbackImpl | | | | | | cb_ +----------> | |  | | | | +-----------------+ +-----------------------+Copy the code

2.1.3 Setting the Proxy

_init_rpc_states the proxy Settings in PyTorch environment, its definition in the torch/distributed/RPC/API. A py.

def _init_rpc_states(agent) :
    worker_infos = agent.get_worker_infos()
    global _ALL_WORKER_NAMES
    _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}

    # NB: backend implementation might have already set the rpc_agent.
    if not _is_current_rpc_agent_set():
        _set_and_start_rpc_agent(agent)
Copy the code

The next step is the world of C++. In the torch/CSRC/distributed/RPC/init. CPP _set_and_start_rpc_agent, whose role is:

  • RpcAgent: : setCurrentRpcAgent set a agent.
  • Call rpcAgent->start() to start the agent.
module.def(
    "_set_and_start_rpc_agent"[] (const std::shared_ptr<RpcAgent>& rpcAgent) {
        
      RpcAgent::setCurrentRpcAgent(rpcAgent); // Set the Agent
        
      // Initializing typeResolver inside RpcAgent constructor will make
      // RpcAgent have python dependency. To avoid RpcAgent to have python
      // dependency, setTypeResolver() here.
        
      std::shared_ptr<TypeResolver> typeResolver =
          std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
            auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
                qn.qualifiedName());
            return c10::StrongTypePtr(
                PythonRpcHandler::getInstance().jitCompilationUnit(),
                std::move(typePtr));
          });
      rpcAgent->setTypeResolver(typeResolver);
      rpcAgent->start(a);// Start the proxy
    },
    py::call_guard<py::gil_scoped_release>());
Copy the code

SetCurrentRpcAgent defined in the torch/CSRC/distributed/RPC/rpc_agent CPP.

2.1.4 Static Class Variables

In RpcAgent, there is a static member variable currentRpcAgent_.

class TORCH_API RpcAgent {
     // We omit other member variables and functions
     private:
      static std::shared_ptr<RpcAgent> currentRpcAgent_;
}
Copy the code

In C++, static member variables have the following characteristics:

  • It belongs to the entire class.
  • Its life is not dependent on any object, is the life cycle of the program.
  • Public static member variables can be accessed directly by the class name.
  • Public static member variables of a class can be accessed by object names.
  • All derived objects of a class share static member variables of the class.
  • Static member variables require separate space allocation outside the class.
  • Static member variables are located in the global data area within the program.

So, we know RpcAgent: : currentRpcAgent_ can think of is a global variable, the RPC unified use this variable to coordinate. These functions are accomplished through some public member functions of RpcAgent.

std::shared_ptr<RpcAgent> RpcAgent::currentRpcAgent_ = nullptr;

bool RpcAgent::isCurrentRpcAgentSet(a) {
  return std::atomic_load(&currentRpcAgent_) ! =nullptr;
}

std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent(a) {
  std::shared_ptr<RpcAgent> agent = std::atomic_load(&currentRpcAgent_);
  return agent;
}

void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
  if (rpcAgent) {
    std::shared_ptr<RpcAgent> previousAgent;
    // Use compare_exchange so that we don't actually perform the exchange if
    // that would trigger the assert just below. See:
    // https://en.cppreference.com/w/cpp/atomic/atomic_compare_exchange
    std::atomic_compare_exchange_strong(
        &currentRpcAgent_, &previousAgent, std::move(rpcAgent));
  } else {
    // We can't use compare_exchange (we don't know what value to expect) but we
    // don't need to, as the only case that would trigger the assert is if we
    // replaced nullptr with nullptr, which we can just do as it has no effect.
    std::shared_ptr<RpcAgent> previousAgent =
        std::atomic_exchange(&currentRpcAgent_, std::move(rpcAgent)); }}Copy the code

So now expand as follows, later for RPC operation, through RpcAgent: : currentRpcAgent_ this global variable.

RpcAgent::currentRpcAgent_
      +
      |
      |
      |
      v
+-----+-----------+        +-----------------------+
| TensorPipeAgent |        | RequestCallbackImpl   |
|                 |        |                       |
|         cb_ +----------> |                       |
|                 |        |                       |
+-----------------+        +-----------------------+
Copy the code

2.2 the RPC proxy

The related functions of dist. Autograd are based on RPC proxy completion, so we need to take a closer look at the proxy.

2.2.1 RpcAgent

This is the proxy used to pass RPC. It is the proxy base class for sending and receiving RPC messages.

  • providessendThe API is used to process request and response.
  • Cb_ is also configured to handle incoming requests.

WorkerInfo is a globally unique identifier for the worker of the agent instance, including the name_ and ID_ member variables. Name_ is a globally unique name and ID_ is a globally unique ID.

class TORCH_API RpcAgent {
 public:
  RpcAgent(
      WorkerInfo id,
      std::unique_ptr<RequestCallback> cb,
      std::chrono::milliseconds rpcTimeout);
  
  // Send a message to the other RpcAgengt represented by to.id, which returns a JitFuture. The implementation is asynchronous.
  virtual c10::intrusive_ptr<JitFuture> send(
      const WorkerInfo& to.id,
      Message&& message,
      const float rpcTimeoutSeconds = kUnsetRpcTimeout,
      const std::unordered_map<c10::Device, c10::Device>& deviceMap = {}) = 0;

 protected:
  const WorkerInfo workerInfo_; // Globally unique identifier for the proxy instance
  const std::unique_ptr<RequestCallback> cb_; // The callback function
  std::atomic<std::chrono::milliseconds> rpcTimeout_;
  std::atomic<bool> profilingEnabled_;
  std::shared_ptr<TypeResolver> typeResolver_;
  std::atomic<bool> rpcAgentRunning_;

 private:
  static std::shared_ptr<RpcAgent> currentRpcAgent_; // Global proxy
  // Add GIL wait time data point to metrics
  virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0;
  friend class PythonRpcHandler;
  // Condition Variable to signal when the rpcRetryMap_ has been populated.
  std::condition_variable rpcRetryMapCV_;
  // Mutex to protect RpcRetryMap_.
  std::mutex rpcRetryMutex_;
};
Copy the code

2.2.2 ProcessGroupAgent

ProcessGroupAgent is a derived class of RpcAgent. This was used before, but PyTorch provides a better TensorAgent. We only selected some of the member variables.

class TORCH_API ProcessGroupAgent : public RpcAgent {
 public:

  c10::intrusive_ptr<::c10d::ProcessGroup> pg_;
  // worker name -> rank
  std::unordered_map<std::string, worker_id_t> nameMap_;
  std::vector<WorkerInfo> allWorkerInfo_;

  MessageCounter sendCounts_;
  MessageCounter recvCounts_;

  std::atomic<int64_t> nextId_;

  std::thread listenerThread_;
  std::thread futureTimeoutThread_;
  c10::intrusive_ptr<c10d::ProcessGroup::Work> recvWork_;

  std::unordered_map<
      worker_id_t,
      std::set<c10::intrusive_ptr<c10d::ProcessGroup::Work>>>
      currentPendingSends_;

  ThreadPool threadPool_;

  // Mapping of request id to FutureInfo struct.
  std::unordered_map<int64_t, FutureInfo> futures_;
};
Copy the code

2.2.3 TensorPipeAgent

TensorPipeAgent defined in the torch/CSRC/distributed/RPC/tensorpipe_agent. H, this is the current and future use. TensorPipeAgent uses TensorPipe to transparently move tensors and data between available transports or channels. It is like a hybrid RPC transport, providing shared memory (Linux) and TCP (Linux & MAC) support. PyTorch is developing its CUDA-enabled version.

We only selected some of the member variables.

// TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe)
// to transparently move tensors and payloads through the fastest available
// transport or channel. It acts like a hybrid RPC transport, providing shared
// memory (linux) and TCP (linux & mac) support. CUDA support is in progress.
class TensorPipeAgent : public RpcAgent {
 public:
  TensorPipeAgent(
      const c10::intrusive_ptr<::c10d::Store>& store,
      std::string selfName,
      worker_id_t selfId,
      int worldSize,
      c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
      TensorPipeRpcBackendOptions opts,
      std::unique_ptr<RequestCallback> cb);

  const TensorPipeRpcBackendOptions opts_;
  std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
  std::vector<c10::Device> devices_;

  ThreadPool threadPool_;
  std::shared_ptr<tensorpipe::Context> context_;
  std::shared_ptr<tensorpipe::Listener> listener_;

  mutable std::mutex connectedPipesMutex_;
  std::unordered_map<worker_id_t, ClientPipe> connectedPipes_;

  // Maps keyed on name and id for easy WorkerInfo lookup.
  std::unordered_map<worker_id_t, WorkerInfo> workerIdToInfo_;
  std::unordered_map<std::string, WorkerInfo> workerNameToInfo_;
  std::unordered_map<std::string, std::string> workerNameToURL_;

  ::c10d::PrefixStore rankToNameStore_;
  ::c10d::PrefixStore nameToAddressStore_;
  const int worldSize_;

  // The join method is required to behave like a barrier and perform collective
  // operations. For simplicity and reliability, we offload this to a process
  // group, but probably one day we might want to re-implement them using RPCs.
  const c10::intrusive_ptr<::c10d::ProcessGroup> processGroup_;

  std::atomic<uint64_t> nextMessageID_{0};

  // Thread that will poll the timeoutMap_ for timed out messages and mark them
  // with an error accordingly
  std::thread timeoutThread_;

  // Function run by the timeoutThread_ to check for timed out RPCs
  void pollTimeoutRpcs(a);
};
Copy the code

2.2.4 Callback functions

The Agent invokes the callback function when receiving a message. RequestCallbackImpl implements the callback logic. RequestCallbackImpl is a derived class. We first look at the base class RequestCallbackNoPython and find the interface RequestCallback, so RequestCallback is the basis of this derived system.

class TORCH_API RequestCallbackNoPython : public RequestCallback
  
class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython   
Copy the code
2.2.4.1 RequestCallback

RequestCallback is an abstract class that processes RPC messages.

// Functor which is invoked to process an RPC message. This is an abstract class
// with some common functionality across all request handlers. Users need to
// implement this interface to perform the actual business logic.
class TORCH_API RequestCallback {
 public:
  // Invoke the callback.
  c10::intrusive_ptr<JitFuture> operator(a)( Message& request, std::shared_ptr
       
         ctx)
        const;

  // NOLINTNEXTLINE(modernize-use-equals-default)
  virtual ~RequestCallback() {}

 protected:
  // RpcAgent implementation should invoke ``RequestCallback`` to process
  // received requests. There is no restriction on the implementation's
  // threading model. This function takes an rvalue reference of the Message
  // object. It is expected to return the future to a response message or
  // message containing an exception. Different rpc agent implementations are
  // expected to ensure delivery of the response/exception based on their
  // implementation specific mechanisms.
  virtual c10::intrusive_ptr<JitFuture> processMessage( Message& request, std::shared_ptr
       
         ctx)
        const = 0;
};
Copy the code
2.2.4.2 RequestCallbackNoPython

RequestCallbackNoPython in the definition of a torch/CSRC/distributed/RPC/request_callback_no_python. H, it implements some processing mechanism, because it contains too many methods, we can only extract part, If you are interested, please go further.

// RequestCallback implementation with no Python dependencies.
class TORCH_API RequestCallbackNoPython : public RequestCallback {
 public:
  c10::intrusive_ptr<JitFuture> processMessage( Message& request, std::shared_ptr
       
         ctx)
        const override;

 protected:

  void processForwardAutogradReq(
      RpcCommandBase& rpc,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const;

  void processBackwardAutogradReq(
      RpcCommandBase& rpc,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const;

  void processRpc(
      RpcCommandBase& rpc,
      const MessageType& messageType,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const;

  virtual void processRpcWithErrors(
      RpcCommandBase& rpc,
      const MessageType& messageType,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const;

  virtual void processRRefBackward(
      RpcCommandBase& rpc,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const;
};
Copy the code

We will see how the callback function is called later when we examine the receiving logic.

0x03 Send Logic

Let’s look at the send logic first. That’s what rpc.rpc_sync does: create root, add send, etc.

3.1 the Python

Let’s start with the Python section.

# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
Copy the code

First we go to rpc_sync, which invokes _INVOke_RPC.

@_require_initialized
def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT) :
    fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
    return fut.wait()
Copy the code

Next, go to _invoke_rpc, and you can see that the function selects different paths depending on the type of call (built-in operation, script, UDF).

def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT) :
    qualified_name = torch.jit._builtins._find_builtin(func)
    dst_worker_info = _to_worker_info(to)
    should_profile = torch.autograd._profiler_enabled()
    ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)

    with ctx_manager as rf:
        args = args if args else ()
        kwargs = kwargs if kwargs else {}

        is_async_exec = hasattr(func, "_wrapped_async_rpc_function")

        if is_async_exec:
            wrapped = func._wrapped_async_rpc_function
            if isinstance(wrapped, torch.jit.ScriptFunction):
                func = wrapped

        if qualified_name is not None:
            fut = _invoke_rpc_builtin( # built-in RPC
                dst_worker_info,
                qualified_name,
                rpc_timeout,
                *args,
                **kwargs
            )
        elif isinstance(func, torch.jit.ScriptFunction): # script
            fut = _invoke_rpc_torchscript( 
                dst_worker_info.name,
                torch._jit_internal._qualified_name(func),
                args,
                kwargs,
                rpc_timeout,
                is_async_exec
            )
        else:
            (pickled_python_udf, tensors) = _default_pickler.serialize(
                PythonUDF(func, args, kwargs)
            )
            fut = _invoke_rpc_python_udf( # the user udf
                dst_worker_info,
                pickled_python_udf,
                tensors,
                rpc_timeout,
                is_async_exec
            )
        if should_profile:
            fut = rf._call_end_callbacks_on_future(fut)
    return fut
Copy the code

Start from here into the c + + world, torch/CSRC/distributed/RPC/init. CPP.

C + + 3.2

You can see that _invoke_rpc_builtin corresponds to pyRpcBuiltin, and _invoke_rpc_PYTHON_UDf corresponds to pyRpcPythonUdf.

PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
  module.def(
      "_invoke_rpc_builtin",
      [](const WorkerInfo& dst,
         const std::string& opName,
         const float rpcTimeoutSeconds,
         const py::args& args,
         const py::kwargs& kwargs) {
        return std::make_shared<jit::PythonFutureWrapper>(
            pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds)); # Built-in functions
      },
      py::call_guard<py::gil_scoped_acquire>());

  module.def(
      "_invoke_rpc_python_udf",
      [](const WorkerInfo& dst,
         std::string& pickledPythonUDF,
         std::vector<torch::Tensor>& tensors,
         const float rpcTimeoutSeconds,
         const bool isAsyncExecution) {
        return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
            dst,
            pickledPythonUDF, # corresponds to udF
            tensors,
            rpcTimeoutSeconds,
            isAsyncExecution));
      },
      py::call_guard<py::gil_scoped_release>());  
  
  # omit others
}
Copy the code

Let’s use pyRpcBuiltin for _invoke_rpc_builtin.

3.2.1 pyRpcBuiltin

In the torch/CSRC/distributed/RPC/python_functions CPP as you can see, pyRpcBuiltin will call to sendMessageWithAutograd.

c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
    const WorkerInfo& dst,
    const std::string& opName,
    const py::args& args,
    const py::kwargs& kwargs,
    const float rpcTimeoutSeconds) {
  DCHECK(PyGILState_Check());
  Stack stack;
  auto op = matchBuiltinOp(opName, args, kwargs, stack);
  // Release GIL since args and kwargs processing is done.
  py::gil_scoped_release release;
  auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
  auto agent = RpcAgent::getCurrentRpcAgent(a);// Get the current agent
  return toPyJitFuture(sendMessageWithAutograd( // Send the request
      *agent,
      dst,
      std::move(*scriptCall).toMessage(),
      false,
      rpcTimeoutSeconds));
}
Copy the code

3.2.2 sendMessageWithAutograd

In the torch/CSRC/distributed/autograd/utils. Using the agent to send FORWARD_AUTOGRAD_REQ CPP here.

Later on the receiver, we’ll see processing FORWARD_AUTOGRAD_REQ messages, so sending and receiving can be roughly linked.

c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
    RpcAgent& agent,
    const WorkerInfo& dst,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    bool forceGradRecording,
    const float rpcTimeoutSeconds,
    bool forceDisableProfiling) {
  auto msg = getMessageWithAutograd( // This will interact with the context and build FORWARD_AUTOGRAD_REQ
      dst.id_,
      std::move(wrappedRpcMsg),
      MessageType::FORWARD_AUTOGRAD_REQ,
      forceGradRecording,
      agent.getDeviceMap(dst));

  c10::intrusive_ptr<JitFuture> fut;
  // If profiler is enabled, wrap this message with profiling metadata that will
  // tell the remote end to process this request with the profiler enabled.
  if(! forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) {
    auto profilerConfig = torch::autograd::profiler::getProfilerConfig(a);auto msgWithProfiling = getMessageWithProfiling(
        std::move(msg),
        rpc::MessageType::RUN_WITH_PROFILING_REQ, // Build the message
        std::move(profilerConfig));
    // Send a message
    fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
  } else {
    fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
  }

  return fut;
}
Copy the code

Send the process is as follows, will use RpcAgent sendMessageWithAutograd: : getCurrentRpcAgent () get RpcAgent: : currentRpcAgent_, is the agent got the global Settings, It is then sent through the proxy.

rpc.rpc_sync + | | v _invoke_rpc_builtin + | Python +---------------------------------------------------------------+ | C++ | v pyRpcBuiltin + | | v sendMessageWithAutograd(RpcAgent::getCurrentRpcAgent()) + | | | RpcAgent::currentRpcAgent_ | + | | | | | v | +-----+-----------+ | | TensorPipeAgent | +-----------------------+ | | | | RequestCallbackImpl | | | cb_ +------------> | | | | | +-----------------------+ | | | | | | +-----------> send +-----------> Will send message to  other worker | | | | +-----------------+Copy the code

0x04 Accept logic

4.1 the callback

When the Agent receive the news, will call to the RequestCallback: : operator (). This is the callback function we talked about earlier. Code is located in the torch/CSRC/distributed/RPC/tensorpipe_agent CPP.

void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
  pipeRead(
      pipe,
      [this, pipe](
          const tensorpipe::Error& error,
          Message&& requestMessage,
          std::shared_ptr<LazyStreamContext> ctx) mutable {

        // Arm for next read
        respond(pipe);

        uint64_t messageId = requestMessage.id(a);increaseCallCount(serverActiveCalls_);

        // Defer user RPC UDF run to thread pool
        threadPool_.run([this,
                         pipe,
                         messageId,
                         requestMessage{std::move(requestMessage)},
                         ctx{std::move(ctx)}]() mutable {

          c10::intrusive_ptr<JitFuture> futureResponseMessage;
          try {
              
            // RequestCallback is called to handle the callback logic
              
            futureResponseMessage = cb_->operator()(requestMessage, ctx);
            
          } catch (const std::exception& /* unused */) {
            futureResponseMessage =
                c10::make_intrusive<JitFuture>(at::AnyClassType::get());
            futureResponseMessage->setError(std::current_exception());
          }

          // Shortcut if immediately done
          if (futureResponseMessage->completed()) {
            decreaseCallCount(serverActiveCalls_);
            sendCompletedResponseMessage(
                pipe, *futureResponseMessage, messageId, std::move(ctx));
          } else {
            // Not complete yet
            increaseCallCount(serverActiveAsyncCalls_);
            futureResponseMessage->addCallback([this, pipe, messageId, ctx{std::move(ctx)}](
                    JitFuture& futureResponseMessage) mutable {
                  decreaseCallCount(serverActiveCalls_);
                  decreaseCallCount(serverActiveAsyncCalls_);
                  sendCompletedResponseMessage(
                      pipe, futureResponseMessage, messageId, std::move(ctx)); }); }}); }); }Copy the code

4.2 the operator ()

ProcessMessage is called in operator() to process the message.

c10::intrusive_ptr<JitFuture> RequestCallback::operator(a)( Message& request, std::shared_ptr
       
         ctx)
        const {
  // NB: cannot clear autograd context id here because the processMessage method
  // might pause waiting for all RRefs in the arguments to be confirmed by their
  // owners and resumne processing in a different thread. Hence, the
  // thread_local context id needs to be set and cleared in the thread that
  // indeed carries out the processing logic.
  return processMessage(request, std::move(ctx));
}
Copy the code

Then, will call to RequestCallbackNoPython: : intrinsic processMessage.

  • First call RequestCallbackImpl implemented deserializePythonRpcCommand to PythonUDF deserialization.
  • ProcessRpcWithErrors is then called to process the message.
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage( Message& request, std::shared_ptr
       
         ctx)
        const {
  // We need two futures here because it could pause twice when processing a
  // RPC message:
  // 1) waiting for all RRefs in the arguments to become confirmed;
  // 2) waiting for processRpc to finish.
  auto retFuture = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
  auto& rrefContext = RRefContext::getInstance(a);try {
    rrefContext.recordThreadLocalPendingRRefs(a);// Deserialize PythonUDF here to trigger RRef unpickling
    / / call RequestCallbackImpl implemented deserializePythonRpcCommand to PythonUDF deserialization
    std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
        deserializeRequest(request), request.type()); // Parse the request
    auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs(a); rrefsReadyFuture->addCallback([this,
         retFuture,
         // std::function must be copyable, hence hae to cast the unique_ptr to
         // a shared_ptr here.
         rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
         messageType = request.type(),
         id = request.id(),
         ctx = std::move(ctx)](JitFuture& /* unused */) mutable {
          c10::MultiStreamGuard guard(
              ctx ? ctx->getReservedStreams() : ArrayRef<Stream>({}));
          // The cost of pre-request check is minimal thanks to
          // std::shared_lock. The cost is in magnitude
          // of 10us.
          auto serverProcessGlobalProfilerStateStackEntryPtr =
              profiler::processglobal::StateStackEntry::current(a);// If server global profiler is enabled, we futher pay the
          // cost of thread local profiler state initialization.
          if (serverProcessGlobalProfilerStateStackEntryPtr) {
            // Initialize thread-local profiler state from process-global
            // profiler state.
            ::torch::autograd::profiler::enableProfilerLegacy(
                serverProcessGlobalProfilerStateStackEntryPtr->statePtr() - >config());
          }

          / / here
          processRpcWithErrors(
              *rpc, messageType, id, retFuture, std::move(ctx));

          // Response message has been sent at this moment, this post-response
          // work doesn't affect RPC trip time.
          if (serverProcessGlobalProfilerStateStackEntryPtr) {
            // Restore thread-local profiler state.
            ::torch::autograd::profiler::thread_event_lists event_lists =
                ::torch::autograd::profiler::disableProfilerLegacy(a);// Put thread_local event_lists into the process-global profiler
            // state.
            profiler::processglobal::pushResultRecursive( serverProcessGlobalProfilerStateStackEntryPtr, event_lists); }}); }catch (std::exception& e) {
    retFuture->markCompleted(handleError(e, request.type(), request.id()));
    rrefContext.clearRecordedPendingRRefsOnError(a); }return retFuture;
}
Copy the code

Then call processRpcWithErrors.

void RequestCallbackNoPython::processRpcWithErrors(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {
  try {
    processRpc(rpc, messageType, messageId, responseFuture, std::move(ctx));
  } catch (std::exception& e) {
    responseFuture->markCompleted(handleError(e, messageType, messageId)); }}Copy the code

Next up is processRpc. Here you can see processing FORWARD_AUTOGRAD_REQ.

void RequestCallbackNoPython::processRpc(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {

    case MessageType::FORWARD_AUTOGRAD_REQ: { // This corresponds to the one sent earlier
      processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx));
      return;
    }
    case MessageType::BACKWARD_AUTOGRAD_REQ: {
      processBackwardAutogradReq(rpc, messageId, responseFuture);
      return;
    };  
  
}  
Copy the code

Details are as follows:

 TensorPipeAgent      RequestCallback  RequestCallbackNoPython     RequestCallbackImpl
        +                   +                 +                          +
        |                   |                 |                          |
        |                   |                 |                          |
        v                   |                 |                          |
    respond                 |                 |                          |
        +                   |                 |                          |
        |                   |                 |                          |
        |                   |                 |                          |
        v                   v                 v                          |
cb_->operator()  +-->   operator()  +-->  processMessage                 |
                                              +                          |
                                              |                          |
                                              |                          v
                                              +--------------->  deserializePythonRpcCommand
                                              |
                                              |
                                              |
                                              v

                                      processRpcWithErrors
                                              +
                                              |
                                              |
                                              v
                                          processRpc
                                              +
                                              |
                                              |
                                              v
                                    processForwardAutogradReq

Copy the code

4.3 RequestCallbackImpl

At that time, the reader will have question, before TensorPipeAgent clearly set up RequestCallbackImpl as a callback function, how to just call it deserializePythonRpcCommand? DeserialXXX looks serializable, and should call some business processing functions, such as processXXXX. Let’s look at RequestCallbackImpl next.

RequestCallbackImpl defined in the torch/CSRC/distributed/RPC/request_callback_impl. H.

class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
 public:
  std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
      std::unique_ptr<RpcCommandBase> rpc,
      const MessageType& messageType) const override;

  void processPythonCall(
      RpcCommandBase& rpc,
      const std::function<void(Message)>& markComplete,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const override;

  void processScriptCall(
      RpcCommandBase& rpc,
      const std::function<void(Message)>& markComplete,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const override;

  void processScriptRemoteCall(
      ScriptRemoteCall& scriptRemoteCall,
      const std::function<void(void)>& postProcessing,
      std::vector<at::IValue>& stack,
      const c10::intrusive_ptr<OwnerRRef>& ownerRRef) const override;

  void processPythonRemoteCall(
      RpcCommandBase& rpc,
      const std::function<void(Message)>& markComplete,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const override;

  void processRpcWithErrors(
      RpcCommandBase& rpc,
      const MessageType& messageType,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const override;

  void processRRefBackward(
      RpcCommandBase& rpc,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
};
Copy the code

Because the end result is RequestCallbackImpl, so in fact, there is a step in the middle of the figure processRpcWithErrors which is actually calling RequestCallbackImpl function processRpcWithErrors, It just adds some exception handling logic.

void RequestCallbackImpl::processRpcWithErrors(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {
  try {
    processRpc(rpc, messageType, messageId, responseFuture, std::move(ctx));
  } catch (py::error_already_set& e) {
    responseFuture->markCompleted(handleError(e, messageType, messageId));
    py::gil_scoped_acquire acquire;
    e.restore(a);// Release ownership on py::objects and also restore
                 // Python Error Indicator.
    PyErr_Clear(a);// Clear the Python Error Indicator as we has
                   // recorded the exception in the response message.
  } catch (std::exception& e) {
    responseFuture->markCompleted(handleError(e, messageType, messageId)); }}Copy the code

The logical diagram is modified as follows:

 TensorPipeAgent      RequestCallback  RequestCallbackNoPython     RequestCallbackImpl
        +                   +                 +                          +
        |                   |                 |                          |
        |                   |                 |                          |
        v                   |                 |                          |
    respond                 |                 |                          |
        +                   |                 |                          |
        |                   |                 |                          |
        |                   |                 |                          |
        v                   v                 v                          |
cb_->operator()  +-->   operator()  +-->  processMessage                 |
                                              +                          |
                                              |                          |
                                              |                          v
                                              +----------------> deserializePythonRpcCommand
                                              |                          +
                                              |                          |
                                              |                          |
                                              |                          v
                                              |
                                              +----------------> processRpcWithErrors
                                              |                          +
                                              |                          |
                                              |                          |
                                              | <------------------------+
                                              |
                                              |
                                              v
                                          processRpc
                                              +
                                              |
                                              |
                                              v
                                    processForwardAutogradReq

Copy the code

If combined with the previous sending, we expand the legend as follows:

  1. Rpc.rpc_sync is called when the sender needs to run automatic gradient calculations remotely.
  2. From Python to the C++ world, the function is pyRpcBuiltin.
  3. The Receiver is notified by calling sendMessageWithAutograd.
  4. Invokes the RpcAgent: : getCurrentRpcAgent () to get the local Agent.
  5. Call the Send function of the Current Agent.
  6. The send function sends FORWARD_AUTOGRAD_REQ to the Receiver worker.
  7. The respond function calls the cb_ callback function on the Agent in the Receiver.
  8. Call processRpcWithErrors to RequestCallbackImpl.
  9. Then call processRpc.
  10. The last call to processForwardAutogradReq, completed a rpc-based distributed autograd boot process.
                                                             +
 rpc.rpc_sync                                 Sender         |     Receiver
        +                                                    |
        |                                                    |
        | 1                                                  |
        v                                                    |
 _invoke_rpc_builtin                                         |
        +                                                    |
        |                                      Python        |
+----------------------------------------------------------+ |
        |                                      C++           |      +----------------------------+
        |  2                                                 |      | RequestCallbackImpl        |
        v                                                    |      |                            |
                                                             |   +----> processRpcWithErrors     |
   pyRpcBuiltin                                              |   |  |             +              |
        +                                                    |   |  |             | 9            |
        |  3                                                 |   |  |             |              |
        |                                                    |   |  |             v              |
        v                                                    |   |  |         processRpc         |
                                     4                       |   |  |             +              |
sendMessageWithAutograd(RpcAgent::getCurrentRpcAgent())      |   |  |             | 10           |
        +                                                    |   |  |             |              |
        |                                                    |   |  |             v              |
        |                                                    |   |  |  processForwardAutogradReq |
        |   RpcAgent::currentRpcAgent_                       |   |  |                            |
        |           +                                        |   |  +----------------------------+
        |           |                                        |   |
        | 5         |                                        |   |8     +-----------------+
        |           v                                        |   |      | TensorPipeAgent |
        |    +------+--------+                               |   |      |                 |
        |    |TensorPipeAgent|   +-------------------+       |   +------------+ cb_       |
        |    |               |   |RequestCallbackImpl|       |          |        ^        |
        |    |      cb_ +------->+                   |       |          |      7 |        |
        |    |               |   +-------------------+       |          |        |        |
        |    |               |                          6| | + | +--------> send +----------------------------------+--------------> respond | | | FORWARD_AUTOGRAD_REQ | | | | +  | | +---------------+ | +-----------------+ +Copy the code

The mobile phone is as follows:

At this point, THE introduction of RPC is complete, we will introduce the context and other management classes, stay tuned.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference