0 x00 the

Horovod is an easy-to-use, high-performance distributed training framework released by Uber in 2017 that has been widely used in the industry.

This series takes you through the source code analysis of Horovod. This is the 15th article in a series to see how Horovod Flexibility broadcasts and sends notifications.

Links to other articles in this series are as follows:

Horovod (1) — Basics

Horovod (2) — A distributed training framework for deep learning — from the user’s perspective

Horovod (3) — What’s behind Horovodrun

Horovod (4) — Network Basics & Driver

Horovod (5) — fusion framework

Horovod (6) — background architecture

Horovod (6), a distributed training framework for deep learning, is implemented using threads

Horovod (7) — DistributedOptimizer

Horovod (8) — on Spark

Horovod (9) — Start on Spark

Horovod (10) — Run on Spark

Horovod (11) — on Spark — GLOO scheme

Horovod (12) — a distributed training framework for deep learning horoVOd (12) — an overall architecture for elastic training

Horovod (13), a distributed training framework for deep learning, is a Driver for elastic training

Horovod (14) — Elastic training discovery node & State

0 x01 problem

First, we asked the question: Why does flexibility training need a broadcast?

The answer is: because after catching the two kinds of exceptions, it needs to broadcast to each worker.

1.1 HorovodInternalError

HorovodInternalError exception handling:

  • inhvd.elastic.runDecorators catch exceptions;
  • If it isHorovodInternalErrorIs restored to the state of the latest submission. At this time, all workers are in the stopped state because of exceptions such as AllReduce.
  • The Driver reperforms a rendezvous based on the currently running node in order to re-initialize the Horovod context.
  • When the new communication domain is constructed successfully, the worker with rank = 0 will broadcast its model to other workers;
  • All workers continue training after the number of iterative steps stopped last time;

Since you need to broadcast variables from Rank 0 to other processes, you must have a broadcast mechanism.

1.2 HostsUpdateInterrupt

With respect to HostsUpdateInterrupt exception handling, let’s look at the specific causes.

  • When the driver finds that a node has been marked as new or removed through the node discovery script, it will send a notification to all workers the next time state.com MIT () or the lighter stat.check_host_updates () is called. An exception HostsUpdateInterrupt is thrown. This exception is similar to the HorovodInternalError exception, except that the parameter status and so on are not recovered from the latest commit, but from the current live parameters.
  • The check_host_updates method will go from_host_messagesReads messages and accumulates updates. As stated in the annotation in its method, the state will be synchronized between each worker so that these workers can throw exceptions at the same time.
  • Specific Synchronization Usage_bcast_objectThe MPI is then called internally.

A broadcast mechanism is needed to synchronize the state between each worker (since these workers are currently training normally, something needs to interrupt their training to reestablish a communication ring) so that all workers throw HostsUpdateInterrupt at the same time.

0x02 Broadcast mechanism

We specifically analyze the broadcast mechanism as follows, is because radio and closely combined with the concrete frame, so we tensorflow, for example, the specific code in horovod/tensorflow/elastic. Py.

2.1 Broadcast Implementation

In horovod/tensorflow/elastic. Py, is to make the specific implementation of TF. Different processing will be done according to the version of TF.

2.1.1 TensorFlowKerasState

Take TensorFlowKerasState as an example. During initialization, because of the need to broadcast the object, for example, _bcast_model is configured in TensorFlowKerasState to broadcast the model, bcast_object is configured to broadcast the object. Broadcast_variables are used to broadcast variables.

The sync function is provided to take care of the broadcast, and you can see that _bcast_model is called.

class TensorFlowKerasState(ObjectState) :
    def __init__(self, model, optimizer=None, backend=None, **kwargs) :

        if not backend or _executing_eagerly():
            # Set the broadcast function here
            self._bcast_model = lambda: _broadcast_model(self.model, self.optimizer, backend=self.backend)
            bcast_object = broadcast_object
        else:
            # For TensorFlow v1, we need to reuse the broadcast op to prevent incrementing the uids
            # Set the broadcast function here
            bcast_op = broadcast_variables(_global_variables(), root_rank=0)
            self._bcast_model = lambda: self.backend.get_session().run(bcast_op)
            bcast_object = broadcast_object_fn(session=self.backend.get_session())
        
    def sync(self) :
        self._bcast_model() # Broadcast model
        self._save_model()
        super(TensorFlowKerasState, self).sync()
Copy the code

2.1.2 Broadcast model

The _broadcast_model function broadcasts model variables and optimizer variables.

def _broadcast_model(model, optimizer, backend) :
    if _executing_eagerly():
        # TensorFlow 2.0 or TensorFlow eager
        broadcast_variables(model.variables, root_rank=0) # Broadcast model variables
        broadcast_variables(optimizer.variables(), root_rank=0) # Broadcast optimizer variables
    else:
        bcast_op = broadcast_variables(_global_variables(), root_rank=0)
        backend.get_session().run(bcast_op)
Copy the code

2.1.3 Broadcast variables

The realization of a radio variables in horovod/tensorflow/functions provides py. The function of broadcast_variables is to broadcast variables from root Rank (i.e., rank 0) to other processes.

Specific differences are also made according to the TF version.

def _make_subgraph(f) :
    return tf.function(f)

@_cache
def _make_broadcast_group_fn() :
    if _executing_eagerly():
        # Eager mode will parallelize independent control flow
        def broadcast_group(variables, root_rank) : # is defined here
            for var in variables:
                var.assign(broadcast(var, root_rank)) # call MPI (root_rank)

        return _make_subgraph(broadcast_group)
    else:
        # Graph mode requires an Op
        def broadcast_group(variables, root_rank) : # is defined here
            # tf.group() is used to create an operation that combines all the actions passed in. When this operation is complete, all ops in all inputs are complete. The tf.group() operation has no output.
            return tf.group(*[var.assign(broadcast(var, root_rank)) # Call MPI function here
                              for var in variables])

        return broadcast_group

def broadcast_variables(variables, root_rank) :
    """Broadcasts variables from root rank to all other processes. """
    broadcast_group = _make_broadcast_group_fn()
    return broadcast_group(variables, root_rank # defined above
Copy the code

2.1.4 Broadcast objects

The function of a broadcast object is to broadcast the object from root Rank (i.e., Rank 0) to other processes. The difference between a broadcast object and a broadcast variable is that the object needs to be serialized and deserialized.

def broadcast_object(obj, root_rank=0, session=None, name=None) :
    """ Serializes and broadcasts an object from root rank to all other processes. Arguments: obj: An object capable of being serialized without losing any context. root_rank: The rank of the process from which parameters will be broadcasted to all other processes. session: Session for TensorFlow v1 compatibility. name: Optional name to use during broadcast, will default to the class type. Returns: The object that was broadcast from the `root_rank`. """
    if name is None:
        name = type(obj).__name__

    def to_numpy(v) : # Different processing is done according to different TF versions
        if not _executing_eagerly():
            sess = session or ops.get_default_session()
            return sess.run(v)
        else:
            return v.numpy()

    if rank() == root_rank:
        b = io.BytesIO() # BytesIO implements reading and writing bytes in memory
        cloudpickle.dump(obj, b) # Serialize, encode into a binary file
        t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8)
        sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32) # The value of the dimension corresponding to the tensor
        to_numpy(broadcast(sz, root_rank, name + '.sz')) # Broadcast Dimension
    else:
        sz = tf.convert_to_tensor([0], dtype=tf.int32)
        sz = to_numpy(broadcast(sz, root_rank, name + '.sz')) # Accept dimension
        t = tf.zeros(sz.tolist()[0], dtype=tf.uint8)

    t = to_numpy(broadcast(t, root_rank, name + '.t')) # Broadcast object content

    ifrank() ! = root_rank: buf = io.BytesIO(t.tobytes()) obj = cloudpickle.load(buf)# Deserialize and decode to the original object

    return obj
Copy the code

2.1.5 HVD c + +

The bottom layer calls the MPI function to complete the broadcast function.

def broadcast(tensor, root_rank, name=None, ignore_name_scope=False) :
    """An op which broadcasts the input tensor on root rank to the same input tensor on all other Horovod processes. The broadcast operation is keyed by the name of the op. The tensor type and shape must be the same on all Horovod processes for a given name. The broadcast will not start until all processes are ready to send and receive the tensor. Returns: A tensor of the same shape and type as `tensor`, with the value broadcasted from root rank. """
    if name is None and not _executing_eagerly():
        name = 'HorovodBroadcast_%s' % _normalize_name(tensor.name)
    return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank,
                                     ignore_name_scope=ignore_name_scope)
Copy the code

2.1.6 MPI

MPI_BCAST is used to broadcast a message from a process with the sequence number root to all processes in the group, including itself.

Because root_rank is specified, even if all workers call the same code, they will only copy messages from the root_rank communication buffer to all other processes.

void MPIController::Bcast(void* buffer, size_t size, int root_rank,
                          Communicator communicator) {
  MPI_Comm comm = mpi_ctx_.GetMPICommunicator(communicator);
  int ret_code = MPI_Bcast(buffer, size, MPI_BYTE, root_rank, comm);
  if(ret_code ! = MPI_SUCCESS) {throw std::runtime_error(
        "MPI_Broadcast failed, see MPI output for details."); }}Copy the code

2.1.7 summary

Let’s summarize each function:

  • _bcast_modelUsed for broadcasting model;
  • bcast_objectUsed to broadcast objects;
  • broadcast_variablesUsed to broadcast variables;
  • The difference between a broadcast object and a broadcast variable is that the object needs to be serialized and deserialized.
  • _broadcast_modelIt’s calledbroadcast_variablesComplete the broadcast of model parameters;
  • broadcast_variablesCall thebroadcast_group.broadcast_groupThe main thing is to combine broadcast operations using tf.group();

2.2 the use of

2.2.1 HorovodInternalError

When the HorovodInternalError is captured, broadcast synchronization will be carried out. The purpose is that when the new communication domain is successfully constructed, the worker with rank = 0 will broadcast its model to other workers.

def run_fn(func, reset) :
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs) :
        notification_manager.init()
        notification_manager.register_listener(state)
        skip_sync = False

        try:
            while True:
                if not skip_sync:
                    state.sync() # will broadcast synchronization here, is TensorFlowKerasState. Sync

                try:
                    return func(state, *args, **kwargs)
                except HorovodInternalError:
                    state.restore() # Capture a scene and continue the while loop
                    skip_sync = False
                except HostsUpdatedInterrupt as e:
                    skip_sync = e.skip_sync

                reset()
                state.on_reset()
        finally:
            notification_manager.remove_listener(state)
    return wrapper
Copy the code

The details are as follows:

  Worker rank 0                               Worker rank n
        +                                         +
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 Catch HorovodInternalError                       |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
       sync                                       |
        |                                         |
        |                                         |    
        v                                         |
_broadcast_model(model)                           |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 broadcast_variables(model.variables)             |
                                                  |
 broadcast_variables(optimizer.variables)         |
                                                  |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
  broadcast_group                                 |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 MPI_LIB.horovod_broadcast  +-------------------> |
        +                                         |
        |                                         |
        |                                         |
        v                                         v
Copy the code

2.2.2 HostsUpdateInterrupt

The purpose of the broadcast object is to synchronize the state between each worker so that all workers throw HostsUpdateInterrupt at the same time.

How to use it?

In WorkerNotificationService. _handle method, called the self. The _manager. Handle_hosts_updated (the req. The timestamp, the req. Res) to inform update.

WorkerNotificationManager. Handle_hosts_updated method, invokes the registration of the state, inform the updated them one by one.

def handle_hosts_updated(self, timestamp, update_res) :
    for listener in self._listeners:
        listener.on_hosts_updated(timestamp, update_res)
Copy the code

This is seen in several methods of State.

  • On_hosts_updated: called when a host changes_host_messagesThis queue holds a message;
  • Commit: This function is called periodically by the user, storing the state and checking for host changes.
  • Check_host_updates: from_host_messagesReads messages and accumulates updates. As described in the annotation in the method, the state will be synchronized between each worker so that these workers can throw exceptions at the same time. Specific Synchronization Usage_bcast_object;

The check_host_updates code is as follows:

def check_host_updates(self) :
    """Checks that a notification has been sent indicating that hosts can be added or will be removed. Raises a `HostsUpdatedInterrupt` if such a notification has been received. """
    # Iterate through the update messages sent from the server. If the update timestamp
    # is greater than the last update timestamp, then trigger a HostsUpdatedException.
    last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
    all_update = HostUpdateResult.no_update
    while not self._host_messages.empty():
        timestamp, update = self._host_messages.get()
        if timestamp > last_updated_timestamp:
            last_updated_timestamp = timestamp
            all_update |= update

    # In order to ensure all workers raise the exception at the same time, we need to sync
    # the updated state across all the workers.
    # TODO(travis): this should be a max allreduce to account for changes in rank 0
    # Broadcast here
    prev_timestamp, self._last_updated_timestamp, all_update = \
        self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))

    # At this point, updated state is globally consistent across all ranks.
    if self._last_updated_timestamp > prev_timestamp:
        raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
Copy the code

The details are as follows:

+---------------------------+      +--------------+            +-------------+
|Catch HostsUpdatedInterrupt|      | Worker rank 1| |Worker rank n| +---------+-----------------+ +-------+------+ +----+--------+ | | | | | | | | | v | | | | WorkerNotificationService | | + | | | | | | | | | | | v | | | | manager.handle_hosts_updated+------------> | | | | | | v  | | on_hosts_updated | + | | | | | | | check_host_updates | | | | | | | | | v | | broadcast_object | + | | | | | | | | | v | | MPI_LIB.horovod_broadcast +----> | + | | | | | v vCopy the code

0x03 Notification mechanism

Above used in manager handle_hosts_updated, manager is WorkerNotificationManager.

So we follow the WorkerNotificationManager under discussion, this is Hovorod notification mechanism.

3.1 WorkerNotificationManager generated

Each host only one WorkerNotificationManager, also only a WorkerNotificationService.

Note: is ElasticDriver will act as a client, send messages to these WorkerNotificationService, causing WorkerNotificationManager corresponding operation.

Horovod/common/elastic. Py completed instance generation has the following code.

notification_manager = WorkerNotificationManager()
Copy the code

WorkerNotificationManager are defined as follows:

class WorkerNotificationManager(object) :
    def __init__(self) :
        self._lock = threading.Lock()
        self._service = WorkerNotificationService(secret_key, nic, self)
        self._listeners = set(a)Copy the code

3.2 the initialization

Before the user code can begin first initialize WorkerNotificationManager.

 def run_fn(func, reset) :
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs) :
        Initialize WorkerNotificationManager #
        notification_manager.init()
        Register your state with notification_Manager
        notification_manager.register_listener(state)        
Copy the code

WorkerNotificationManager initialization code as follows, the logic is:

  • If _service has generated, the direct return, this is only one WorkerNotificationService ensures that each host.
  • Rendezvous information from system variables, such as addresses, ports, keys, and so on;
  • Generate WorkerNotificationService, assigned to _service;
  • Use put_datA_into_kvstore to send the address of the worker and the sequence rank assigned to it in the logical communication ring to rendezvous (this is for subsequent generation of the WorkerNotificationClient).
  • Note: This rendezvous will store the address of each worker and the rank assigned to it in the logical communication ring. Worker processes use this rendezvous to construct new communication domains.
def init(self, rendezvous_addr=None, rendezvous_port=None,
         nic=None, hostname=None, local_rank=None) :
    with self._lock:
        if self._service:
            return

        Get rendezvous information from system variables such as addresses, ports, keys, and so on
        rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR)
        rendezvous_port = rendezvous_port if rendezvous_port is not None else \
            int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
        nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
        hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
        local_rank = local_rank if local_rank is not None else \
            int(os.environ.get(HOROVOD_LOCAL_RANK))

        secret_key = secret.make_secret_key()
        self._service = WorkerNotificationService(secret_key, nic, self)

        value = (self._service.addresses(), secret_key)
        # Send the address of the worker and the number rank assigned to it in the logical communication ring to rendezvous
        put_data_into_kvstore(rendezvous_addr,
                              rendezvous_port,
                              PUT_WORKER_ADDRESSES,
                              self._create_id(hostname, local_rank),
                              value)
Copy the code

The details of put_data_into_kvstore are as follows:

def put_data_into_kvstore(addr, port, scope, key, value) :
    try:
        url = "http://{addr}:{port}/{scope}/{key}".format(
            addr=addr, port=str(port), scope=scope, key=key
        )
        req = Request(url, data=codec.dumps_base64(value, to_ascii=False))
        req.get_method = lambda: "PUT"  # for urllib2 compatibility
        urlopen(req)
    except (HTTPError, URLError) as e:
        raise RuntimeError("Put data input KVStore server failed.", e)
Copy the code

3.3 registered State

The user code also registers its state with notification_Manager before it starts.

def run_fn(func, reset) :
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs) :
        Initialize WorkerNotificationManager #
        notification_manager.init()
        Register your state with notification_Manager
        notification_manager.register_listener(state)
Copy the code

The specific code is as follows:

def register_listener(self, listener) :
    self._listeners.add(listener)

def remove_listener(self, listener) :
    self._listeners.remove(listener)
Copy the code

3.4 WorkerNotificationService

WorkerNotificationService also only one in each host, to accept its HostsUpdatedRequest news message from the client for processing. As you can see, it inherits the network BasicService, this means that WorkerNotificationService itself is an HTTP server, and their client interactions, If you think about the various drivers/clients you’ve introduced before, you can understand the mechanism.

class WorkerNotificationService(network.BasicService) :
    NAME = 'worker notification service'

    def __init__(self, key, nic, manager) :
        super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
                                                        key,
                                                        nic)
        self._manager = manager

    def _handle(self, req, client_address) :
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp, req.res)
            return network.AckResponse()

        return super(WorkerNotificationService, self)._handle(req, client_address)
Copy the code

The logic is as follows:

 +-------------------------------+                          +---------------------------+
 | WorkerNotificationManager     |                          | rendezvous                |
 |                               +------------------------> |                           |
 |                               |  put_data_into_kvstore   |                           |
 |                               |                          |                           |
 |                               |                          +---------------------------+
 | _listeners                    |
 |      +                        |                          +---------------------------+
 |      |         _service  +-----------------------------> | WorkerNotificationService |
 |      |                        |                          |                           |
 +-----------------------+-------+                          |                           |
        |                ^                                  |                           |
        |                |                                  |                           |
        |                |                                  |                           |
        |                +----------------------------------------+ _manager            |
        |                                                   |                           |
        v                                                   |                           |
                                                            +---------------------------+
[State 1, State 2. , State n]Copy the code

3.5 WorkerNotificationClient

WorkerNotificationClient is used to send messages to WorkerNotificationService interface.

In ElasticDriver, a WorkerNotificationClient is generated for each worker for notification.

class WorkerNotificationClient(network.BasicClient) :
    def __init__(self, addresses, key, verbose, match_intf=False) :
        super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
                                                       addresses,
                                                       key,
                                                       verbose,
                                                       match_intf=match_intf)

    def notify_hosts_updated(self, timestamp, update_res) :
        self._send(HostsUpdatedRequest(timestamp, update_res))
Copy the code

3.6 to generate the Client

3.6.1 Registration timing

To recap, in WorkerNotificationManager init initialization function, will send rendezvous put request, to register.

The registration information is used to generate the client.

put_data_into_kvstore(rendezvous_addr,
                      rendezvous_port,
                      PUT_WORKER_ADDRESSES,
                      self._create_id(hostname, local_rank),
                      value)
Copy the code

3.6.2 registered worker

There is _put_value in ElasticRendezvousHandler to process PUT_WORKER_ADDRESSES. Call driver processing.

Note that this is within Rendezvous Server
def _put_value(self, scope, key, value) :
    if scope == PUT_WORKER_ADDRESSES:
        host, local_rank = key.split(':')
        addresses, secret_key = codec.loads_base64(value)
        self._put_worker_addresses(host, int(local_rank), addresses, secret_key)

    super(RendezvousHandler, self)._put_value(scope, key, value)

def _put_worker_addresses(self, host, local_rank, addresses, secret_key) :
    # Here call driver for processing
    driver.register_worker_server(host, local_rank, addresses, secret_key)
Copy the code

3.6.3 generated WorkerNotificationClient

In ElasticDriver, a WorkerNotificationClient is generated for each worker for notification.

Note here: The ElasticDriver is the user of the WorkerNotificationClient. When the WorkerNotificationClient needs to be notified, the ElasticDriver calls the WorkerNotificationClient. Give corresponding host WorkerNotificationService hair message, causing WorkerNotificationManager accordingly.

# ElasticDriver = ElasticDriver
def register_worker_server(self, host, slot, addresses, secret_key) :
    self._worker_clients[(host, slot)] = WorkerNotificationClient(
        addresses, secret_key, self._verbose)
Copy the code

The logic is as follows:

+-------------------------------+ | WorkerNotificationManager | +---------------------------+ +----------------------------+ | | | rendezvous | | ElasticRendezvousHandler | | init +-------------------------------->  | +-------> | | | |1put_data_into_kvstore | | | | | | | | | | | | +---------------------------+ +------------------+---------+ | _listeners | | | + | +---------------------------+ | | | _service +-----------------------------> | WorkerNotificationService | | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - + -- -- -- -- -- -- -- + | | | | ^ | | | | | | | | | | | | | | +----------------------------------------+ _manager | | | | | | v | | | +---------------------------+ | [State1, State 2. , State n] | | +-------------------------------------------------------------------------------------------+ |2 register_worker_server
                      |
                      |
                      v
                                                        3 new instance
 +-------------------------------+
 |ElasticDriver                  |             +----------------------------+     +---------------------------+
 |                               |             | WorkerNotificationClient 1 |     |WorkerNotificationClient n |
 |                               |             |                            |     |                           |
 |                               |             |                            |     |                           |
 |         _worker_clients  +--------------->  |     (host 1, slot 1) |... | (host n, slot n) | | | | For worker1          |     |        For worker n       |
 |                               |             |                            |     |                           |
 +-------------------------------+             +----------------------------+     +---------------------------+

Copy the code

The mobile phone is shown in the picture:

3.7 the use of

3.7.1 Discovering Updates

If host changes are detected in ELASticDriver._DISCOVERy_thread, self. _notify_workers_HOST_changes is called.

def _notify_workers_host_changes(self, current_hosts, update_res) :
    next_host_assignments = {}
    if current_hosts.count_available_slots() >= self._min_np:
        # Assignments are required to be stable via contract
        next_host_assignments, _ = self._get_host_assignments(current_hosts)

    if next_host_assignments == self.host_assignments:
        # Skip notifying workers when host changes would not result in changes of host assignments
        return

    coordinator_slot_info = self.get_coordinator_info()
    coordinator_client = self.get_worker_client(coordinator_slot_info)

    timestamp = _epoch_time_s()
    coordinator_client.notify_hosts_updated(timestamp, update_res)
Copy the code

3.7.2 for client

The get_worker_client function gets the WorkerNotificationClient. It is to find the client corresponding to a worker based on the host and slot information.

def get_worker_client(self, slot_info) :
    return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))
Copy the code

3.7.3 send HostsUpdatedRequest

Notify_hosts_updated sends a HostsUpdatedRequest

class WorkerNotificationClient(network.BasicClient) :
    def __init__(self, addresses, key, verbose, match_intf=False) :
        super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
                                                       addresses,
                                                       key,
                                                       verbose,
                                                       match_intf=match_intf)

    def notify_hosts_updated(self, timestamp, update_res) :
        self._send(HostsUpdatedRequest(timestamp, update_res))
Copy the code

3.7.4 processing HostsUpdatedRequest

Will process of WorkerNotificationService HostsUpdatedRequest, call WorkerNotificationManager processing.

class WorkerNotificationService(network.BasicService) :
    NAME = 'worker notification service'

    def __init__(self, key, nic, manager) :
        super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
                                                        key,
                                                        nic)
        self._manager = manager

    def _handle(self, req, client_address) :
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp, req.res)
            return network.AckResponse()

        return super(WorkerNotificationService, self)._handle(req, client_address)
Copy the code

3.7.5 WorkerNotificationManager

So, when the host updates, WorkerNotificationManager handle_hosts_updated as follows, in the final call to on_hosts_updated state.

def handle_hosts_updated(self, timestamp, update_res) :
    for listener in self._listeners: # traversal state
        listener.on_hosts_updated(timestamp, update_res)
Copy the code

The implementation of State is as follows:

def on_hosts_updated(self, timestamp, update_res) :
    self._host_messages.put((timestamp, update_res))
Copy the code

The logic is as follows:

                                                         +-----------------------------v
                                                         ^        thread loop          |
                                                         |                             |
                                        +----------------+----------------------+      |
                                        |  ElasticDriver._discovery_thread      |      |
       1_notify_workers_host_changes | | | | | | +------------------+ | | | | | | | | HostManager.update_available_hosts | | | |  | | | +-----------------+---------------------+ | | ^ | | | | | | | | +----------<---------------+ v v +---------------------------+2 HostsUpdatedRequest  +----------------------------+ handle_hosts_updated +----------------------------+
|                           |                        |                            |                      |                            |
| WorkerNotificationClient  +----------------------> |  WorkerNotificationService | +------------------> |  WorkerNotificationManager |
|                           |                        |                            |                      |                            |
+---------------------------+                        +----------------------------+                      +------+---------------------+
                                                                                                                |
                                                                                                                |
                                                                                                                | on_hosts_updated
                                                                                                                |
                                                                                                                v
                                                                                                  +-----------------------+
                                                                                                  |  State      |         |
                                                                                                  |             | put     |
                                                                                                  |             v         |
                                                                                                  |     _host_messages    |
                                                                                                  +-----------------------+

Copy the code

Mobile phones are as follows:

3.7.6 Processing updates

Check_host_updates is called to check for updates only when the user calls COMMIT.

def commit(self) :
    self.save()
    self.check_host_updates()
Copy the code

Checking for updates is to see if there are any new messages in _HOST_messages, and if there is a change in host, a HostsUpdatedInterrupt exception is generated.

def check_host_updates(self) :
    # Iterate through the update messages sent from the server. If the update timestamp
    # is greater than the last update timestamp, then trigger a HostsUpdatedException.
    last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
    all_update = HostUpdateResult.no_update
    while not self._host_messages.empty():
        timestamp, update = self._host_messages.get()
        if timestamp > last_updated_timestamp:
            last_updated_timestamp = timestamp
            all_update |= update

    # In order to ensure all workers raise the exception at the same time, we need to sync
    # the updated state across all the workers.
    # TODO(travis): this should be a max allreduce to account for changes in rank 0
    prev_timestamp, self._last_updated_timestamp, all_update = \
        self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))

    # At this point, updated state is globally consistent across all ranks.
    if self._last_updated_timestamp > prev_timestamp:
        raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
Copy the code

When HorvodInternalError or HostsUpdatedInterrupt is added or deleted from the worker process, these two errors are caught and reset is called for fault-tolerant processing. So the process is connected in series.

The details are as follows:

So far, we have finished sorting out the broadcast notification mechanism. The next article will introduce how the worker works.

0xEE Personal information

Thoughts on life and technology

Wechat public account: Rosie’s Thinking

If you want to get updates on your own articles, or check out your own technical recommendations, please pay attention.