0 x00 the

Horovod is an easy-to-use, high-performance distributed training framework released by Uber in 2017 that has been widely used in the industry.

This series takes you through the source code analysis of Horovod. This article, the 17th in a series, looks at Horovod’s fault-tolerant mechanisms.

We still use questions to guide learning.

The question is:

  • Are these exceptions automatically emitted by each worker?
  • Do workers throw exceptions together?
  • How are these exceptions notified to the Driver?

We will analyze one by one below (in order to be written independently, part of the principle content of this article is the same as the previous article).

Links to other articles in this series are as follows:

Horovod (1) — Basics

Horovod (2) — A distributed training framework for deep learning — from the user’s perspective

Horovod (3) — What’s behind Horovodrun

Horovod (4) — Network Basics & Driver

Horovod (5) — fusion framework

Horovod (6) — background architecture

Horovod (6), a distributed training framework for deep learning, is implemented using threads

Horovod (7) — DistributedOptimizer

Horovod (8) — on Spark

Horovod (9) — Start on Spark

Horovod (10) — Run on Spark

Horovod (11) — on Spark — GLOO scheme

Horovod (12) — a distributed training framework for deep learning horoVOd (12) — an overall architecture for elastic training

Horovod (13), a distributed training framework for deep learning, is a Driver for elastic training

Horovod (14) — Elastic training discovery node & State

Horovod (15) — Broadcast & Notification

Horovod (16), a distributed training framework for deep learning, is a Worker lifecycle for flexibility training

0x01 General idea

First of all, we need to pay attention to the fact that, to some extent, fault tolerance and elastic scheduling are mutually causative.

  • Fault-tolerant means that a job is not affected by changes in the number of processes in it.
  • In elastic scheduling, the number of processes in a job increases or decreases with the workload of the cluster. Therefore, a job must be fault-tolerant so that it can cooperate with the scheduling system to implement elastic scheduling.

Secondly, in the source code of the document, there are the following comments, we can see the specific ideas of fault tolerance.

The reset process following a ``HorovodInternalError`` (failure) or ``HostsUpdatedInterrupt`` (add/remove request) is as  follows: 1. Catch exception within the ``hvd.elastic.run`` decorator. 2. Restore last committed state if ``HorovodInternalError``  was raised. 3. Reinitialize Horovod context performing a new round of rendezvous. 4. Synchronize state among the workers by broadcasting from the new worker-0. 5. Resume training by executing the underlying training function. During rendezvous, older workers will take priority in being assigned worker-0 status to ensure that the state that is broadcast is up to date.Copy the code

A rough translation is as follows:

In error state, when HorvodInternalError or HostsUpdateInterrupt is added or deleted in worker process, these two errors will be caught and reset will be called for fault tolerance processing:

  • inhvd.elastic.runDecorators catch exceptions;
  • If it isHorovodInternalError, the state is restored to the last commit.
  • The Horovod context is re-initialized, and then the driver will trigger a new round of rendezvous according to the nodes currently running. During rendezvous, the old worker will be elected as the new rank-0 first, because the old worker has the latest status.
  • When the new communication domain is successfully constructed, the worker with rank=0 will broadcast its model (state) to other workers;
  • Then the iterative step that was stopped last time starts training and continues to run down the code in the train function.

So let’s see how we do that.

0x02 An exception is thrown

2.1 Sample Code

Let’s first review the sample code.

import tensorflow as tf import horovod.tensorflow as hvd hvd.init() @tf.function def train_one_batch(data, target, allreduce=True): with tf.GradientTape() as tape: probs = model(data, training=True) loss = tf.losses.categorical_crossentropy(target, probs) if allreduce: tape = hvd.DistributedGradientTape(tape) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) .... Def train(state): for state. Epoch in range(state. Epoch, epochs): for state.batch in range(state.batch, batches_per_epoch): data, target = get_random_batch() train_one_batch(data, target) if state.batch % batches_per_commit == 0: state.commit() state.batch = 0 state = hvd.elastic.TensorFlowKerasState(model, optimizer, batch=0, epoch=0) state.register_reset_callbacks([on_state_reset]) train(state)Copy the code

The key thing is that train(state) is wrapped in the adapter @hvd.elastic. Run, so let’s go with that.

2.2 HorovodInternalError

From the following code shows HVD. Elastic. The run is horovod tensorflow/elastic. Py run function.

import horovod.tensorflow as hvd
@hvd.elastic.run
Copy the code

So we came to the horovod tensorflow/elastic. Py.

Func is a user training function. When an error occurs in the user training function, it will be analyzed according to the captured exception information. If it is related to Ring AllReduce, the exception HorovodInternalError(e) will be thrown.

def run(func): from tensorflow.python.framework.errors_impl import UnknownError def wrapper(state, *args, **kwargs): try: return func(state, *args, **kwargs) except UnknownError as e: If 'HorovodAllreduce' in e.message or \ 'HorovodAllgather' in E.message or \ 'HorovodBroadcast' in e.message: raise HorovodInternalError(e) return run_fn(wrapper, _reset)Copy the code

2.3 HostsUpdatedInterrupt

As we know from the previous passage:

When the driver process finds that a node is marked as new or removed through the node discovery script, it will send a notification to all workers, and workers will process according to the notification.

The details are as follows:

  • The driver (background discovery) process gets the WorkerNotificationClient and then calls the WorkerNotificationClient for notification. Is to use the WorkerNotificationClient to send HostsUpdatedRequest.
  • WorkerNotificationService inherited network BasicService, therefore WorkerNotificationClient as WorkerNotificationService operation interface, Send WorkerNotificationService HostsUpdatedRequest thereby.
  • Respond to HostsUpdatedRequest WorkerNotificationService. Call handle_hosts_updated will be informed by registered on WorkerNotificationManager listener (that is, the State of the user code).
  • Each worker has its own State, which is located inWorkerNotificationManager . _listeners.
  • After each worker receives the notification, it calls _host_messages to register the host change in state, that is, it puts a message “host changed “into its _host_messages.
  • The next time state.com MIT () or the lighter state.check_host_updates() is called, state.check_host_updates reads messages from _host_messages and accumulates updates, As described in the comment in the method, the state is synchronized between each worker so that they all throw the HostsUpdateInterrupt exception at the same time. Specifically, _bcast_object is used synchronously (and then MPI is called internally).
  • State.check_host_updates () throws the HostsUpdateInterrupt exception.

The specific code is as follows:

When the user calls COMMIT, check_host_updates is called to check for updates. This is an intrusion into user code. The user is using something in the framework, not knowing the Driver, but using something else in the framework, such as state.

def commit(self):
    self.save()
    self.check_host_updates()
Copy the code

Check for updates below.

If a host change is detected, a HostsUpdatedInterrupt exception is generated.

def check_host_updates(self): # Iterate through the update messages sent from the server. If the update timestamp # is greater than the last update timestamp, then trigger a HostsUpdatedException. last_updated_timestamp = prev_timestamp = self._last_updated_timestamp all_update = HostUpdateResult.no_update while not self._host_messages.empty(): timestamp, update = self._host_messages.get() if timestamp > last_updated_timestamp: last_updated_timestamp = timestamp all_update |= update prev_timestamp, self._last_updated_timestamp, all_update = \ self._bcast_object((prev_timestamp, last_updated_timestamp, all_update)) # At this point, updated state is globally consistent across all ranks. if self._last_updated_timestamp > prev_timestamp: Raise HostsUpdatedInterrupt(all_update == HostupDateresult.removed)Copy the code

2.4 summarize

So we can answer the two questions at the beginning of the passage:

  • Are these exceptions automatically emitted by each worker?

    • Yes, it’s automatically thrown.
    • When an error occurs in the user training function, it is analyzed according to the captured exception information. If it is related to Ring AllReduce, the exception HorovodInternalError(e) is thrown instead.
    • A HostsUpdatedInterrupt exception is generated when a host change is detected.
  • Do workers throw exceptions together?

    • It’s thrown together.
    • If the training fails, an exception is thrown
    • When the driver finds that a node has been marked as new or removed through the node discovery script, it will send a notification to all workers the next time state.com MIT () or the lighter stat.check_host_updates () is called. A HostsUpdateInterrupt exception is thrown.

The logic for throwing exceptions is as follows:

+-----------------------------------------------------------------+
| Worker                                                          |
|                                                                 |
|  HostsUpdatedInterrupt                    HorovodInternalError  |
|     ^                                             ^             |
|     |                                             |             |
|     |    +----------------------------------+     |             |
|     |    | train                            |     |             |
|     |    |                                  |     |             |
|     |    |    optimizer.apply_gradients +---------+             |
|     |    |                                  |                   |
|     +-------+ state.commit()                                    |
|          |                                  |                   |
|          +----------------------------------+                   |
|                                                                 |
|                                                                 |
+-----------------------------------------------------------------+
Copy the code

0x03 Abnormal processing

3.1 Overall Logic

The overall architecture is in run_fn.

Recall where run_fn was called from. The wrapper is running inside the wrapper. The wrapper itself is a wrapper for the user training function.

def run(func):
    from tensorflow.python.framework.errors_impl import UnknownError
​
    def wrapper(state, *args, **kwargs):
        try:
            return func(state, *args, **kwargs)
        except UnknownError as e:
            if 'HorovodAllreduce' in e.message or \
                    'HorovodAllgather' in e.message or \
                    'HorovodBroadcast' in e.message:
                raise HorovodInternalError(e)
                
    return run_fn(wrapper, _reset) 
Copy the code

The general logic is as follows:

+----------------------------------------------------------------------------+
| Worker                                                                     |
|                                                                            |
|  +----------------------------------------------------------------------+  |
|  | run_fn                                                               |  |
|  |                                                                      |  |
|  |                                                                      |  |
|  |                                                                      |  |
|  |                                                                      |  |
|  |                                                                      |  |
|  |                                                                      |  |
|  |                                                                      |  |
|  |    HostsUpdatedInterrupt                    HorovodInternalError     |  |
|  |       ^                                             ^                |  |
|  |       |                                             |                |  |
|  |       |    +----------------------------------+     |                |  |
|  |       |    | train                            |     |                |  |
|  |       |    |                                  |     |                |  |
|  |       |    |    optimizer.apply_gradients +---------+                |  |
|  |       |    |                                  |                      |  |
|  |       +-------+ state.commit()                                       |  |
|  |            |                                  |                      |  |
|  |            +----------------------------------+                      |  |
|  |                                                                      |  |
|  |                                                                      |  |
|  |                                                                      |  |
|  +----------------------------------------------------------------------+  |
+----------------------------------------------------------------------------+
Copy the code

The run_fn logic is as follows:

  • When the HorovodInternalError is generated, state.restore() is called to restore;
  • When HostsUpdatedInterrupt is captured, skip_sync is set;
  • Call reset(), state.on_reset() to reset;
  • In the next loop, it is decided whether to execute state.sync() based on skip_sync;

The specific code is as follows:

def run_fn(func, reset):
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs):
        notification_manager.init()
        notification_manager.register_listener(state)
        skip_sync = False
​
        try:
            while True:
                if not skip_sync:
                    state.sync()
​
                try:
                    return func(state, *args, **kwargs)
                except HorovodInternalError:
                    state.restore()
                    skip_sync = False
                except HostsUpdatedInterrupt as e:
                    skip_sync = e.skip_sync
​
                reset()
                state.on_reset()
        finally:
            notification_manager.remove_listener(state)
    return wrapper
Copy the code

So let’s extend the logic as follows:

+------------------------------------------------------------------------------+ | Worker | | | | +------------------------------------------------------------------------+ | | | run_fn | | | | +----------------------------------+ | | | | | while True: | | | | | | | | | | | v | | | | | | | | | | state.sync() | | | | | + | | | | | | | | | | | | | | | | | v | | | | | +------------------+---------------+ | | | | | | train | | | | | | | | | | | | | | optimizer.apply_gradients +---------+  | | | | | | | | | | | | | +-------+ state.commit() | | | | | | | | | | | | | | | | +----------------------------------+  | | | | | | | | | | | | | v v | | | | | HostsUpdatedInterrupt HorovodInternalError | | | | | + | | | | | + | | | | | | | | | | | | | | v | | | | | | state.restore() | | | | | | + | | | | | | | | | | | | +------------------+ <------------------+ | | | | | | | | | | | | | | | | | | | v v | | | | | reset() | | | | | | | | | | state.on_reset() | | | | | | | | | | + | | | | | | | | | | | +-----------------------------------> | | | | | | | +------------------------------------------------------------------------+ | | | +------------------------------------------------------------------------------+Copy the code

3.2 recovery

State.restore () will restore.

In TensorFlowKerasState, restore is implemented.

def restore(self):
    self._load_model()
    super(TensorFlowKerasState, self).restore()
Copy the code

The specific restore is to reload the model using TensorFlowKerasState model and the optimizer.

def _load_model(self):
    if _executing_eagerly():
        for var, saved_var in zip(self.model.variables, self._saved_model_state):
            var.assign(saved_var)
        for var, saved_var in zip(self.optimizer.variables(), self._saved_optimizer_state):
            var.assign(saved_var)
    else:
        self.model.set_weights(self._saved_model_state)
        self.optimizer.set_weights(self._saved_optimizer_state)
Copy the code

We expand as follows:

+---------------------------------------------------------------------------------------------------------+ | Worker | |  | | +----------------------------------------------------------------------------------------------------+ | | | run_fn  | | | | +---------------------------------------------------------------+ | | | | | while True: | | | | | | | | | | | v | | | | | | | | | | state.sync() | | | | | + | | | | | | | | | | | | | | | | | v | | | | | +------------------+---------------+ | | | | | | train | | | | | | | | | | | | | | optimizer.apply_gradients +---------+  | | | | | | | | | | | | | +-------+ state.commit() | | | | | | | | | | | | | | | | +----------------------------------+  | | | | | | | | | | | | | v v | | | | | HostsUpdatedInterrupt HorovodInternalError | | | | | + | | | | | + | | | | | | | | | | | | | | v +-------------------------+ | | | | | | state.restore() +---> | _load_model | | | | | | | + | | | | | | | | | | model.set_weights | | | | | | +------------------+ <------------------+ | optimizer.set_weights | | | | | | | | | var.assign(saved_var) | | | | | | | | | | | | | | | v v +-------------------------+ | | | | | reset() | | | | | | | | | | state.on_reset() | | | | | | | | | | + | | | | | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- ^ | | | | | | | +----------------------------------------------------------------------------------------------------+ | | | +---------------------------------------------------------------------------------------------------------+Copy the code

Mobile phones are as follows:

3.3 reset

The following code resets.

reset()
state.on_reset()
Copy the code

3.3.1 reset

The specific reset function is:

def _reset():
    shutdown()
    init()
Copy the code

3.3.2 rainfall distribution on 10-12 _HorovodBasics

Specifically using the _HorovodBasics function here.

_basics = _HorovodBasics(__file__, 'mpi_lib')
​
init = _basics.init
shutdown = _basics.shutdown
Copy the code

Specifically, to re-establish the MPI related context.

def init(self, comm=None):
​
    if comm is None:
        comm = []
​
    atexit.register(self.shutdown)
​
    if not isinstance(comm, list):
        mpi_built = self.MPI_LIB_CTYPES.horovod_mpi_built()
​
        from mpi4py import MPI
        if MPI._sizeof(MPI.Comm) == ctypes.sizeof(ctypes.c_int):
            MPI_Comm = ctypes.c_int
        else:
            MPI_Comm = ctypes.c_void_p
            self.MPI_LIB_CTYPES.horovod_init_comm.argtypes = [MPI_Comm]
​
        comm_obj = MPI_Comm.from_address(MPI._addressof(comm))
        self.MPI_LIB_CTYPES.horovod_init_comm(comm_obj)
    else:
        comm_size = len(comm)
        self.MPI_LIB_CTYPES.horovod_init(
            (ctypes.c_int * comm_size)(*comm), ctypes.c_int(comm_size))
​
def shutdown(self):
    self.MPI_LIB_CTYPES.horovod_shutdown()
Copy the code

3.3.3 on_reset

Is to perform the reset callback set by the user.

def on_reset(self):
    self._host_messages = queue.Queue()
    self.reset()
    for callback in self._reset_callbacks:
        callback()
Copy the code

For example, the user sets the following callback:

def on_state_reset():
    optimizer.lr.assign(lr * hvd.size())
Copy the code

The logic is as follows:

+-------------------------------------------------------------------------------------------------------------+ | Worker  | | | | +--------------------------------------------------------------------------------------------------------+ | | | run_fn | | | | +-----------------------------------------------------------------+ | | | | | while True: | | | | | | | | | | | v | | | | | | | | | | state.sync() | | | | | + | | | | | | | | | | | | | | | | | v | | | | | +------------------+---------------+ | | | | | | train | | | | | | | | | | | | | | optimizer.apply_gradients +---------+  | | | | | | | | | | | | | +-------+ state.commit() | | | | | | | | | | | | | | | | +----------------------------------+  | | | | | | | | | | | | | v v | | | | | HostsUpdatedInterrupt HorovodInternalError +-------------------------+ | | | | | + | _load_model | | | | | | + | | | | | | | | | | | model.set_weights | | | | | | | v | optimizer.set_weights | | | | | | | state.restore() +---> | var.assign(saved_var) | | | | | | | + | | | | | | | | | +-------------------------+ | | | | | +------------------+ <------------------+ | | | | | | | +-------------------------+ | | | | | | | | _HorovodBasics |  | | | | | v v | | | | | | | reset() +-----------------------------> | | | | | | | +---------------+ | horovod_init | | | | | | | user callback +<----+ state.on_reset() | | | | | | | +---------------+ | horovod_init_comm | | | | | | + | | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - ^ | | | | | | | +--------------------------------------------------------------------------------------------------------+ |  | | +-------------------------------------------------------------------------------------------------------------+Copy the code

Mobile phones are as follows:

We do the sync

When a reset occurs, the user also performs the necessary synchronization, which involves broadcasting variables and storing the model.

def sync(self):
    if self.session is not None:
        self.session.run(self._bcast_op)
    self._save_model()
    super(TensorFlowState, self).sync()
Copy the code
3.3.4.1 radio

The broadcast function was set during the previous initialization

self._bcast_op = broadcast_variables(self.variables, root_rank=0)
Copy the code

Therefore, when the new communication domain is successfully constructed, the worker with rank=0 will broadcast its model to other workers.

3.3.4.2 deposit model

Storing a model is a call to _eval_fn to dump the model variables into memory.

def _save_model(self):
    self._values = [self._eval_fn(var) for var in self.variables]
Copy the code

_eval_fn was set during the previous initialization

self._eval_fn = self._to_numpy if _executing_eagerly() else self._eval_var
Copy the code

The specific function is:

def _eval_var(self, var):
    return var.eval(self.session)
​
def _to_numpy(self, var):
    return var.numpy()
Copy the code

So our logic extends as follows:

+-------------------------------------------------------------------------------------------------------------------+ | Worker | | | | +-------------------------------------------------------------------------------------------------------------+ | | | run_fn | | | | +----------------------------------------------------------------------+ | | | | | while True: | | | | | | | | | | | v | | | | | +-------------------------------------------------+ | | | | | state.sync() +--------> |broadcast_variables(self.variables, root_rank=0) | | | | | | + | | | | | | | | | _save_model | | | | | | | | | | | | | | v +-------------------------------------------------+ | | | | | +------------------+---------------+ | | | | | | train | |  | | | | | | | | | | | | optimizer.apply_gradients +---------+ | | | | | | | | | | | | | +-------+ state.commit() | | | | | | | | | | | | | | | | +----------------------------------+ | | | | | | | | | | | | | v v | | | | | HostsUpdatedInterrupt HorovodInternalError +-------------------------+ | | | | | + | _load_model | | | | | | + | | | | |  | | | | | | model.set_weights | | | | | | | v | optimizer.set_weights | | | | | | | state.restore() +---> | var.assign(saved_var) | | | | | | | + | | | | | | | | | +-------------------------+ | | | | | +------------------+ <------------------+ | | | | | | | +-------------------------+ | | | | | | | | _HorovodBasics | | | | | | v v | | | | | | | reset() +-----------------------------> | | | | | | | +---------------+ | horovod_init | | | | | | | user callback +<----+ state.on_reset() | | | | | | | +---------------+ | horovod_init_comm | | | | | | + | | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - ^ | | | | | | | +-------------------------------------------------------------------------------------------------------------+ | | | +-------------------------------------------------------------------------------------------------------------------+Copy the code

Mobile phones are as follows:

At this point, the analysis of flexibility training is over. The following two or three articles will introduce K8S.

0xEE Personal information

Thoughts on life and technology

Wechat public account: Rosie’s Thinking

If you want to get updates on your own articles, or check out your own technical recommendations, please pay attention.

0 XFF reference

Horovod for ElasticDL in Kubernetes

Kubernetes Training _ Distributed deep learning training using Horovod on Kubernetes

The Elastic Training Operator is an Elastic deep learning Training tool on Kubernetes

Elastic and Fault-tolerant Distributed Training for ElasticHorovod

Horovod Flexibility training

Kubernetes- Native elastic distributed deep learning system