0 x00 the

With respect to PyTorch flexibility training, we have introduced Agent and rendezvous respectively, but some parts, such as monitoring, have not been further explored. This paper will unify them and provide an overall logical review of flexibility training.

The resilience training series is as follows:

PyTorch distributed elastic training (1) — the general idea

PyTorch Distributed Elastic Training (2)– Startup & Single node flow

PyTorch Distributed elastic training (3)– proxy

PyTorch Distributed Elastic Training (4)– Rendezvous Architecture and logic

PyTorch Distributed Elastic Training (5)– Rendezvous Engine

PyTorch Distributed Elastic Training (5)– Rendezvous Engine

0x01 General Logic

We need to look at the logic of the system from several angles, roughly from top to bottom, from whole to part.

1.1 Node Cluster Angle

Let’s start by looking at the Node cluster as a top-down bird’s eye view of the elastic system. In this perspective, each Node runs an Agent, which contains a rendezvous to be responsible for distributed negotiation, and the Agent is also responsible for starting and monitoring workers.

1.2 Agent overall logic diagram

We then go deep into the agent. As we know from the above, the overall logic is shown in the figure below.

  • 1) call_initialize_workersTo start the worker process, that is, multiple processes are started to execute the user program in parallel for training.
    • 2) call _rendezvous
      • Call Next_rendezvous to handle membership changes,
      • Call _assign_worker_ranks to set up the ranks for the worker.
    • 3) Call _start_workers to start workers.
  • 4) Call _monitor_workers to monitor the results of these processes.

1.3 Monitoring Angle

The core of flexibility training is monitoring/dynamic processing, so we went deep into the monitoring module for analysis. From the perspective of monitoring, the Agent main loop _invoke_run has the following logic:

  • Call _initialize_workers to start workers.
    • Call _rendezvous, within which:
      • Call Next_rendezvous to handle membership changes,
      • Call _assign_worker_ranks to set up the ranks for the worker.
    • Call _start_workers to start workers.
  • The program enters a while loop, and then regularly monitors the user program through _monitor_workers and makes decisions based on the situation.
  • If the worker process is faulty or UNHEALTHY, go to elif state in {workerstate. UNHEALTHY, workerstate. FAILED}:.
    • The new rendezvous was restarted with a call to _restart_workers and the worker process.
    • If the number of times the task restarts exceeds the maximum, the task is shut down.
  • If the program is running properly, go to state == workerstate.healthy.
    • If it is scale up, there is a new node in waiting, so restart all workers.

The specific code is as follows:

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        self._initialize_workers(self._worker_group) # start the worker
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assertself._worker_group.state ! = WorkerState.INIT# Regular monitoring
            time.sleep(monitor_interval)
            Monitor client running status
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state # Process health
            self._worker_group.state = state

            if state == WorkerState.SUCCEEDED:
                The program ends normally
                self._exit_barrier() # If one succeeds, it's all over
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # program error
                if self._remaining_restarts > 0: # retry
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group) # restart
                else:
                    self._stop_workers(self._worker_group) The number of retries reaches, terminating workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
								The program is running properly
                # Node membership changes, such as scale up
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                Restart all workers if a new node is waiting
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")
Copy the code

We refined it again, and the details are as follows:

  _initialize_workers  <---------------------------------+                 Node 1    +   Node 2                  _initialize_workers
           +                                             |                           |                                   +
           |                                             |                           |                                   |
           |                                             |  +-----------------+      |      +-----------------+          |
           v                                             |  |RendezvousHandler|    sync     |RendezvousHandler|          v
      _rendezvous +---------------------------------------->+                 | <----+----> |                 +<---+ _rendezvous
           +                          next_rendezvous    |  |                 |      |      |                 |          +
           |                                             |  |                 |      |      |                 |          |
    _assign_worker_ranks                                 |  |                 |  heartbeat  |                 |          |
           |                                             |  |                 | <----+----> |                 |
           v                                             |  +-----------------+      |      +-----------------+          v
     _start_workers                                      |                           |                              _start_workers
           +                                             |                           |                                   +
           |                                             |                           |                                   |
           |                                             |                           |                                   |
           v                                             |                           |                                   v
     +-----+-------------------------------------------------------+                 |                          +--------+---------+
     |                                                   |         |                 |                          |                  |
     |state = _monitor_workers                           |         |                 |                          |                  |
     |   +                                               |         |                 |                          |                  |
     |   |                                               |         |                 |                          |                  |
     |   | UNHEALTHY,FAILED   1. Process fail            |         |                 |                          |                  |
+--> |   +-----------------> _restart_workers +--+       |         +-->              |                          |                  |
|    |   |                                       |       +         |  |              |                          |                  |
|    |   |                                       +--> _stop_workers|  |              |                          |  LOOP Every 30S  |
|    |   | HEALTHY            2. Node change     |                 |  |              |                          |                  |
|    |   +-----------------> _restart_workers +--+                 |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   | SUCCEEDED                                               |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   | 3. exit                                                 |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    +-------------------------------------------------------------+  |              |                          |                  |
|        |                                                            |              |                          |                  |
<---------------------------------------------------------------------+              |                          +--------+---------+
         |        LOOP  Every 30S                                                    |                                   |
         |                                                                           |                                   |
         v                                                                           |                                   v
       _exit_barrier                                                                 +                             _exit_barrier
Copy the code

Mobile phone as shown below:

Or you can see below, from zhuanlan.zhihu.com/p/408382623 images…

0 x02 multiple processes

The monitoring mechanism is to monitor multiple running training workers, which involves the startup and monitoring of multi-process. We need to introduce multi-process. This can be seen from the entry of starting the worker process.

2.1 start the workers

_start_workers calls start_processes to start the worker process, with _start_method being “spawn” by default. That is, multiple processes are started to execute the user program in parallel. Meanwhile, the running results of these processes are monitored. In the start_PROCESSES parameter, entryPoint and args are user commands and parameters. Entrypoint can be a function or a string.

Then, _start_workers stores the results of start_processes in _pcontext and uses _pcontext to control the process. For example, to end a worker, call the close method of _pcontext directly.

    @prof
    def _start_workers(self, worker_group: WorkerGroup) - >Dict[int.Any] :
        spec = worker_group.spec        store = worker_group.store
        assert store is not None
        master_addr, master_port = super()._get_master_addr_port(store)
        restart_count = spec.max_restarts - self._remaining_restarts

        use_agent_store = spec.rdzv_handler.get_backend() == "static"

        args: Dict[int.Tuple] = {}
        envs: Dict[int.Dict[str.str]] = {}
        for worker in worker_group.workers:
            local_rank = worker.local_rank
            worker_env = {
                "LOCAL_RANK": str(local_rank),
                "RANK": str(worker.global_rank),
                "GROUP_RANK": str(worker_group.group_rank),
                "ROLE_RANK": str(worker.role_rank),
                "ROLE_NAME": spec.role,
                "LOCAL_WORLD_SIZE": str(spec.local_world_size),
                "WORLD_SIZE": str(worker.world_size),
                "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
                "ROLE_WORLD_SIZE": str(worker.role_world_size),
                "MASTER_ADDR": master_addr,
                "MASTER_PORT": str(master_port),
                "TORCHELASTIC_RESTART_COUNT": str(restart_count),
                "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
                "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
                "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
                "NCCL_ASYNC_ERROR_HANDLING": str(1),}if "OMP_NUM_THREADS" in os.environ:
                worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
            envs[local_rank] = worker_env
            worker_args = list(spec.args)
            worker_args = macros.substitute(worker_args, str(local_rank))
            args[local_rank] = tuple(worker_args)

        # scaling events do not count towards restarts (gets same attempt #)
        # remove existing log dir if this restart is due to a scaling event
        attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
        shutil.rmtree(attempt_log_dir, ignore_errors=True)
        os.makedirs(attempt_log_dir)

        assert spec.entrypoint is not None
        self._pcontext = start_processes( Save the result of starting multithreading in _pcontext.
            name=spec.role,
            entrypoint=spec.entrypoint, # Training code entry
            args=args, # local rank = local rank
            envs=envs,
            log_dir=attempt_log_dir,
            start_method=self._start_method,
            redirects=spec.redirects,
            tee=spec.tee,
        )

        return self._pcontext.pids()
Copy the code

2.1.1 start_processes

Note the start_processes code in the torch/distributed/elastic/multiprocessing/API. Py, and use of mp start_processes different behind. Start_processes extracts local ranks from args and performs operations based on local_ranks, such as creating log files for each process. Its meaning is: associate each worker process with local_rank, and one local_rank corresponds to one worker process.

def start_processes(
    name: str,
    entrypoint: Union[Callable.str],
    args: Dict[int.Tuple],
    envs: Dict[int.Dict[str.str]],
    log_dir: str,
    start_method: str = "spawn",
    redirects: Union[Std, Dict[int, Std]] = Std.NONE,
    tee: Union[Std, Dict[int, Std]] = Std.NONE,
) -> PContext:
    """
    Starts ``n`` copies of ``entrypoint`` processes with the provided options.
    ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
    The number of copies is determined by the number of entries for ``args`` and
    ``envs`` arguments, which need to have the same key set.

    ``args`` and ``env`` parameters are the arguments and environment variables
    to pass down to the entrypoint mapped by the replica index (local rank).
    All local ranks must be accounted for.
    That is, the keyset should be ``{0,1,...,(nprocs-1)}``.

    Args:
        name: a human readable short name that describes what the processes are
              (used as header when tee'ing stdout/stderr outputs)
        entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
        args: arguments to each replica
        envs: env vars to each replica
        log_dir: directory used to write log files
        nprocs: number of copies to create (one on each process)
        start_method: multiprocessing start method (spawn, fork, forkserver)
                      ignored for binaries
        redirects: which std streams to redirect to a log file
        tees: which std streams to redirect + print to console

    """

    # listdir raises FileNotFound or NotADirectoryError so no need to check manually
    if os.listdir(log_dir):
        raise RuntimeError(
            f"log_dir: {log_dir} is not empty, please provide an empty log_dir"
        )

    nprocs = len(args)
    _validate_full_rank(args, nprocs, "args")
    _validate_full_rank(envs, nprocs, "envs")

    # create subdirs for each local rank in the logs_dir
    redirs = to_map(redirects, nprocs)
    ts = to_map(tee, nprocs)

    # to tee stdout/stderr we first redirect into a file
    # then tail -f stdout.log/stderr.log so add tee settings to redirects
    for local_rank, tee_std in ts.items():
        redirect_std = redirs[local_rank]
        redirs[local_rank] = redirect_std | tee_std

    stdouts = {local_rank: "" for local_rank in range(nprocs)}
    stderrs = {local_rank: "" for local_rank in range(nprocs)}
    tee_stdouts: Dict[int.str] = {}
    tee_stderrs: Dict[int.str] = {}
    error_files = {}

    # heavy use of local_rank
    for local_rank in range(nprocs):
        clogdir = os.path.join(log_dir, str(local_rank))
        os.mkdir(clogdir)

        rd = redirs[local_rank]
        if (rd & Std.OUT) == Std.OUT:
            stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
        if (rd & Std.ERR) == Std.ERR:
            stderrs[local_rank] = os.path.join(clogdir, "stderr.log")

        t = ts[local_rank]
        if t & Std.OUT == Std.OUT:
            tee_stdouts[local_rank] = stdouts[local_rank]
        if t & Std.ERR == Std.ERR:
            tee_stderrs[local_rank] = stderrs[local_rank]

        error_file = os.path.join(clogdir, "error.json")
        error_files[local_rank] = error_file
        envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file

    context: PContext
    if isinstance(entrypoint, str) : context = SubprocessContext( name=name, entrypoint=entrypoint, args=args, envs=envs, stdouts=stdouts, stderrs=stderrs, tee_stdouts=tee_stdouts, tee_stderrs=tee_stderrs, error_files=error_files, )else:
        context = MultiprocessContext(
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
            start_method=start_method,
        )

    try:
        context.start()
        return context
    except Exception:
        context.close()
        raise
Copy the code

2.1.2 RunResult

The result of a worker process run is indicated by RunResult. RunResult is the result returned by the worker thread. The running results follow the “all-or-nothing” strategy, where the running will be successful only when all local workers managed by the agent are successfully completed.

As mentioned above, each worker process is associated with local_rank, which is also correct. If there are 5 Gpus, of course, 5 worker process training will be started, and these 5 worker processes correspond to local rank 0~4.

However, the RunResult annotation states that if the result is successful (for example, is_failed() = False), the return_VALUES field contains the output (return value) of the worker processes managed by this agent, which are mapped in its GLOBAL ranks. That is, result.return_values[0] is the return value of global rank 0. So, in _Monitor_workers there is a mapping from the local rank to the Gloabl rank, which we’ll talk about later.

@dataclass
class RunResult:
    """ Results returned by the worker executions. Run results follow an "all-or-nothing" policy where the run is successful  if and only if ALL local workers managed by this agent complete successfully. If the result is successful (e.g. ``is_failed() = False``) then the ``return_values`` field contains the outputs (return values) of the workers managed by  THIS agent mapped by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of global rank 0. .. note:: ``return_values`` are only meaningful for when the worker entrypoint is a function. Workers specified as a binary entrypoint do not canonically have a return value and the ``return_values`` field is meaningless and may be empty. If ``is_failed()`` returns ``True`` then the ``failures`` field contains the failure information, again, mapped by the GLOBAL rank of the worker that failed. The keys in ``return_values`` and ``failures`` are mutually exclusive, that is, a worker's final state can only be one of: succeeded, failed. Workers intentionally terminated by the agent according to the agent's restart policy, are not represented in either ``return_values`` nor ``failures``. """

    state: WorkerState
    return_values: Dict[int.Any] = field(default_factory=dict)
    failures: Dict[int, ProcessFailure] = field(default_factory=dict)

    def is_failed(self) - >bool:
        return self.state == WorkerState.FAILED
Copy the code

2.1 use TE

TE uses torch. Mp and subprocess packages for multi-process processing. When multiple processes are started, the result is saved in _pcontext, which is an instance of type PContext.

    self._pcontext = start_processes( Save the result of starting multithreading in _pcontext.
        name=spec.role,
        entrypoint=spec.entrypoint,
        args=args,
        envs=envs,
        log_dir=attempt_log_dir,
        start_method=self._start_method,
        redirects=spec.redirects,
        tee=spec.tee,
    )
Copy the code

Where start_processes and PContext come from the following:

from torch.distributed.elastic.multiprocessing import start_processes, PContext
Copy the code

_monitor_workers Monitors using _pContext. Workerstate. FAILED, workerstate. HEALTHY, or WorkerState.SUCCEEDED will be returned to the upper layer based on the thread result during monitoring.

@prof
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
    role = worker_group.spec.role
    worker_pids = {w.id for w in worker_group.workers}
    assert self._pcontext is not None
    pc_pids = set(self._pcontext.pids().values())
    
    result = self._pcontext.wait(0) # Monitor the running results
    if result:
        if result.is_failed():
            # map local rank failure to global rank
            worker_failures = {}
            for local_rank, failure in result.failures.items():
                worker = worker_group.workers[local_rank]
                worker_failures[worker.global_rank] = failure
            return RunResult(
                state=WorkerState.FAILED, # process error, return workerstate. FAILED
                failures=worker_failures,
            )
        else:
            # copy ret_val_queue into a map with a global ranks
            workers_ret_vals = {}
            for local_rank, ret_val in result.return_values.items():
                worker = worker_group.workers[local_rank]
                workers_ret_vals[worker.global_rank] = ret_val
            return RunResult(
                state=WorkerState.SUCCEEDED,
                return_values=workers_ret_vals,
            )
    else:
        return RunResult(state=WorkerState.HEALTHY)
Copy the code

So, PContext is the key, so let’s just look at this class.

2.2 PContext

PContext is just an abstract class, and it’s really just some basic configuration.

class PContext(abc.ABC) :
    """ The base class that standardizes operations over a set of processes that are launched via different mechanisms. The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. .. warning:: stdouts and stderrs should ALWAYS be a superset of tee_stdouts and tee_stderrs (respectively) this is b/c tee is implemented as a redirect + tail -f 
      
        """
      
    def __init__(
        self,
        name: str,
        entrypoint: Union[Callable.str],
        args: Dict[int.Tuple],
        envs: Dict[int.Dict[str.str]],
        stdouts: Dict[int.str],
        stderrs: Dict[int.str],
        tee_stdouts: Dict[int.str],
        tee_stderrs: Dict[int.str],
        error_files: Dict[int.str].) :
        self.name = name
        # validate that all mappings have the same number of keys and
        # all local ranks are accounted for
        nprocs = len(args)
        _validate_full_rank(stdouts, nprocs, "stdouts")
        _validate_full_rank(stderrs, nprocs, "stderrs")

        self.entrypoint = entrypoint
        self.args = args
        self.envs = envs
        self.stdouts = stdouts
        self.stderrs = stderrs
        self.error_files = error_files
        self.nprocs = nprocs

        self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout)
        self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr)    
Copy the code

But it has two derived classes that are critical: MultiprocessContext and SubprocessContext. As mentioned earlier, entryPoint and args are user commands and parameters for start_PROCESSES. Entrypoint can be a function or a string. If entryPoint is a function, use MultiprocessContext. If it is a string, use SubprocessContext.

def start_processes(
    name: str,
    entrypoint: Union[Callable.str],
    args: Dict[int.Tuple],
    envs: Dict[int.Dict[str.str]],
    log_dir: str,
    start_method: str = "spawn",
    redirects: Union[Std, Dict[int, Std]] = Std.NONE,
    tee: Union[Std, Dict[int, Std]] = Std.NONE,
) -> PContext:
  
    context: PContext
    if isinstance(entrypoint, str) :# if it is a string
        context = SubprocessContext(
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
        )
    else:
        context = MultiprocessContext( The # function comes here
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
            start_method=start_method,
        )

    try:
        context.start() # call here
        return context
    except Exception:
        context.close()
        raise  
Copy the code

Specifically, the two derived classes have different foundations.

  • MultiprocessContextusetorch.multiprocessing.start_processesTo start the process.
  • SubprocessContextusesubprocess.PopenTo start the process.

Let’s just use the MultiprocessContext for the next analysis.

2.3 MultiprocessContext

MultiprocessContext is defined as follows, the most interesting of which is the _pc member variable, which is actually ProcessContext.

import torch.multiprocessing as mp

class MultiprocessContext(PContext) :
    """ ``PContext`` holding worker processes invoked as a function. """

    def __init__(
        self,
        name: str,
        entrypoint: Callable,
        args: Dict[int.Tuple],
        envs: Dict[int.Dict[str.str]],
        stdouts: Dict[int.str],
        stderrs: Dict[int.str],
        tee_stdouts: Dict[int.str],
        tee_stderrs: Dict[int.str],
        error_files: Dict[int.str],
        start_method: str.) :
        super().__init__(
            name,
            entrypoint,
            args,
            envs,
            stdouts,
            stderrs,
            tee_stdouts,
            tee_stderrs,
            error_files,
        )

        self.start_method = start_method
        # each ret_val queue will always contain a single element.
        self._ret_vals = {
            local_rank: mp.get_context(self.start_method).SimpleQueue()
            for local_rank in range(self.nprocs)
        }

        # see comments in ``join()`` for what this is
        self._return_values: Dict[int.Any] = {}
        self._pc: Optional[mp.ProcessContext] = None # Here is the key
        self._worker_finished_event = mp.get_context(self.start_method).Event()
Copy the code

2.3.1 start

MultiprocessContext start calls MP.start_processes and saves the result.

import torch.multiprocessing as mp

		def _start(self) :
        if self._pc:
            raise ValueError(
                "The process context already initialized."
                " Most likely the start method got called twice."
            )
        self._pc = mp.start_processes( # mp.processContext is returned
            fn=_wrap,
            args=(
                self.entrypoint,
                self.args,
                self.envs,
                self.stdouts,
                self.stderrs,
                self._ret_vals,
                self._worker_finished_event,
            ),
            nprocs=self.nprocs,
            join=False,
            daemon=False,
            start_method=self.start_method,
        )
Copy the code

2.3.2 wait

The wait method is in its base class PContext(abc.abc):. The _poll function is called in a loop to check periodically.

    def wait(self, timeout: float = -1, period: float = 1) - >Optional[RunProcsResult]:
        """ Waits for the specified ``timeout`` seconds, polling every ``period`` seconds for the processes to be done. Returns ``None`` if the processes are still running on timeout expiry. Negative timeout values are interpreted as "wait-forever". A timeout value of zero simply queries the status of the processes (e.g. equivalent to a poll). """
        if timeout == 0:
            return self._poll()
        if timeout < 0:
            timeout = sys.maxsize

        expiry = time.time() + timeout
        while time.time() < expiry: # Regular operation
            pr = self._poll() # poll
            if pr:
                return pr
            time.sleep(period)

        return None
Copy the code

2.3.3 _poll

_poll function is specific test, called the torch. The mp. ProcessContext. Join to do testing. Torch.mp.ProcessContext raises an exception when some/all worker processes fail. If a timeout occurs, the worker process status is checked and returned immediately. Because we use Synchronize.Event to wait for all processes to complete, the Join will never return a success.

PyTorch uses MultiProcessing.Queue to return values from the worker process to the parent process, and the result returned internally includes the results of each process’s run.

def _poll(self) - >Optional[RunProcsResult]:

    try:
        # torch.mp.ProcessContext Throws an Exception if some/all of
        # worker processes failed
        # timeout < 0 checks worker status and return immediately
        # Join will never return success since we use synchronize.Event to wait
        # for all processes to finish.
        self._pc.join(-1)

        # IMPORTANT: we use multiprocessing.Queue to carry worker return values
        # back to the parent, the worker process will wait before terminating
        # until all the buffered items are fed by the feeder thread to the underlying
        # pipe. Hence to prevent deadlocks on large return values,
        # we opportunistically try queue.get on each join call
        # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
        
        for local_rank in range(0, self.nprocs): # Walk through the process below yourself
            return_queue = self._ret_vals[local_rank]
            if not return_queue.empty():
                # save the return values temporarily into a member var
                self._return_values[local_rank] = return_queue.get() Get the result of the process running

        if self._is_done():
            # we should ALWAYS have ALL the return values when all the processes are done
            self._worker_finished_event.set(a)# Wait untill all processes are finished. At this point workers finished executing user function
            self._pc.join()
            self.close()
            return RunProcsResult(
                return_values=self._return_values, Return the process result
                stdouts=self.stdouts,
                stderrs=self.stderrs,
            )
        else:
            return None
          
    except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
        failed_local_rank = e.error_index

        # entrypoint for MultiprocessContext will always be a Callable
        fn_name = self.entrypoint.__qualname__  # type: ignore[union-attr]
        failed_proc = self._pc.processes[failed_local_rank]
        error_filepath = self.error_files[failed_local_rank]

        self.close()
        return RunProcsResult( Return the process result
            failures={
                failed_local_rank: ProcessFailure(
                    local_rank=failed_local_rank,
                    pid=e.pid,
                    exitcode=failed_proc.exitcode,
                    error_file=error_filepath,
                )
            },
            stdouts=self.stdouts,
            stderrs=self.stderrs,
        )
Copy the code

2.4 ProcessContext

The key variable for MultiprocessContext is: _pc: Optional[mp.processContext], the member variable is built by start_processes, so we need to look at torch.mp.processContext.

Against 2.4.1 start_processes

Start_processes in torch/multiprocessing/spawn. Py, return ProcessContext. Note that from now on, the training process will run its own training code, as if there were no Agent, because the Agent has finished torch. Distributed. Launch.

def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn') :
    mp = multiprocessing.get_context(start_method)
    error_queues = []
    processes = []
    for i in range(nprocs):
        error_queue = mp.SimpleQueue()
        process = mp.Process(
            target=_wrap,
            args=(fn, i, args, error_queue), The training process starts running the training code
            daemon=daemon,
        )
        process.start()
        error_queues.append(error_queue)
        processes.append(process)

    context = ProcessContext(processes, error_queues)
    if not join:
        return context

    # Loop on join until it returns True or raises an exception.
    while not context.join():
        pass
Copy the code

2.4.2 ProcessContext

Torch. Mp. ProcessContext is the class that finally comes into play. In fact, we don’t care about the internal implementation of Torch. Mp. ProcessContext and how it is started, because through start_processes, Torch. What we really care about is how to use Torch.mp.processContext for monitoring.

From its comments we know that torch.mp.processContext throws an exception when some/all worker processes fail. If a timeout occurs, the worker process status is checked and returned immediately. Because we use Synchronize.Event to wait for all processes to complete, the Join will never return a success.

# torch.mp.ProcessContext Throws an Exception if some/all of
# worker processes failed
# timeout < 0 checks worker status and return immediately
# Join will never return success since we use synchronize.Event to wait
# for all processes to finish.
Copy the code

2.5 summarize

The current relationship is as follows:

  • At build time, LocalElasticAgent generates MultiprocessContext, which in turn generates ProcessContext.
  • LocalElasticAgent._pcontextSave theMultiprocessContext.MultiprocessContext._pcSave theProcessContext.
  • When localElasticAgent. _monitor_workers calls multiProcessContext. wait, and MultiprocessContext calls processContext. join. Processcontext. join specifically monitors the running status of the process, thus completing the overall logic of monitoring.
  • After the child process changes or times out, processContext. join returns the process result, and multiProcessContext. wait forwards the process result back, _monitor_workers Converts process results to WorkerState.SUCCEEDED or workerstate. FAILED.

Specific figure:

+--------------------------------------------------------------------------------------+ +------------------------------------+ +----------------+ | LocalElasticAgent | | MultiprocessContext | | ProcessContext  | | | | | | | | | | | | | | +----------------------------------------+ MultiprocessContext _pcontext | | ProcessContext  _pc | | | | | _invoke_run | | | | | | | | | | | | | | | | _initialize_workers +--------------------> _pcontext = start_processes +--------------> start(): | | | | | | | | _pc = mp.start_processes +-----------> | | | | | | | | | | |while True:                          |      +--------------------------------+  |   |                                    |   |                |
|  |       _monitor_workers(_worker_group)+------> | _monitor_workers               |  |   |                                    |   |                |
|  |                                        |      |                                |  |   |                                    |   |                |
|  |                                        |      |             _pcontext.wait +--------------->  wait +---> poll:             |   |                |
|  |                                        |      |                                |  |   |                    _pc.join  +--------------->          |
|  +----------------------------------------+      +--------------------------------+  |   |                                    |   |                |
|                                                                                      |   |                                    |   |                |
+--------------------------------------------------------------------------------------+   +------------------------------------+   +----------------+

Copy the code

The mobile phone is as follows:

0x03 Monitoring Mechanism

As you can see from the previous _MONITor_workers code, _Monitor_workers converts the child process run results to the specific state of WorkerState. When the agent receives the monitoring results of _Monitor_workers, it will act accordingly.

            Monitor client running status
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state # Process health
            self._worker_group.state = state

            if state == WorkerState.SUCCEEDED:
                The program ends normally
                self._exit_barrier() # If one succeeds, it's all over
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # program error
                if self._remaining_restarts > 0: # retry
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group) # restart
                else:
                    self._stop_workers(self._worker_group) The number of retries reaches, terminating workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
								The program is running properly
                # Node membership changes, such as scale up
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                Restart all workers if a new node is waiting
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")
Copy the code

3.1 monitor

Call _pContext.wait (0) to get the current state of worker child processes, and then convert different workerStates to return the result to the caller. RunResult should map to global rank, so _monitor_workers has a mapping from local rank to Gloabl rank.

Why use Global Rank as a process status indicator? Global Rank is used because nodes need to communicate with each other.

    @prof
    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
        role = worker_group.spec.role
        worker_pids = {w.id for w in worker_group.workers} Get the PID of all workers of the agent
        pc_pids = set(self._pcontext.pids().values())
        ifworker_pids ! = pc_pids:return RunResult(state=WorkerState.UNKNOWN)

        result = self._pcontext.wait(0) Monitor the running structure
        if result:
            if result.is_failed(): If the process fails
                # map local rank failure to global rank
                worker_failures = {}
                The result returned internally contains the results of each process
                for local_rank, failure in result.failures.items(): # local_rank is the process index
                    worker = worker_group.workers[local_rank] Get the corresponding worker
                    worker_failures[worker.global_rank] = failure Get the global_rank and set the worker status
                return RunResult(
                    state=WorkerState.FAILED,
                    failures=worker_failures, Return the result of the run
                )
            else:
                # copy ret_val_queue into a map with a global ranks
                workers_ret_vals = {}
                for local_rank, ret_val in result.return_values.items():
                    worker = worker_group.workers[local_rank] # 
                    workers_ret_vals[worker.global_rank] = ret_val
                return RunResult(
                    state=WorkerState.SUCCEEDED,
                    return_values=workers_ret_vals, Return the result of the run
                )
        else:
            return RunResult(state=WorkerState.HEALTHY)
Copy the code

3.2 processing

Depending on the return status, there will be different processing:

  • If workerstate. SUCCEEDED, the training is complete and normal returns.
  • If workerstate.healthy, the training is running properly and is checked for new nodes to join, as we will explain later.
  • If WorkerState.UNHEALTHY, WorkerState.FAILED, it indicates that the training is faulty.
    • One is that the program fails and the TE retries.
    • One is node exit, which we will examine below, but its processing flow is consistent with a program error.

Let’s take a look at how to handle the end of training and program errors.

0x04 Training Ends

        if state == WorkerState.SUCCEEDED:
            The program ends normally
            self._exit_barrier() # If one succeeds, it's all over
            return run_result
Copy the code

This is what happens at the normal end of the training, in particular the use of _exit_barrier.

4.1 Unified Completion

Torchelastic currently supports DDP style apps. That is to say, TE expects all workers to finish at about the same time. In fact, it is almost impossible to guarantee that all DDP workers can terminate at the same time, so the TE provides a finalization barrier, which is used to implement a timeout (5 minutes) for worker finalization. In other words, if one worker finishes training, TE (Torchelastic) expects all users’ workers to finish within 5 minutes.

def _exit_barrier(self) :
    """ Wait for ``exit_barrier_timeout`` seconds for all agents to finish executing their local workers (either successfully or not). This acts as a safety guard against user scripts that terminate at different times. This barrier keeps the agent process alive until all workers finish. """

    start = time.time()
    try:
        store_util.barrier(
            self._store,
            self._worker_group.group_rank,
            self._worker_group.group_world_size,
            key_prefix=_TERMINAL_STATE_SYNC_ID,
            barrier_timeout=self._exit_barrier_timeout,
        )
    except Exception:
        log.exception(
            f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
        )
Copy the code

The default value of exit_barrier_timeout is 300 seconds, or 5 minutes.

exit_barrier_timeout: float = 300.Copy the code

4.2 the synchronization

In the torch/distributed/elastic/utils/store. Py, barrier will call the synchronize to synchronize.

def barrier(
    store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
) - >None:
    """ A global lock between agents. Note: Since the data is not removed from the store, the barrier can be used once per unique ``key_prefix``. """
    data = f"{rank}".encode(encoding="UTF-8")
    synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)
Copy the code

Synchronize is synchronization through the store.

def get_all(store, prefix: str, size: int) :
    r""" Given a store and a prefix, the method goes through the array of keys of the following format: ``{prefix}{idx}``, where idx is in a range from 0 to size, and tries to retrieve the data. Usage :: values = get_all(store, 'torchelastic/data', 3) value1 = values[0] # retrieves the data for key torchelastic/data0 value2 = values[1] # retrieves the data for key torchelastic/data1 value3 = values[2] # retrieves the data for key torchelastic/data2 """
    data_arr = []
    for idx in range(size):
        data = store.get(f"{prefix}{idx}")
        data_arr.append(data)
    return data_arr

def synchronize(
    store,
    data: bytes,
    rank: int,
    world_size: int,
    key_prefix: str,
    barrier_timeout: float = 300.) - >List[bytes] :
    """ Synchronizes ``world_size`` agents between each other using the underlying c10d store. The ``data`` will be available on each of the agents. Note: The data on the path is not deleted, as a result there can be stale data if you use the same key_prefix twice. """
    store.set_timeout(timedelta(seconds=barrier_timeout))
    store.set(f"{key_prefix}{rank}", data)
    agent_data = get_all(store, key_prefix, world_size)
    return agent_data
Copy the code

0x05 Error handling

5.1 Error Types

Each host in the distributed PyTorch job runs a TorchElastic agent and multiple workers (child processes of the TorchElastic agent). Since the worker is user-provided (PyTorch Script/Job), TorchElastic can propagate errors to the trainer via the agent until the scheduler notifies the end user of the status of those jobs and applies some retry strategy.

TE classifies errors into the following categories.

+----------------+----------------+--------------------------------------------------------------+
| Category       | Sub-Category   |  Description                                                 |
+================+================+==============================================================+
| User Error     | Input Error    | invalid inputs to TorchElastic APIs (e.g. min > max nodes)   |
|                +----------------+--------------------------------------------------------------+
|                | Worker Failure | any failures on the worker child process                     |
+----------------+----------------+--------------------------------------------------------------+
| PlatformError | n/a | failures caused by the agent | +----------------+----------------+--------------------------------------------------------------+ | Infra Error | n/a |  failures outside the domainof the agent and workers         |
|                |                | (e.g. host failures)                                         |
+----------------+----------------+--------------------------------------------------------------+
Copy the code

5.1 Error Handling Mode

The corresponding error handling modes are as follows:

  • User Error: Specific processing methods are as follows:
    • User Error: Such as Error input, so that the direct program can catch.
    • The Worker Failure:
      • Worker Failures are special because exceptions/Failures originate in a different process than the agent, so errors need to be propagated from process to process (for example, agents cannot simplytry-catchAn exception thrown on a worker process.
        • TorchElastic agent for usetorch.distributed.elastic.multiprocessing.start_processesStart the Worker, which has a simple file-based interprocess error propagation built in.
        • Any userecordDecorated functions or binary entry points write uncaught exceptions (with trace information) to environment variablesTORCHELASTIC_ERROR_FILESpecified file. The parent process (such as the agent) sets this environment variable on each child process it starts, then aggregates the error files for all of the child processes and propagates the error files with the minimum timestamp (such as the first error).
      • The document states as follows: For a training job with “N” workers, if the worker named “k<=n” fails, all workers will stop and restart until the restarts for “max_restarts” times is reached. The above sentence actually means: If a worker fails and the maximum restart times are not reached, TE will start new Rendezvous and restart all workers. Since it is new Rendezvous, other TE agents will also restart their workers.
      • The failure of one worker causes the entire cluster to fail: If a single worker repeatedly fails, the MAX_restarts variable for the TE Agent becomes zero. This will cause the Agent to complete its work and shut down Rendezvous. If there are any other workers on different agents, they will also be terminated.
  • Platform Error (agent failure):
    • All errors other than Worker failures are normally raised from the agent process, either implicitly or explicitly crashing the agent process. So you can apply the exception handling strategy provided by the standard language (Python).
    • Proxy failures can also cause local workgroups to fail. It is up to the Job Manager to decide what to do, such as failing the entire job (gang semantics) or trying to replace nodes. Both behaviors are supported by agents.
  • Infra Errors: Handled in the same way as agent faults.

Let’s look at how to handle “Worker Failure” in detail.

5.2 Handling Mechanism

The error handling mechanism is as follows: If the number of retries has not reached the maximum, the system tries to restart workers. If the maximum number of times has been reached, stop workers.

        elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
            # program error
            if self._remaining_restarts > 0: # retry
                self._remaining_restarts -= 1
                self._restart_workers(self._worker_group) # restart
            else:
                self._stop_workers(self._worker_group) The number of retries reaches, terminating workers
                self._worker_group.state = WorkerState.FAILED
                self._exit_barrier()
                return run_result
Copy the code

5.2.1 to restart

_restart_workers will stop all workers and start a new rendezvous.

@prof
def _restart_workers(self, worker_group: WorkerGroup) - >None:
    """ Restarts (stops, rendezvous, starts) all local workers in the group. """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)
Copy the code

5.2.2 stop

Stopping workers is to close the context.

def _shutdown(self) - >None:
    if self._pcontext:
        self._pcontext.close()
        
@prof
def _stop_workers(self, worker_group: WorkerGroup) - >None:
    self._shutdown()
Copy the code

In MultiprocessContext, the close method closes all child processes and waits for them all to stop.

    def _close(self) - >None:
        if self._pc:
            for proc in self._pc.processes:
                proc.terminate()
                proc.join()
Copy the code

5.4 Other Agents Restart

A new rendezvous will cause other agents to restart their workers as well.

When worker fails, TE will check the number of restarts available, if there is more than 0 restarts, TE will start a new rendezvous round and restart the worker process. New rendezvous round will other TE agents to terminate their workers.
Copy the code

How is this done? Details are as follows:

  1. **Agent 0 (faulty Agent) ** The fault was detected by Monitoring.
  2. Agent 0 calls _restart_workers to restart the worker.
  3. Agent 0 will call next_rendezvous to initiate a new rendezvous.
  4. Before performing any operation, such as the Keep Alive operation, Agent 0 calls sync to obtain cluster information from kvstore to ensure that the Agent has the latest cluster status.
  5. Agent 0 adds itself to the local waiting_list.
  6. Agent 0 also calls mark_dirty, which means I have a status update and I need to write KVStore.
  7. Agent 0 calls sync to write its waiting_list to the KVStore.
  8. **Agent 1 (other working agents) ** will call sync to get the latest information from KVStore before doing anything, such as the Keep Alive operation.
  9. Agent 1 uses this information to update its own state so that the local Waiting_list is updated.
  10. The train loop of Agent 1 is in Healthy state after monitoring every 30 seconds because the system is normal.
  11. Agent 1 calls num_nodes_waiting() to see the number of waiting_lists.
  12. Agent 1 gets the number of local Waiting Lists.
  13. If waiting list is not empty, _restart_workers is also called.
  14. This will eventually call Next_rendezvous.

Details are as follows:

 Agent 0                                      Agent 1+---------------------------+ +--------------------------------------------+ | _invoke_run | | _invoke_run | | + | | + |  | | | | | | | |1              |                 |                           |                |
|          v                |                 |                           |                |
| Worker Process Error      |                 |                           |                |
|          +                |                 |                           |                |
|          |                |                 |                           | 10             |
|          | 2              |                 |                           v                |
|          v                |                 |                        HEALTHY             |
|  _restart_workers         |                 |                           +                |
|          +                |                 |                           | 11             |
|          |                |                 |                           |                |
|          | 3              |                 |                           v                |
|          v                |                 |              +-->  num_nodes_waiting() > 0 |
|   next_rendezvous         |                 |              |            +                |
|          +                |                 |              |            |                |
|          | 4              |                 |              | 12         | 13             |
|          |                +   +----------+  |              |            v                |
|          v      cluster info  |          |  |              |       _restart_workers      |
|        sync  <------------+-> | KV Store |  |              |            +                |
|          +                |   |          |  |              |            |                |
|          | 5              |   |          |  |              |            | 14             |
|          v                |   |          |  |              |            v                |
|  Add to local waiting_list|   |          |  |              |        next_rendezvous      |
|          +                |   |          |  |              |                             |
|          |                |   |          |  |              |                             |
|          | 6              |   |          |  |              v                             |
|          v                |   |          |  |                                            |
|     mark_dirty            |   |          |  |  Add to local waiting_list                 |
|          +                |   |          |  |              ^                             |
|          |                |   |          |  |              |                             |
|          | 7              |   |          |  |            9 | waiting_list                |
|          v         7      |   |          |  |    8         +                             |
|        sync +---------------> |          +--------------> sync                           |
|              waiting_list |   |          |  |waiting_list                                |
|                           |   +----------+  |                                            |
+---------------------------+                 +--------------------------------------------+


Copy the code

At this point, we have completed the preliminary introduction of the monitoring mechanism. Due to space limitation, we will continue to introduce how to deal with Scale up/ Down in the next chapter.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

Design and implementation of PyTorch 1.9.0 Elastic Distributed Training

PyTorch Elastic read the source code