0 x00 the

In previous articles, we have studied the basic modules of PyTorch distributed and introduced some official examples. We will now cover the resiliency training of PyTorch. In this third article, we will look at the basic functions of elastic proxy.

Other articles in this series are as follows:

PyTorch distributed elastic training (1) — the general idea

PyTorch Distributed Elastic Training (2)– Startup & Single node flow

0x01 General Background

To summarize, the most important concepts of TE are Agent and Rendezvous.

  • Agent is an independent background process running on a single node, which can be considered as worker Manager or Process Supervisor. It is responsible for starting the worker, monitoring the running of the worker, catching woker anomalies, and passingrendezvousRealize the mutual discovery between workers, and take charge of the basis when there is a member changerendezvousSynchronize changes.
  • In order to achieve elastic training, there needs to be a mechanism for nodes/processes to discover each other. Rendezvous was the discovery mechanism or synchronization component. All workers will rendezvous to establish a new process group when the system is started or when members change.

1.1 Functional Separation

TE is a combination of multiple Elastic agents based on Rendezvous, which is a separation of functions. Let’s compare.

  • The Agent focuses on the logic of a specific node.
    • The Agent is responsible for specific service log-related operations, such as starting processes to execute user programs, monitoring the running status of user programs, and notifying Rendezvous of anomalies.
    • Agent is a worker manager, responsible for starting/managing workers process, forming a worker group, monitoring the running status of workers and capturing invalid workers. If there is a fault/new worker, Restart the worker group.
    • The Agent is responsible for maintaining WORLD_SIZE and RANK information. Users do not need to manually provide these services. The Agent handles them automatically.
    • Agent is an independent background process on a specific node. Agents cannot achieve overall elastic training on their own, so a mechanism is needed to realize mutual discovery and change synchronization among workers (WORLD_SIZE and RANK also require synchronization of multiple nodes), which is the following Rendezvous concept.
  • Rendezvous is responsible for the cluster logic that ensures a strong consensus among nodes on “which nodes participate in training.”
    • Each Agent contains an internal Rendezvous handler, which collectively constitutes a Rendezvous cluster and thus an Agent cluster.
    • Upon Rendezvous, a shared key-value store was created that implements a shared key-value storetorch.distributed.StoreAPI. This storage, shared only by members who have completed Rendezvous, is intended to allow Torch Distributed Elastic to exchange control and data information during initialization operations.
    • Rendezvous is responsible for maintaining all relevant information about the current group on each agent. Each agent will have a rendezvous, which will communicate with each other and generally maintain a set of information stored in the aforementioned stores.
    • Rendezvous is responsible for cluster logic, such as adding new nodes, removing nodes, assigning ranks, and so on.

Let’s take a look at the schematic diagram from the source code, so you have a general idea.

1.2 Rendezvous

We will briefly introduce Rendezvous in this article, focusing on the Agent.

In the Torch Distributed Elastic context, we used the term rendezvous to refer specifically to a specific feature: A distributed synchronization primitive that combines peer Discovery.

Rendezvous was used by Torch Distributed Elastic to gather participants (nodes) for a training job, so that participants could negotiate a list of participants and each participant’s role, and make a consistent collective decision as to when training should start/resume.

Rendezvous decoupled functionality and abstracted business logic into a series of operators, such as _RendevzousJoinOp. Rendezvous maintains a set of state machines, operators that determine what to do next. Such as _RendezvousOpExecutor to perform the operators, and perform the operations based on the results of the operators to determine the Action to be performed next.

Among _DistributedRendezvousOpExecutor, for example, if it is found that the current action is ADD_TO_WAIT_LIST, executes _add_to_wait_list, Call self._state.wait_list.add(self._node)

if action == _Action.KEEP_ALIVE:
    self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS:
    self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST: Discover the current Action
    self._add_to_wait_list() # execute
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
    self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
    self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
    self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
    self._mark_rendezvous_closed()
Copy the code

0x02 Agent General logic

2.1 features

The Elastic Agent is torchelastic’s control plane. It is an independent process that starts and manages the underlying worker process. The agent is responsible for:

  • Working with PyTorch native distribution: Each worker has all the information it needs to make a successful calltorch.distributed.init_process_group().
  • Fault tolerance: Monitor every worker, terminate all workers and restart them in time when errors or anomalies occur.
  • Flexibility: Responds to member changes and restarts all workers with new members.

The following image is from Zhihu, which is a refinement of the previous one.

2.2 Work Basis

Torchelast Agent and user worker work according to failover contract:

  • TE (Torchelastic) expects the user worker to complete the work within 5 minutes.
  • When designing a DDP application, it is best to have all workers fail, not just one worker.
  • TE does not synchronize restart times between agents.
  • TE Re-rendezvous will not reduce the number of restarts.
  • When the individual agent completes its work (successfully or failed), it will close Rendezvous. If other agents still have workers working, they will be terminated.
  • Based on the above, scaling down does not work if at least one agent completes the task.
  • When the agent detects Scale up, it does not reduce the “max_restarts”.
  • Torchelast agents work cooperatively through ETCD or similar backends.

2.3 the deployment

Simple Agents are deployed on each node and work with local processes. More advanced agents can start and manage workers remotely. Agents can achieve complete decentralization, communicate and coordinate with other agents (workers managing the same job) to make a collective decision based on the situation of the workers they manage.

For how to configure, the source code also provides an example. If you start training a job with 8 trainers (one trainer per GPU) on a GPU, you can do the following configuration.

1. Use 8 x single GPU instances, place an agent per instance, managing 1 worker per agent.
2. Use 4 x double GPU instances, place an agent per instance, managing 2 workers per agent.
3. Use 2 x quad GPU instances, place an agent per instance, managing 4 workers per agent.
4. Use 1 x 8 GPU instance, place an agent per instance, managing 8 workers per agent.
Copy the code

2.4 the base class

The base ElasticAgent Class is an Abstract Class from which all real running agents derive. As you can see from the ElasticAgent comment, the agent process is responsible for managing one or more worker processes. The worker process is assumed to be a regular distributed PyTorch script. When the worker process is created by the agent, the agent provides the worker process with the necessary information to properly initialize the Torch process group. At deployment time, the exact topology and agent-to-worker ratio depends on the specific implementation of the agent and user job placement preferences.

class ElasticAgent(abc.ABC) :
    """ Agent process responsible for managing one or more worker processes. The worker processes are assumed to be regular distributed PyTorch scripts. When the worker process is created by the agent, the agent provides the necessary information for the worker processes to properly initialize a torch process group. The exact deployment topology and ratio of agent-to-worker is dependent on the specific implementation of the agent and the user's job placement preferences. Usage :: group_result = agent.run() if group_result.is_failed(): # workers failed failure = group_result.failures[0] log.exception(f"worker 0 failed with exit code : {failure.exit_code}") else: return group_result.return_values[0] # return rank 0's results """

    @abc.abstractmethod
    def run(self, role: str = DEFAULT_ROLE) -> RunResult:
        """ Runs the agent, retrying the worker group on failures up to ``max_restarts``. Returns: The result of the execution, containing the return values or failure details for each worker mapped by the worker's global rank. Raises: Exception - any other failures NOT related to worker process """
        raise NotImplementedError()

    @abc.abstractmethod
    def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
        """ Returns: The ``WorkerGroup`` for the given ``role``. Note that the worker group is a mutable object and hence in a multi-threaded/process environment it may change state. Implementors are encouraged (but not required) to return a defensive read-only copy. """
        raise NotImplementedError()
Copy the code

ElasticAgent has two derived classes:

  • SimpleElasticAgentSome functions of the base class are implemented to make it easy to extend the implementation of the new proxy.
  • LocalElasticAgentThe derivedSimpleElasticAgent, is the final agent of flexibility training. It is mainly used for local operation and is responsible for managing all worker processes on a single machine.

0x03 Worker

Let’s first look at worker, which is the subject managed by Agent.

3.1 the Worker definition

Worker class represents a Worker instance. We introduced WorkerSpec above. Worker is constructed according to WorkerSpec, and its key member variables are as follows:

  • Id (any) : uniquely identifies a worker, as explained by the specific implementation of ElasticAgent. For local agents, it can be the worker’s PID (int). For remote agents, it can be encoded as’ host:port (string) ‘.

  • Local_rank: local rank of the worker.

  • Global_rank: Worker’s global rank.

  • Role_rank: Ranks of all workers with the same role.

  • World_size: number of global workers.

  • Role_world_size: number of workers with the same role.

class Worker:
    """ Represents a worker instance. Contrast this with ``WorkerSpec`` that represents the specifications of a worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to a ``WorkerSpec`` as an object is to a class. The ``id`` of the worker is interpreted by the specific implementation of ``ElasticAgent``. For a local agent, it could be the ``pid (int)`` of the worker, for a remote agent it could be encoded as ``host:port (string)``. Args: id (Any): uniquely identifies a worker (interpreted by the agent) local_rank (int): local rank of the worker global_rank (int): global rank of the worker role_rank (int): rank of the worker across all workers that have the same role world_size (int): number of workers (globally) role_world_size (int): number of workers that have the same role """

    def __init__(
        self,
        local_rank: int,
        global_rank: int = -1,
        role_rank: int = -1,
        world_size: int = -1,
        role_world_size: int = -1.) :
        # unique identifier for this worker
        self.id: Any = None

        # rank of the worker among workers with the same role being monitored
        # by the same ``agent`` instance.
        self.local_rank: int = local_rank

        # rank of the worker among all the workers across all roles
        # across all ``agent`` instances.
        # Global rank is not stable between re-rendezvous.
        self.global_rank: int = global_rank

        # rank of the worker among all the workers with the same role
        # across all ``agent`` instances.
        # Global rank is not stable between re-rendezvous.
        self.role_rank: int = role_rank

        # total number of workers (globally). Due to elasticity
        # the world size may change between re-rendezvous.
        self.world_size: int = world_size

        # total number of workers that share the same role. Due to elasticity
        # the role world size may change between re-rendezvous.
        self.role_world_size: int = role_world_size
Copy the code

3.2 WorkerGroup

WorkerGroup represents a working group that, as a whole, manages multiple workers for batch processing.

class WorkerGroup:
    """ Represents the set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker group contains cross instance workers or not depends on the implementation of the agent. """
    def __init__(self, spec: WorkerSpec) :
        self.spec = spec
        self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]

        # assigned after rdzv
        self.store = None
        self.group_rank = None
        self.group_world_size = None

        self.state = WorkerState.INIT
Copy the code

During the SimpleElasticAgent initialization, a WorkerGroup is created.

class SimpleElasticAgent(ElasticAgent) :
    """ An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` (e.g. one particular type of worker role). """

    def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300) :
        self._worker_group = WorkerGroup(spec)
        self._remaining_restarts = self._worker_group.spec.max_restarts
        self._store = None
        self._exit_barrier_timeout = exit_barrier_timeout
        self._total_execution_time = 0
Copy the code

3.3 WorkerState

WorkerState Indicates the WorkerGroup status. All workers in the working group as a whole maintain/change status. If one worker in the group fails, the whole group is considered to have failed:

  UNKNOWN - agent lost track of worker group state, unrecoverable
  INIT - worker group object created not yet started
  HEALTHY - workers running and healthy
  UNHEALTHY - workers running and unhealthy
  STOPPED - workers stopped (interruped) by the agent
  SUCCEEDED - workers finished running (exit 0)
  FAILED - workers failed to successfully finish (exit !0)
Copy the code

The meanings of these states are as follows:

  • UNKNOWN- The agent lost track of the status of the workgroup and cannot recover it

  • INIT- The workgroup object created has not been started

  • Healthy-worker Is running healthily

  • The grid-worker is running but UNHEALTHY

  • STOPPED- Agent stops (interrupts) the worker

  • SUCCEEDED-worker completes running (exit value 0)

  • FAILED- the worker FAILED to complete successfully (exit value does not equal 0)

The workgroup starts with an initial INIT state, then moves into a “healthy” or “unhealthy” state, and finally reaches the terminal “success” or “failure” state. Workgroups can be interrupted by agents and temporarily placed in a “stop” state. A worker process in the “stopped” state can be scheduled to restart in the near future. Examples of a worker process being set to the stopped state are:

  • Observe the working group fault | unhealthy
  • Member changes detected

When an operation on a workgroup (start, stop, RDZV, retry, etc.) fails and the action part is applied to the workgroup, the state will be “unknown”. This typically occurs when an exception occurs during a state change and the exception is not caught/handled. When a workgroup is in the “unknown” state, the agent does not restore the workgroup, so it is best to terminate the job and have the Job Manager retry the node.

WorkerState is defined as follows:

class WorkerState(str, Enum) :
    """ State of the ``WorkerGroup``. Workers in a worker group change state as a unit. If a single worker in a worker group  fails the entire set is considered failed:: UNKNOWN - agent lost track of worker group state, unrecoverable INIT - worker group object created not yet started HEALTHY - workers running and healthy UNHEALTHY - workers running and unhealthy STOPPED - workers stopped (interruped) by the agent SUCCEEDED - workers finished running (exit 0) FAILED - workers failed to successfully finish (exit ! 0) A worker group starts from an initial ``INIT`` state, then progresses to ``HEALTHY`` or ``UNHEALTHY`` states, and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state. Worker groups can be interrupted and temporarily put into ``STOPPED`` state by the agent. Workers in ``STOPPED`` state are scheduled to be restarted in the near future by the agent. Some examples of workers being put into ``STOPPED`` state are: 1. Worker group failure|unhealthy observed 2. Membership change detected When actions (start, stop, rdzv, retry, etc) on worker group fails and results in the action being partially applied to the worker group the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled exceptions during state change events on the agent. The agent is not expected to recover worker groups in ``UNKNOWN`` state and is better off self terminating and allowing the job manager to retry the node. """

    UNKNOWN = "UNKNOWN"
    INIT = "INIT"
    HEALTHY = "HEALTHY"
    UNHEALTHY = "UNHEALTHY"
    STOPPED = "STOPPED"
    SUCCEEDED = "SUCCEEDED"
    FAILED = "FAILED"

    @staticmethod
    def is_running(state: "WorkerState") - >bool:
        """ Returns: True if the worker state represents workers still running (e.g. that the process exists but not necessarily healthy). "" "
        return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY}
Copy the code

0x04 SimpleElasticAgent

SimpleElasticAgent is one of the Agent implementation classes. This abstraction is intended to facilitate extension of the new Agent implementation. The built-in LocalElasticAgent is responsible for managing all worker processes on a single machine. If the user wants to manage all workers on multiple machines with only one agent, instead of only the local worker, You can implement a custom Agent by extending SimpleElasticAgent.

class SimpleElasticAgent(ElasticAgent) :
    """ An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` (e.g. one particular type of worker role). """

    def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300) :
        self._worker_group = WorkerGroup(spec)
        self._remaining_restarts = self._worker_group.spec.max_restarts
        self._store = None
        self._exit_barrier_timeout = exit_barrier_timeout
        self._total_execution_time = 0
Copy the code

4.1 Overall Operation

The SimpleElasticAgent main loop _invoke_run is the core logic (where by default the agent and worker are on the same machine) and does the following:

  • useself._initialize_workers(self._worker_group)Complete initialization, such as starting workers, assigning ranks to each worker, etc.
  • Then enter the while True loop, in which _monitor_workers regularly rotates the user program to obtain the running results of the worker process, and then performs different processing according to the situation.
    • Returns if the program ends normally.
    • If the program fails, retry. If the number of retries reaches, terminate workers.
    • If the node membership changes, for example, scale up, there will be a new node in waiting, and then all workers will be restarted.
    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        self._initialize_workers(self._worker_group) # start the worker
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assertself._worker_group.state ! = WorkerState.INIT# Regular monitoring
            time.sleep(monitor_interval)
            Monitor client running status
            run_result = self._monitor_workers(self._worker_group) Get the result of the process running
            state = run_result.state
            self._worker_group.state = state

            put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
            put_metric(f"workers.{role}.{state.name.lower()}".1)

            if state == WorkerState.SUCCEEDED:
                The program ends normally
                self._exit_barrier()
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # program error
                if self._remaining_restarts > 0: # retry
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group)
                else:
                    self._stop_workers(self._worker_group) The number of retries reaches, terminating workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
                If the node membership changes, for example, scale up, there will be a new node waiting
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                Restart all workers if a new node is waiting
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")
Copy the code

The above is just about the general process, we will analyze the overall process one by one.

4.2 Initializing Workers

In the agent main loop, the worker is first started using self._initialize_workers(self._worker_group). In _initialize_workers:

  • The first to useself._rendezvous(worker_group)Perform synchronous consensus operation between nodes and rank processing, etc.
  • The second call_start_workersStart the workers. Here,_start_workersIs a virtual function that requires a derived class implementation.
    @prof
    def _initialize_workers(self, worker_group: WorkerGroup) - >None:
        r""" Starts a fresh set of workers for the worker_group. Essentially a rendezvous followed by a start_workers. The caller should first call ``_stop_workers()`` to stop running workers prior to calling this method. Optimistically sets the state of the worker group that just started as ``HEALTHY`` and delegates the actual monitoring of state to ``_monitor_workers()`` method """
        role = worker_group.spec.role

        # TODO after stopping workers, wait at least monitor_interval*2 for
        # workers on different nodes to fail on a collective op before waiting
        # on the rdzv barrier, this way we ensure that nodes enter rdzv
        # at around the same time and reduce false positive rdzv timeout errors
        self._rendezvous(worker_group) Synchronize the consensus operation

        worker_ids = self._start_workers(worker_group) # start the worker
        for local_rank, w_id in worker_ids.items():
            worker = worker_group.workers[local_rank]
            worker.id = w_id

        worker_group.state = WorkerState.HEALTHY
Copy the code

2 _rendezvous

We’ll start with _rendezvous, which will do the following:

  • Call next_rendezvous() to handle membership changes, which return world size, Store, etc.
  • Stores are configured into workgroups, so that workers can communicate with each other using this kvstore.
  • A call to _assign_worker_ranks generates the worker and ranks the worker. The returned workers are assigned values in the agent’s worker_group.workers.

Rendezvous information was used to process, for example, extract ranks from Rendezvous.

    @prof
    def _rendezvous(self, worker_group: WorkerGroup) - >None:
        r""" Runs rendezvous for the workers specified by worker spec. Assigns workers a new global rank and world size. Updates  the rendezvous store for the worker group. """

        spec = worker_group.spec

        Group rank = group rank = group rank = group rank
        store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
        self._store = store # store is set to Agent, store can be considered as remote KV storage

        # Ranks the worker based on group rank
        workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec)
        worker_group.workers = workers
        worker_group.store = store
        worker_group.group_rank = group_rank
        worker_group.group_world_size = group_world_size

        if group_rank == 0:
            self._set_master_addr_port(store, spec.master_addr, spec.master_port)
        master_addr, master_port = self._get_master_addr_port(store)
        restart_count = spec.max_restarts - self._remaining_restarts
Copy the code
4.2.2.1 Handling Membership Changes

Rdzv_handler.next_rendezvous () is called to handle membership changes in order to start the next rendezvous operation (since the worker is already started and needs to join the cluster).

Note that next_rendezvous is an internal function of RendezvousHandler. This function call is blocked until the required number of workers is reached. This function is called when the worker is initialized or restarted. When the function returns, different worker groups will be uniquely identified by the returned rank. Its internal logic is:

  • Use the first_RendezvousExitOpExit the node.
  • And then use it_RendezvousJoinOpRejoin the node.
  • Finally start the heartbeat, return world size, store, etc.
    def next_rendezvous(self) - >Tuple[Store, int.int] :
        """See base class."""

        self._stop_heartbeats()

        # Delay the execution for a small random amount of time if this is our
        # first run. This will slightly skew the rendezvous attempts across the
        # nodes and reduce the load on the backend.
        if self._state_holder.state.round= =0:
            _delay(seconds=(0.0.3))

        exit_op = _RendezvousExitOp()
        join_op = _RendezvousJoinOp()

        deadline = self._get_deadline(self._settings.timeout.join)

        self._op_executor.run(exit_op, deadline)
        self._op_executor.run(join_op, deadline)

        self._start_heartbeats()

        rank, world_size = self._get_world()
        store = self._get_store()

        return store, rank, world_size # return the rank of worker group
Copy the code
4.2.3.2 Assign the ranks to the Worker

This is followed by a call to _assign_worker_ranks to set up the ranks for the worker. The rank assignment algorithm is as follows:

  1. Each agent writes its configuration (group_rank, group_world_size, num_workers) to common storage.
  2. Each agent retrieves the configuration of all agents and performs two-level sorting using roles and ranks.
  3. Determine the global rank: The global rank of the current proxy is the offset of the local proxy’s group_rank in the INFOS array. The offset is calculated as the sum of local_worlds for all agents ranked below group_rank. The level of workers is [offset, offset+local_world_size].
  4. Determine the role rank: Determine the role rank using the algorithm in Point 3, except that the offset calculation begins with the first agent that has the same group rank as the current role.
  5. Because all agents use the same algorithm, the ranks array computed is the same.

Then generate workers and assign all worker values to worker_group.workers.

@prof
def _assign_worker_ranks(
    self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
) - >List[Worker]:
    """ Determines proper ranks for worker processes. The rank assignment is done according to the following algorithm: 1. Each agent writes its configuration(group_rank, group_world_size , num_workers) to the common store. 2. Each agent retrieves configuration for all agents and performs two level sort using  role and rank. 3. Determine the global rank: the global rank of the workers for the current agent is the offset of the infos array up to group_rank of the agent. The  offset is computed as a sum of local_world_size of all agents that have rank less than the group_rank. The workers would have the ranks: [offset, offset+local_world_size) 4. Determine the role rank: The role rank is determined using the algorithms in the point 3 with the exception that the offset is done from the first agent that has the same role as current one and has the minimum group rank. """

    # Each agent writes its configuration (group_rank, group_world_size, num_workers) to public storage.
    role_infos = self._share_and_gather(store, group_rank, group_world_size, spec)
    Each agent retrieves the configuration of all agents and performs two-level sorting using roles and ranks.
    my_role_info = role_infos[group_rank]
    The global rank of the current proxy is its group_rank offset in the infOS array. The offset is calculated as the sum of local_worlds for all agents ranked below group_rank. The level of workers is [offset, offset+local_world_size].
    worker_world_size, worker_global_ranks = self._get_ranks(role_infos, group_rank)
    role_infos = sorted(
        role_infos, key=functools.cmp_to_key(_RoleInstanceInfo.compare)
    )
    role_start_idx, role_end_idx = _RoleInstanceInfo.find_role_boundaries(
        role_infos, my_role_info.role
    )
    role_pos = next(
        idx
        for idx, role_info in enumerate(role_infos)
        if _RoleInstanceInfo.compare(role_info, my_role_info) == 0
    )
    Use the algorithm in Point 3 to determine the role rank. The difference is that the offset calculation starts from the first agent that has the same group rank as the current role.
    role_world_size, role_ranks = self._get_ranks(
        role_infos, role_pos, role_start_idx, role_end_idx + 1
    )
    [worker_group. Workers] [worker_group. Workers]
    workers = []
    for ind in range(spec.local_world_size):
        worker = Worker(
            local_rank=ind,
            global_rank=worker_global_ranks[ind],
            role_rank=role_ranks[ind],
            world_size=worker_world_size,
            role_world_size=role_world_size,
        )
        workers.append(worker)
    return workers
Copy the code

4.2.4 Starting the Workers Process

The worker process is started by calling _start_workers of the derived class, so the base class is not implemented here, as we will see later.

    @abc.abstractmethod
    def _start_workers(self, worker_group: WorkerGroup) - >Dict[int.Any] :
        r""" Starts ``worker_group.spec.local_world_size`` number of workers according to worker spec for the worker group . Returns a map of ``local_rank`` to worker ``id``. """
        raise NotImplementedError()
Copy the code

The current logic is as follows:

  1. Call rdzv_handler.next_rendezvous to synchronize with other nodes.
  2. Rdzv_handler. next_rendezvous Returns information about ranks to _assign_worker_ranks.
  3. _assign_worker_ranks generates Workers, where each Worker is automatically assigned a rank. These workers are referred to by the Agent’s worker_group.workers.
+--------------------------------------------------+
| LocalElasticAgent                                |         _initialize_workers
|                                                  |                 +
|                                                  |                 |
|                                                  |                 |
|   +----------------------+                       |                 v
|   |WorkerGroup           |                       |         _rendezvous(worker_group)
|   |                      |                       |                 +
|   |     spec             |                       |                 |
|   |                      |                       |                 | 1
|   |     group_world_size |                       |                 v
|   |                      |                       |        rdzv_handler.next_rendezvous()
|   |     store            |                       |                 +
|   |                      |    +----------------+ |                 |
|   |     group_rank       |    | Worker0(rank 0) | |2 | ranks
|   |                      |    | Worker1(rank 1)| |  Workers        v
|   |     workers  +----------> | ...            | | <----+ _assign_worker_ranks
|   |                      |    | Workern(rank n)| |    3
|   +----------------------+    +----------------+ |
|                                                  |
+--------------------------------------------------+
Copy the code

Next, the functions related to rank and worker will be listed separately for better understanding.

4.3 ranks related

The previous _assign_worker_ranks ranks the worker, but there are some internal details that need to be sorted out.

4.3.1 _RoleInstanceInfo

Here’s an introduction to the _RoleInstanceInfo data structure. Agents use this class to exchange information with other agents. This information is used to determine the rank of the agent workers. These agents work in a heterogeneous environment, and different agents may have different numbers of workers. Its construction parameters are:

  • Role (STR) : user-defined role.
  • Rank (int) : indicates the rank of the proxy.
  • Local_world_size (int) : indicates the number of local workers.
class _RoleInstanceInfo:
    """ The class is used by the agent to exchange the information with other agents. The information is used to determine the rank of the workers that agent manages in heterogeneous environments, where different agents can have different number of workers. """

    __slots__ = ["role"."rank"."local_world_size"]

    def __init__(self, role: str, rank: int, local_world_size: int) :
        r""" Args: role (str): user-defined role for the workers with this spec rank (int): the rank of the agent local_world_size (int): number of local workers to run """

        self.role = role
        self.rank = rank
        self.local_world_size = local_world_size

    def serialize(self) - >bytes:
        dict_data = {
            "role": self.role,
            "rank": self.rank,
            "local_world_size": self.local_world_size,
        }
        return json.dumps(dict_data).encode(encoding="UTF-8")

    @staticmethod
    def deserialize(data: bytes) :
        dict_data = json.loads(data.decode(encoding="UTF-8"))
        return _RoleInstanceInfo(
            dict_data["role"], dict_data["rank"], dict_data["local_world_size"])    @staticmethod
    def compare(obj1, obj2) - >int:
        if obj1.role == obj2.role:
            return obj1.rank - obj2.rank
        elif obj1.role > obj2.role:
            return 1
        else:
            return -1

    @staticmethod
    def find_role_boundaries(roles_infos: List, role: str) - >Tuple[int.int] :
        start_idx, end_idx = -1, -1
        for idx, role_info in enumerate(roles_infos):
            if role_info.role == role:
                if start_idx == -1:
                    start_idx = idx
                end_idx = idx
        return (start_idx, end_idx)
Copy the code

4.3.2 _share_and_gather

The _share_AND_Gather function synchronizes between agents to get overall information about roles. Each agent writes its configuration (group_rank, group_world_size, num_workers) to common storage. This is where information is shared using the store returned by Rendezvous.

    def _share_and_gather(
        self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
    ) - >List:
        agent_role_info = _RoleInstanceInfo(
            spec.role, group_rank, spec.local_world_size
        )
        key_prefix = "torchelastic/role_info"
        agent_config_enc = agent_role_info.serialize()
        role_infos_bytes = store_util.synchronize(
            store, agent_config_enc, group_rank, group_world_size, key_prefix
        )
        role_infos = [
            _RoleInstanceInfo.deserialize(role_info_bytes)
            for role_info_bytes in role_infos_bytes
        ]
        return role_infos
Copy the code

4.3.3 _get_ranks

Determine the global rank based on role InfOS: The global rank of the current agent is the offset of the group_rank of the agent in the INFOS array. The offset is calculated as the sum of local_worlds for all agents ranked below group_rank. The level of workers is [offset, offset+local_world_size].

def _get_ranks(
    self,
    role_infos: List[_RoleInstanceInfo],
    role_idx: int,
    start_idx: int = 0,
    end_idx: int = -1.) - >Tuple[int.List[int]] :
    if end_idx == -1:
        end_idx = len(role_infos)
    prefix_sum = 0
    total_sum = 0
    for idx in range(start_idx, end_idx):
        if role_idx > idx:
            prefix_sum += role_infos[idx].local_world_size
        total_sum += role_infos[idx].local_world_size
    return (
        total_sum,
        list(range(prefix_sum, prefix_sum + role_infos[role_idx].local_world_size)),
    )
Copy the code

The current logic expansion is as follows:

  1. Call rdzv_handler.next_rendezvous() to synchronize with other nodes to obtain information.
  2. Get store (which can be considered as remote KV storage), group_world_size, group_rank in the information to Agent.
  3. Ranks Such information is sent to _assign_worker_ranks method.
  4. In _assign_worker_ranks, the _share_AND_GATHER call is synchronized between agents to get overall role information. Each agent writes its configuration (group_rank, group_world_size, num_workers) to the public KV storage.
  5. Determine the global rank based on role InfOS: The global rank of the current agent is the offset of the group_rank of the agent in the INFOS array. The offset is calculated as the sum of local_worlds for all agents ranked below group_rank.
  6. Set up a series of Workers using various information.
  7. Workers are copied to the Agent’s WorkerGroup.
                                                              _initialize_workers
                                                                      +
                                                                      |
                                                                      |
                                                                      v
                                                              _rendezvous(worker_group)
                                                                      +
+----------------------------------------------+                      |
| LocalElasticAgent                            |                      | 1
|                                              |   2                  v
|                                         +--------------+  rdzv_handler.next_rendezvous()
| +--------------------+                  |    |                      +
| | WorkerGroup        |                  |    |                      |
| |                    |                  |    |                    3 | ranks
| |                    |                  |    |                      v
| |  spec              |                  |    |       +--------------+------------------+
| |                    |                  |    |       | _assign_worker_ranks            |
| |                    |                  |    |       |                                 |
| |  store   <----------------------------+    |       |                        4        |
| |                    |                  |    |       | role_infos = _share_and_gather( |
| |                    |                  |    |       |               +          store) |
| |  group_world_size<--------------------+    |       |               | 5               |
| |                    |                  |    |       |               |                 |
| |                    |                  |    |       |               v                 |
| |  group_rank <-------------------------+    |       |          _get_ranks(world...)   |
| |                    |                       |       |          _get_ranks(role...)    |
| |                    |   +----------------+  |       |               +                 |
| |  workers  +----------->+ Worker0(rank 0)|  |       |               |                 |
| |                    |   | Worker1(rank 1) | | | |6| | | | |... | |Workers| v | | | | | Workern(rank n)+<------------+ new Worker(local_rank, | | +--------------------+ +----------------+ |7  |               global_rank,      |
|                                              |       |               role_rank,        |
+----------------------------------------------+       |               world_size,       |
                                                       |               role_world_size)  |
                                                       |                                 |
                                                       +---------------------------------+
Copy the code

With the Worker instance generated after the _rendezvous operation, let’s look at how to generate the Worker process. But because these methods are not implemented in SimpleElasticAgent, we need to expand our diagrams in the analysis section of its derived class LocalElasticAgent.

4.4 the Worker related

Let’s look at the remaining two worker-related functions of SimpleElasticAgent.

4.4.1 restart

_restart_workers Yes Restart workers.

# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _restart_workers(self, worker_group: WorkerGroup) - >None:
    """ Restarts (stops, rendezvous, starts) all local workers in the group. """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)
Copy the code

4.4.2 barrier

In fact, it is almost impossible to guarantee that all workers in DDP can end at the same time, so TE provides a finalization barrier, which is used to implement a timeout (5 minutes) for worker finalization.

    def _exit_barrier(self) :
        """ Wait for ``exit_barrier_timeout`` seconds for all agents to finish executing their local workers (either successfully or not). This acts as a safety guard against user scripts that terminate at different times. This barrier keeps the agent process alive until all workers finish. """
        start = time.time()
        try:
            store_util.barrier(
                self._store,
                self._worker_group.group_rank,
                self._worker_group.group_world_size,
                key_prefix=_TERMINAL_STATE_SYNC_ID,
                barrier_timeout=self._exit_barrier_timeout,
            )
        except Exception:
            log.exception(
                f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
            )
Copy the code

0x05 LocalElasticAgent

LocalElasticAgent is the final agent used for elasticity training. It is mainly used for local operations and is responsible for managing all worker processes on a single machine. It is derived from SimpleElasticAgent.

This agent is deployed on each host and configured to generate n worker processes. When gpus are used, n is the number of Gpus available on the host. Local agents do not communicate with other local agents deployed on other hosts, even though workers can communicate between hosts. The Worker ID is interpreted as a local process. The agent starts and stops all worker processes on the machine as a whole.

The functions and arguments passed to the worker must be compatible with Python Multiprocessing. To pass the multiprocessing data structure to the worker, the user can create the data structure in the same multiprocessing as the specified start_method and pass it as a function parameter.

Exit_barrier_timeout Specifies the amount of time, in seconds, to wait for other agents to complete. This serves as a safety net, which can deal with situations where workers complete at different times, so as to prevent agents from treating workers completed in advance as scale-down events. It is strongly recommended that the user code ensure that the worker terminates synchronously, rather than relying on exit_barrier_timeout.

SimpleElasticAgent is mainly provided for initialization and general operation, but some abstract functions are left unimplemented, such as _start_workers, _stop_workers, _Monitor_workers, and _shutdown. LocalElasticAgent completes these functions.

class LocalElasticAgent(SimpleElasticAgent) :
    """ An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. This agent is deployed per host and is configured to spawn ``n`` workers. When using GPUs, ``n`` maps to the number of GPUs available on the host. The local agent does not communicate to other local agents deployed on other hosts, even if the workers may communicate inter-host. The worker id is interpreted to be a local process. The agent starts and  stops all worker processes as a single unit. The worker function and argument passed to the worker function must be python multiprocessing compatible. To pass multiprocessing data structures to the workers you may create the data structure in the same multiprocessing context as the specified ``start_method`` and pass it as a function argument. The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait for other agents to finish. This acts as a safety net to handle cases where workers finish at different times, to prevent agents from viewing workers that finished early as a scale-down event. It is strongly advised that the user code deal with ensuring that workers are terminated in a synchronous manner rather than relying on the exit_barrier_timeout. """

    def __init__(
        self,
        spec: WorkerSpec,
        start_method="spawn",
        exit_barrier_timeout: float = 300,
        log_dir: Optional[str] = None.) :
        super().__init__(spec, exit_barrier_timeout)
        self._start_method = start_method
        self._pcontext: Optional[PContext] = None
        rdzv_run_id = spec.rdzv_handler.get_run_id()
        self._log_dir = self._make_log_dir(log_dir, rdzv_run_id)

    def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str) :
        base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
        os.makedirs(base_log_dir, exist_ok=True)
        dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_".dir=base_log_dir)
        return dir
Copy the code

5.1 the use of

Let’s first extract the code from its comments and see how it works. Here’s how to start function as an entry point.

    def trainer(args) - >str:
        return "do train"

    def main() :
        start_method="spawn"
        shared_queue= multiprocessing.get_context(start_method).Queue()
        spec = WorkerSpec(
                    role="trainer",
                    local_world_size=nproc_per_process,
                    entrypoint=trainer,
                    args=("foobar",),... <OTHER_PARAMS... >) agent = LocalElasticAgent(spec, start_method) results = agent.run()if results.is_failed():
            print("trainer failed")
        else:
            print(f"rank 0 return value: {results.return_values[0]}")
            # prints -> rank 0 return value: do train
Copy the code

Here’s how to launch binary as an entry point.

    def main() :
        spec = WorkerSpec(
                    role="trainer",
                    local_world_size=nproc_per_process,
                    entrypoint="/usr/local/bin/trainer",
                    args=("--trainer_args"."foobar"),... <OTHER_PARAMS... >) agent = LocalElasticAgent(spec) results = agent.run()if not results.is_failed():
            print("binary launches do not have return values")
Copy the code

With the Worker instance generated after the _rendezvous operation, let’s look at how to generate the Worker process.

5.2 stop

The following function stops workers.

    @prof
    def _stop_workers(self, worker_group: WorkerGroup) - >None:
        self._shutdown()
        
    def _shutdown(self) - >None:
        if self._pcontext:
            self._pcontext.close()        
Copy the code

5.3 the initialization

Now that the Worker instance has been generated following the _rendezvous operation, let’s look at how to generate the Worker process. Since these methods were not implemented in SimpleElasticAgent, we continue to expand our diagrams in this summary.

Let’s first look at initializing workers. Within _initialize_workers, the workers instance is first created using _rendezvous, followed by a call to _start_workers to start workers.

    @prof
    def _initialize_workers(self, worker_group: WorkerGroup) - >None:
        r""" Starts a fresh set of workers for the worker_group. Essentially a rendezvous followed by a start_workers. The caller should first call ``_stop_workers()`` to stop running workers prior to calling this method. Optimistically sets the state of the worker group that just started as ``HEALTHY`` and delegates the actual monitoring of state to ``_monitor_workers()`` method """
        role = worker_group.spec.role

        # TODO after stopping workers, wait at least monitor_interval*2 for
        # workers on different nodes to fail on a collective op before waiting
        # on the rdzv barrier, this way we ensure that nodes enter rdzv
        # at around the same time and reduce false positive rdzv timeout errors
        self._rendezvous(worker_group) The # Worker instance has been generated

        worker_ids = self._start_workers(worker_group) Start the Worker process
        for local_rank, w_id in worker_ids.items():
            worker = worker_group.workers[local_rank]
            worker.id = w_id Get the process ID

        worker_group.state = WorkerState.HEALTHY
Copy the code

5.4 Starting worker Processes

The _start_workers method calls start_processes to start the worker process, with the default _start_method being “spawn”. That is, multiple processes are started to execute the user program in parallel. Meanwhile, the running results of these processes are monitored. In the start_PROCESSES parameter, entryPoint and args are user commands and parameters. Entrypoint can be a function or a string.

_start_workers stores the results of start_processes starting multithreading in _pcontext and then uses _pcontext to continue control. For example, to terminate the worker is to directly call the close method of _pcontext.

    @prof
    def _start_workers(self, worker_group: WorkerGroup) - >Dict[int.Any] :
        spec = worker_group.spec
        store = worker_group.store
        assert store is not None
        master_addr, master_port = super()._get_master_addr_port(store)
        restart_count = spec.max_restarts - self._remaining_restarts

        use_agent_store = spec.rdzv_handler.get_backend() == "static"

        args: Dict[int.Tuple] = {}
        envs: Dict[int.Dict[str.str]] = {}
        for worker in worker_group.workers:
            local_rank = worker.local_rank
            worker_env = {
                "LOCAL_RANK": str(local_rank),
                "RANK": str(worker.global_rank),
                "GROUP_RANK": str(worker_group.group_rank),
                "ROLE_RANK": str(worker.role_rank),
                "ROLE_NAME": spec.role,
                "LOCAL_WORLD_SIZE": str(spec.local_world_size),
                "WORLD_SIZE": str(worker.world_size),
                "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
                "ROLE_WORLD_SIZE": str(worker.role_world_size),
                "MASTER_ADDR": master_addr,
                "MASTER_PORT": str(master_port),
                "TORCHELASTIC_RESTART_COUNT": str(restart_count),
                "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
                "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
                "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
                "NCCL_ASYNC_ERROR_HANDLING": str(1),}if "OMP_NUM_THREADS" in os.environ:
                worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
            envs[local_rank] = worker_env
            worker_args = list(spec.args)
            worker_args = macros.substitute(worker_args, str(local_rank))
            args[local_rank] = tuple(worker_args)

        # scaling events do not count towards restarts (gets same attempt #)
        # remove existing log dir if this restart is due to a scaling event
        attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
        shutil.rmtree(attempt_log_dir, ignore_errors=True)
        os.makedirs(attempt_log_dir)

        self._pcontext = start_processes( Save the result of starting multithreading in _pcontext.
            name=spec.role,
            entrypoint=spec.entrypoint,
            args=args,
            envs=envs,
            log_dir=attempt_log_dir,
            start_method=self._start_method,
            redirects=spec.redirects,
            tee=spec.tee,
        )

        return self._pcontext.pids()
Copy the code

5.5 monitor

After running, TE calls _Monitor_workers to monitor workers. The result of starting multiple threads was saved in _pcontext, and now _pContext is used to monitor the performance.

    @prof
    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
        role = worker_group.spec.role
        worker_pids = {w.id for w in worker_group.workers}
        assert self._pcontext is not None
        pc_pids = set(self._pcontext.pids().values())
        ifworker_pids ! = pc_pids:return RunResult(state=WorkerState.UNKNOWN)

        result = self._pcontext.wait(0) Monitor the running structure
        if result:
            if result.is_failed(): If the process fails
                # map local rank failure to global rank
                worker_failures = {}
                The result returned internally contains the results of each process
                for local_rank, failure in result.failures.items():
                    worker = worker_group.workers[local_rank]
                    worker_failures[worker.global_rank] = failure
                return RunResult(
                    state=WorkerState.FAILED,
                    failures=worker_failures, Return the result of the run
                )
            else:
                # copy ret_val_queue into a map with a global ranks
                workers_ret_vals = {}
                for local_rank, ret_val in result.return_values.items():
                    worker = worker_group.workers[local_rank]
                    workers_ret_vals[worker.global_rank] = ret_val
                return RunResult(
                    state=WorkerState.SUCCEEDED,
                    return_values=workers_ret_vals, Return the result of the run
                )
        else:
            return RunResult(state=WorkerState.HEALTHY)
Copy the code

Since the startup and monitoring involved the overall system operation logic, which needed to be better understood together with Rendezvous, we postponed this part of the analysis until rendezvous to do the overall analysis.

The current general logic is as follows:

  1. Call rdzv_handler.next_rendezvous() to synchronize with other nodes to obtain information.
  2. Get store (which can be considered as remote KV storage), group_world_size, group_rank in the information to Agent.
  3. Ranks Such information is sent to _assign_worker_ranks method.
  4. In _assign_worker_ranks, the _share_AND_GATHER call is synchronized between agents to get overall role information. Each agent writes its configuration (group_rank, group_world_size, num_workers) to the public KV storage.
  5. Determine the global rank based on role InfOS: The global rank of the current agent is the offset of the group_rank of the agent in the INFOS array. The offset is calculated as the sum of local_worlds for all agents ranked below group_rank.
  6. Set up a series of Workers using various information.
  7. Workers are copied to the Agent’s WorkerGroup.
  8. Use _start_workers to start the worker process.
  9. Assign the worker process ID to the Agent’s worker.id so that the worker.id can be used to operate the process.
  10. Use _monitor_workers to monitor worker processes.
  11. Use _exit_barrier to wait for the worker process to end.
                                                              _initialize_workers
                                                                      +
                                                                      |
                                                                      |
                                                                      v
                                                              _rendezvous(worker_group)
                                                                      +
+----------------------------------------------+                      |
| LocalElasticAgent                            |                      | 1
|                                              |   2                  v
|                                         +--------------+  rdzv_handler.next_rendezvous()
| +--------------------+                  |    |                      +
| | WorkerGroup        |                  |    |                      |
| |                    |                  |    |                    3 | ranks
| |                    |                  |    |                      v
| |  spec              |                  |    |       +--------------+------------------+
| |                    |                  |    |       | _assign_worker_ranks            |
| |                    |                  |    |       |                                 |
| |  store   <----------------------------+    |       |                        4        |
| |                    |                  |    |       | role_infos = _share_and_gather( |
| |                    |                  |    |       |               +          store) |
| |  group_world_size<--------------------+    |       |               | 5               |
| |                    |                  |    |       |               |                 |
| |                    |                  |    |       |               v                 |
| |  group_rank <-------------------------+    |       |          _get_ranks(world...)   |
| |                    |                       |       |          _get_ranks(role...)    |
| |                    |   +----------------+  |       |               +                 |
| |  workers  +----------->+ Worker0(rank 0)|  |       |               |                 |
| |                    |   | Worker1(rank 1) | | | |6| | | | |... | |Workers| v | | | | | Workern(rank n)+<------------+ new Worker(local_rank, | | +--------------------+ +---------+------+ |7  |               global_rank,      |
|                                    ^         |       |               role_rank,        |
|                                    |         |       |               world_size,       |
|                                    |         |       |               role_world_size)  |
+----------------------------------------------+       |                                 |
                                     |                 +---------------+-----------------+
                                     |                                 |
                                     |                                 | 8
                                     |              9                  v
                                     +-----------------------+   _start_workers
                                                                       +
                                                                       | 10| v +---------------+--------------+ | state = _monitor_workers | +--> | +--> | +---------------+--------------+ | | | |  <--------------------------------------+ LOOP Every 30S | |11
                                                                       v
                                                                    _exit_barrier
Copy the code

The mobile phone is as follows:

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

TorchElastic – Flexible, fault-tolerant distributed training