0 x00 the

Horovod is an easy-to-use, high-performance distributed training framework released by Uber in 2017 that has been widely used in the industry.

This series takes you through the source code analysis of Horovod. This article, the fourteenth in a series, looks at how Horovod dynamically discovers nodes and state information.

Links to other articles in this series are as follows:

Horovod (1) — Basics

Horovod (2) — A distributed training framework for deep learning — from the user’s perspective

Horovod (3) — What’s behind Horovodrun

Horovod (4) — Network Basics & Driver

Horovod (5) — fusion framework

Horovod (6) — background architecture

Horovod (6), a distributed training framework for deep learning, is implemented using threads

Horovod (7) — DistributedOptimizer

Horovod (8) — on Spark

Horovod (9) — Start on Spark

Horovod (10) — Run on Spark

Horovod (11) — on Spark — GLOO scheme

Horovod (12) — a distributed training framework for deep learning horoVOd (12) — an overall architecture for elastic training

Horovod (13), a distributed training framework for deep learning, is a Driver for elastic training

0 x01 design point

This article corresponds to the Host Discovery section of the schema diagram, which is called by Driver Main, so the two sections are shown together.

The key design points of the node discovery mechanism are as follows:

  • When there is a node change, how to find it immediately? Horovod is done with periodic calls.
  • How to notify each worker when a node change is discovered? Horovod does this by building a notification mechanism. That is, each worker to register yourself into WorkerNotificationManager above, when a node changes, WorkerNotificationManager will inform the worker one by one.
  • What happens when the worker is notified? Horovod further encapsulates the worker’s State into various states on the depth framework, and when notified, calls the corresponding callback function for the states, either to synchronize the State, or to do other processing.

0x02 Discovery Mechanism

This code is mainly in: horovod/runner/elastic/discovery. Py.

2.1 Discovery Script

The main purpose of HostDiscoveryScript is to save the script (set in when the program starts) and then call the discovery script when find_available_HOSTS_AND_slots is executed to get the host information.

The output of the script is in the format of the host argument in the call to horovodrun, such as:

$ sh ./discover_hosts.sh    Run the script to output the node information
10.6832.2.:4
10.6832.3.:4
10.6832.4.:4
Copy the code

The definition is as follows:

class HostDiscoveryScript(HostDiscovery) :
  
    def __init__(self, discovery_script, slots) :
        self._discovery_script = discovery_script # Set script
        self._default_slots = slots The examination and approval of # slots
        super(HostDiscoveryScript, self).__init__()

    def find_available_hosts_and_slots(self) :
        stdout = io.StringIO()
        Execute the discovery script
        exit_code = safe_shell_exec.execute(self._discovery_script, stdout=stdout)

        # read the script output and parse the host information
        host_slots = {}
        lines = set(stdout.getvalue().strip().split('\n'))
        for line in lines:
            host = line
            if ':' in line:
                host, slots = line.split(':')
                host_slots[host] = int(slots)
            else:
                host_slots[host] = self._default_slots
        return host_slots
Copy the code

2.2 HostManager

As the core of host Discovery, HostManager is used to maintain the current host and state. Its main variables are:

  • Self._current_hosts: current host information, including slot, assign order, and so on;
  • _hosts_state: indicates the status of the current host, including blacklist and event.
  • _discovery: it can be considered as an encapsulation of the discovery script, which is used to dynamically execute the discovery script and obtain host information.
class HostManager(object) :
    def __init__(self, discovery) :
        self._current_hosts = DiscoveredHosts(host_slots={}, host_assignment_order=[])
        self._hosts_state = defaultdict(HostState)
        self._discovery = discovery

    def update_available_hosts(self) :
        # TODO(travis): also check for hosts removed from the blacklist in the future
        Check for updates, indicating whether nodes are added or removed
        def check_update(cur_host_slots, prev_host_slots) :
            res = HostUpdateResult.no_update

            for prev_h in prev_host_slots:
                if prev_h not in cur_host_slots:
                    # prev_h is a removed host
                    res |= HostUpdateResult.removed

            for h in cur_host_slots:
                if h not in prev_host_slots:
                    # h is an added host
                    res |= HostUpdateResult.added
                elif cur_host_slots[h] > prev_host_slots[h]:
                    # h has more slots added
                    res |= HostUpdateResult.added
                elif cur_host_slots[h] < prev_host_slots[h]:
                    # h has removed some slots
                    res |=  HostUpdateResult.removed
            return res

        prev_host_slots = self._current_hosts.host_slots
        prev_host_assignment_order = self._current_hosts.host_assignment_order
        host_slots = self._discovery.find_available_hosts_and_slots()
        
        ifprev_host_slots ! = host_slots:# has modified
            # Find host that is not in blacklist
            available_hosts = set([host for host in host_slots.keys() if not self._hosts_state[host].is_blacklisted()])
            # Find the order of host
            host_assignment_order = HostManager.order_available_hosts(available_hosts, prev_host_assignment_order)
            self._current_hosts = DiscoveredHosts(host_slots=host_slots,
                                                  host_assignment_order=host_assignment_order)
            # Check for updates
            return check_update(self._current_hosts.host_slots, prev_host_slots)
        else: No change, no update
            return HostUpdateResult.no_update
Copy the code

The core HostManager logic is the update_available_hosts method, which is used to find available hosts.

2.2.1 order_available_hosts

The function of order_available_hosts is to ensure that the oldest host is given the lowest rank, namely rank 0, because the oldest host is most likely to have the original training model and training state, and this information needs to be sent to all workers before the next new iteration.

    @staticmethod
    def order_available_hosts(available_hosts, prev_host_assignment_order) :
        # We need to ensure this list preserves relative order to ensure the oldest hosts are assigned lower ranks.
        host_assignment_order = [host for host in prev_host_assignment_order if host in available_hosts]
        known_hosts = set(host_assignment_order)
        for host in available_hosts:
            if host not in known_hosts:
                host_assignment_order.append(host)
        return host_assignment_order
Copy the code

2.3 configuration

Let’s see how the script is configured into HostManager.

First, it is discovered that the script is configured in _RUN_elastic.

def _run_elastic(args) :
    # construct host discovery component
    if args.host_discovery_script:
        If the discovery script is set in the parameter, assign the value to discover_hosts
        discover_hosts = discovery.HostDiscoveryScript(args.host_discovery_script, args.slots)
    elif args.hosts: If hosts is set, assign the value to discover_hosts
        _, available_host_slots = hosts.parse_hosts_and_slots(args.hosts)
        if len(available_host_slots) < 2:
            raise ValueError('Cannot run in fault tolerance mode with fewer than 2 hosts.')
        discover_hosts = discovery.FixedHosts(available_host_slots)
    else: # Throw exception
        raise ValueError('One of --host-discovery-script, --hosts, or --hostnames must be provided')

    Enter setting
    settings = elastic_settings.ElasticSettings(discovery=discover_hosts,
                                                .....)

    env = os.environ.copy()
    config_parser.set_env_from_args(env, args)
    gloo_run_elastic(settings, env, args.command)
Copy the code

Second, it is found that the script is set to ElasticSettings.

class ElasticSettings(BaseSettings) :
    def __init__(self, discovery, min_np, max_np, elastic_timeout, reset_limit, **kwargs) :
        self.discovery = discovery
Copy the code

When started, it is set to ElasticDriver.

def start(self) :
    """Starts the Horovod driver and services."""
    self.rendezvous = RendezvousServer(self.settings.verbose)
    self.driver = ElasticDriver(
        rendezvous=self.rendezvous,
        discovery=self.settings.discovery, Set the discovery script here
        min_np=self.settings.min_np,
        max_np=self.settings.max_np,
        timeout=self.settings.elastic_timeout,
        reset_limit=self.settings.reset_limit,
        verbose=self.settings.verbose)
Copy the code

Finally, when HostManager is set up, the discovery script is set up.

class ElasticDriver(object) :
    def __init__(self, rendezvous, discovery, min_np, max_np, timeout=None, reset_limit=None, verbose=0) :
        self._rendezvous = rendezvous
        self._host_manager = HostManager(discovery) # Setup script
Copy the code

0x03 How to Call

3.1 Infinite loop thread

The calling logic for HostManager is in the ElasticDriver class.

When ElasticDriver is initialized, a background thread _discovery_thread is generated.

self._discovery_thread = threading.Thread(target=self._discover_hosts)
Copy the code

3.1.1 Regular exploration

In _discovery_thread, _discover_hosts is run.

ElasticDriver. _discover_hosts will:

  • First callself._host_manager.update_available_hosts(self._host_manager.current_hosts, update_res)Get the latest host status;
  • Second, if the state of the new host has changed, call _notify_workers_HOST_changes and _wait_hosts_cond.notify_all to notify everyone that the host has changed.
def _discover_hosts(self) :
    first_update = True
    while not self._shutdown.is_set():
        self._wait_hosts_cond.acquire()
        try:
            # Get the latest host status
            update_res = self._host_manager.update_available_hosts()
            ifupdate_res ! = HostUpdateResult.no_update: self._notify_workers_host_changes(self._host_manager.current_hosts, update_res) self._wait_hosts_cond.notify_all()# Notify everyone of host changes
        except RuntimeError as e:
            if first_update:
                # Misconfiguration, fail the job immediately
                self._shutdown.set()
                self._wait_hosts_cond.notify_all() # Notify everyone of host changes
                raise
            # Transient error, retry until timeout
            logging.warning(str(e))
        finally:
            self._wait_hosts_cond.release()
        first_update = False
        self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS)
Copy the code

The logic is as follows: A thread loop runs periodically:

 <--------------------^
+                     |
|       thread loop   |
|                     |
|    +----------------+----------------------+
|    |  ElasticDriver._discovery_thread      |
|    |                                       |
|    |                                       |
|    |                                       |
|    |                                       |
|    |   HostManager.update_available_hosts  |
|    |                                       |
|    +----------------+----------------------+
|                     ^
|                     |
v                     |
+-------------------->+

Copy the code

3.1.2 Notification of changes

If a host change is detected, call self. _notify_workers_HOST_changes to notify.

That is, when the Driver’s scheduled process finds that a node is marked as new or removed through the node discovery script, it will call _notify_workers_HOST_changes and send a notification to all workers.

The logic is as follows:

<--------------------^ + | | thread loop | | | | +----------------+-----------------------------------------------+ | | ElasticDriver._discovery_thread | | | | | | | | | HostManager.update_available_hosts | | | + | | | | | | | | | | | v | |  | YES | | | update_res ! = no_update ??? +--------+ | | | + | | | | | | | | | | v | | | | NO | | | | _notify_workers_host_changes | | | v | | +----------------------------------------------------------------+ | | | | | | v | +-------------------->+Copy the code

The details are as follows:

def _notify_workers_host_changes(self, current_hosts, update_res) :
    next_host_assignments = {}
    if current_hosts.count_available_slots() >= self._min_np:
        # Assignments are required to be stable via contract
        next_host_assignments, _ = self._get_host_assignments(current_hosts)

    if next_host_assignments == self.host_assignments:
        # Skip notifying workers when host changes would not result in changes of host assignments
        return

    coordinator_slot_info = self.get_coordinator_info()
    # get WorkerNotificationClient
    coordinator_client = self.get_worker_client(coordinator_slot_info)

    timestamp = _epoch_time_s()
    coordinator_client.notify_hosts_updated(timestamp, update_res) # notice
Copy the code

Get_worker_client gets the WorkerNotificationClient and calls the WorkerNotificationClient to notify, So let’s look at the WorkerNotificationClient.

def get_worker_client(self, slot_info) :
    return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))
Copy the code

The details are as follows:

<--------------------^ + | | thread loop | | | | +----------------+------------------------------------+ | | ElasticDriver._discovery_thread | | | + | | | | | | | v | | | HostManager.update_available_hosts | | | + | | | | | | | |  | | | v YES | +---------------------------+ | | update_res ! = no_update ??? +-----+ | | | | | + | | | | | | | | | | WorkerNotificationClient | | | | v | notify_hosts_updated | | | | | NO | | | | |  | _notify_workers_host_changes+------------------------> | | | | v | | | | +-----------------------------------------------------+ +---------------------------+ | | | | | | v | +-------------------->+Copy the code

Mobile phones are as follows:

3.2 How To Notify

Is to use the WorkerNotificationClient to send HostsUpdatedRequest.

3.2.1 WorkerNotificationClient

As you can see, WorkerNotificationService inherited network BasicService, Therefore WorkerNotificationClient as WorkerNotificationService operation interface, so as to send WorkerNotificationService HostsUpdatedRequest.

class WorkerNotificationClient(network.BasicClient) :
    def __init__(self, addresses, key, verbose, match_intf=False) :
        super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
                                                       addresses,
                                                       key,
                                                       verbose,
                                                       match_intf=match_intf)

    def notify_hosts_updated(self, timestamp, update_res) :
        self._send(HostsUpdatedRequest(timestamp, update_res))
Copy the code

3.2.2 WorkerNotificationService

Respond to HostsUpdatedRequest WorkerNotificationService.

class WorkerNotificationService(network.BasicService) :
    NAME = 'worker notification service'

    def __init__(self, key, nic, manager) :
        super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
                                                        key,
                                                        nic)
        self._manager = manager

    def _handle(self, req, client_address) :
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp, req.res)
            return network.AckResponse()

        return super(WorkerNotificationService, self)._handle(req, client_address)
Copy the code

3.2.3 WorkerNotificationManager

Handle_hosts_updated will be informed by registered on WorkerNotificationManager listener (that is, the State of the user code).

WorkerNotificationManager is in horovod/common/elastic. Py build, run on each host.

notification_manager = WorkerNotificationManager()
Copy the code

Specific definitions are as follows:

class WorkerNotificationManager(object) :
    def __init__(self) :
        self._lock = threading.Lock()
        self._service = None
        self._listeners = set(a)def init(self, rendezvous_addr=None, rendezvous_port=None,
             nic=None, hostname=None, local_rank=None) :
        with self._lock:
            if self._service:
                return

            rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR)
            if not rendezvous_addr:
                return

            rendezvous_port = rendezvous_port if rendezvous_port is not None else \
                int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
            nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
            hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
            local_rank = local_rank if local_rank is not None else \
                int(os.environ.get(HOROVOD_LOCAL_RANK))

            secret_key = secret.make_secret_key()
            self._service = WorkerNotificationService(secret_key, nic, self)

            value = (self._service.addresses(), secret_key)
            put_data_into_kvstore(rendezvous_addr,
                                  rendezvous_port,
                                  PUT_WORKER_ADDRESSES,
                                  self._create_id(hostname, local_rank),
                                  value)

    def register_listener(self, listener) :
        self._listeners.add(listener)

    def remove_listener(self, listener) :
        self._listeners.remove(listener)

    def handle_hosts_updated(self, timestamp, update_res) :
        for listener in self._listeners:
            listener.on_hosts_updated(timestamp, update_res)
Copy the code

3.2.4 notice of the State

Let’s go through the following process:

  • When the Driver’s timing process finds that a certain node is marked as new or removed through the node discovery script, it will send a notification to all workers.
  • Each worker has its own State, which is stored inWorkerNotificationManager . _listeners.
  • _host_messages registers host changes in state by putting a “host changed “message into its _host_messages.
  • Since this message does not have to be processed immediately, it is simply queued for State.

The logic is as follows:

<--------------------^ + | | thread loop | | | | +----------------+------------------------------------+ | | ElasticDriver._discovery_thread | | | + | | | | | | | v | | | HostManager.update_available_hosts | | | + | | | | | | | |  | | | v YES | | | update_res ! = no_update ??? +-----+ | +--------------------------+ +----------------------------+ | | + | | | | | | | | | | | | WorkerNotificationClient | | WorkerNotificationService | | | | v | notify_hosts_updated | | HostsUpdatedRequest | | | | | NO | | | | | | | | _notify_workers_host_changes+------------------------> | | +-------------------> | | | | v | | | | | | +-----------------------------------------------------+ +--------------------------+ +----------------+-----------+ | | | | | | | | handle_hosts_updated | v | | +-------------------->+ v +------------------+-----------+ | | | WorkerNotificationManager | +-----------+ +----------+ +----------+ | | | | | | | | | | | State1  |  | State 2|... | State n | <---------------------+ _listeners | | | | | | | | | +-----------+ +----------+ +----------+ | | | | ^ ^ ^ |  | | | | | | on_hosts_updated | | on_hosts_updated | on_hosts_updated | | | | | | | +--------------+-------------------+-------------------------+ handle_hosts_updated | | | +------------------------------+Copy the code

Mobile phones are as follows:

3.2.5 When to Process

When is this notification processed? The next time state.com MIT () or the lighter state.check_host_updates() is called, state.check_host_updates reads messages from _HOST_messages and accumulates updates.

As described in the comment in the check_host_updates method, the status is synchronized between each worker so that they can throw HostsUpdateInterrupt at the same time. Specifically, _bcast_object is used synchronously (and then MPI is called internally).

We’ll talk about check_host_updates in the introduction to State.

0x04 State abstraction

Horovod implements a State object, which abstracts the machine-trained model one more step.

Each Worker has a State object.

  • Horovod places all variables that need to be synchronized between workers into hvd.elastic.State (such as Model Parameters, Optimizer State, current epoch, batch schedule, and so on) objects.

  • The purpose of the State object is to periodically store the training State and restore the machine learning State from the State object when needed. In this way, when some workers have unexpected errors, it can avoid being unable to recover the site because the state is damaged.

  • For example, suppose a worker suddenly dies in the process of parameter update, and at this time, part of the gradient update may only be updated to half. This state is irreversible and cannot be continued, resulting in the damaged state of the parameter that cannot be used for recovery training.

4.1 the State

The function of State is to track the memory State among different workers.

Key variables & methods are:

  • On_reset: called when the state needs to be restarted;
  • On_hosts_updated: when there is called when the host changes, namely to _host_messages this queue in a message;
  • Commit: This function is called periodically by the user. It stores the state in memory and checks for host changes.
    • When an exception occurs, an HorovodInternalError is raised. When hvd.elastic. Run catches this exception, it restores all states from the latest commit.
    • Because the commit state is expensive (for example, too many parameters can lead to too much time), you need to strike a balance between the amount of processing time per batch and how long the training will have to recover from if something goes wrong. For example, if you commit every 10 batches you’ve trained, you’ve reduced the copy time by 10 times. But when an error occurs, you need to roll back to the previous state of 10 batches.
  • Check_host_updates: from_host_messagesReads messages and accumulates updates. As described in the annotation in the method, the state will be synchronized between each worker so that these workers can throw exceptions at the same time. Specific Synchronization Usage_bcast_object(and then calls MPI internally);
    • If the node discovery script predicts that a node needs to be removed or added, Elastic Horvod can avoid rollback operations. When the Driver’s timing process detects that a node is being added or removed using the node discovery script, it sends a notification to all workers so that the next time state.com MIT () or the lighter stat.check_host_updates () is called, An exception HostsUpdateInterrupt is thrown. This exception is similar to the HorovodInternalError exception, except that the parameter status and so on are not recovered from the latest commit, but from the current live parameters.
    • In general, if your hardware is reliable and stable and your choreography system provides adequate alerts when task nodes are removed, you can call the state.com MIT () function at low frequencies, Only the relatively inexpensive state.check_host_updates() is called at the end of each batch to check for node changes.
  • _reset_callbacks: Users can register callback functions tohvd.elastic.StateObject in response to changes in worker members.
    • For example, the callback function can handle the following cases:
      1. When the number of workers changes, the learning rate should be changed accordingly according to the new world size.
      2. Repartition the data set.
    • These callbacks are called between “Horovod is restarted “and” state is synchronized between nodes “.

Specific definitions are as follows:

class State(object) :
    """State representation used for tracking in memory state across workers. Args: bcast_object: Function used to broadcast a variable from rank 0 to the other workers. get_rank: Function that returns the current rank of this worker. """
    def __init__(self, bcast_object, get_rank) :
        self._bcast_object = bcast_object
        self._rank = get_rank
        self._host_messages = queue.Queue()
        self._last_updated_timestamp = 0
        self._reset_callbacks = []

    def on_reset(self) :
        self._host_messages = queue.Queue()
        self.reset()
        for callback in self._reset_callbacks:
            callback()

    def on_hosts_updated(self, timestamp, update_res) :
        self._host_messages.put((timestamp, update_res))

    def commit(self) :
        self.save()
        self.check_host_updates()

    def check_host_updates(self) :
        """Checks that a notification has been sent indicating that hosts can be added or will be removed. Raises a `HostsUpdatedInterrupt` if such a notification has been received. """
        # Iterate through the update messages sent from the server. If the update timestamp
        # is greater than the last update timestamp, then trigger a HostsUpdatedException.
        If the update timestamp is greater than the last update timestamp, a HostUpdateResult is triggered
        last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
        all_update = HostUpdateResult.no_update
        while not self._host_messages.empty():
            timestamp, update = self._host_messages.get()
            if timestamp > last_updated_timestamp:
                last_updated_timestamp = timestamp
                all_update |= update

        # In order to ensure all workers raise the exception at the same time, we need to sync
        # the updated state across all the workers.
        # TODO(travis): this should be a max allreduce to account for changes in rank 0
        # reads messages from '_host_messages', accumulates updates, and, as described in the method comment, synchronizes state between each worker so that they all throw exceptions at the same time. Using '_bcast_object' synchronously (and then calling MPI internally)
        prev_timestamp, self._last_updated_timestamp, all_update = \
            self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))

        # At this point, updated state is globally consistent across all ranks.
        if self._last_updated_timestamp > prev_timestamp:
            raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
Copy the code

Therefore, after we add Commit, the logic looks like this:

<--------------------^ + | | thread loop | | | | +----------------+------------------------------------+ | | ElasticDriver._discovery_thread | | | + | | | | | | | v | | | HostManager.update_available_hosts | | | + | | | | | | | |  | | | v YES | | | update_res ! = no_update ??? +-----+ | +--------------------------+ +----------------------------+ | | + | | | | | | | | | | | | WorkerNotificationClient | | WorkerNotificationService | | | | v | notify_hosts_updated | | HostsUpdatedRequest | | | | | NO | | | | | | | | _notify_workers_host_changes+------------------------> | | +-------------------> | | | | v | | | | | | +-----------------------------------------------------+ +--------------------------+ +----------------+-----------+ | | | | | | | | _bcast_object handle_hosts_updated | v | | +-------------------->+ +-------------+----------------------+ v | | | +------------------+-----------+ | | | | | v v v | WorkerNotificationManager | +--------------------+ +----+------+ +---+------+ +------+---+ | | | | | | | | | | | | | Python xxx.py +-------------------------------------> | State1  |  | State 2|... | State n | <---------------------+ _listeners | | | commit / check_host_updates | | | | | | | | +--------------------+ +-----------+ +----------+ +----------+ | | | | ^ ^ ^ | | | | | | | on_hosts_updated | | on_hosts_updated | on_hosts_updated | | | | | | | +--------------+-------------------+-------------------------+ handle_hosts_updated | | |  +------------------------------+Copy the code

The details are as follows:

Let’s move on to the levels of derived classes.

4.2 ObjectState

The purpose of ObjectState is to assemble simple Python Objects.

class ObjectState(State) :
    """State for simple Python objects. Every object is specified as a keyword argument, and will be assigned as an attribute. Args: bcast_object: Horovod broadcast object function used to sync state dictionary. get_rank: Horovod rank function used to identify is this process is the coordinator. kwargs: Properties to sync, will be exposed as attributes of the object. """
    def __init__(self, bcast_object, get_rank, **kwargs) :
        self._bcast_object = bcast_object
        self._saved_state = kwargs
        self._set_attrs()
        super(ObjectState, self).__init__(bcast_object=bcast_object, get_rank=get_rank)

    def save(self) :
        new_state = {}
        for attr in self._saved_state.keys():
            new_state[attr] = getattr(self, attr)
        self._saved_state = new_state

    def restore(self) :
        self._set_attrs()

    def sync(self) :
        if self._saved_state:
            self._saved_state = self._bcast_object(self._saved_state)
            self._set_attrs()

    def _set_attrs(self) :
        for attr, value in self._saved_state.items():
            setattr(self, attr, value)
Copy the code

4.3 TensorFlowKerasState

Horovod already provides standard State holding and recovery implementations for TensorFlow, Keras, and PyTorch by default, and you can override the HvD.elastic.State object if you need customization in certain scenarios.

TensorFlowKerasState is a state abstraction of the TensorFlowKeras Model and Optimizer.

In the initialization function, various related variables are set, such as the broadcast function.

class TensorFlowKerasState(ObjectState) :

    def __init__(self, model, optimizer=None, backend=None, **kwargs) :
        self.model = model
        if not _model_built(model):
            raise ValueError('Model must be built first. Run `model.build(input_shape)`.')

        self.optimizer = optimizer or model.optimizer
        self.backend = backend
        self._save_model()

        if not backend or _executing_eagerly():
            self._bcast_model = lambda: _broadcast_model(self.model, self.optimizer, backend=self.backend)
            bcast_object = broadcast_object
        else:
            # For TensorFlow v1, we need to reuse the broadcast op to prevent incrementing the uids
            bcast_op = broadcast_variables(_global_variables(), root_rank=0)
            self._bcast_model = lambda: self.backend.get_session().run(bcast_op)
            bcast_object = broadcast_object_fn(session=self.backend.get_session())

        super(TensorFlowKerasState, self).__init__(bcast_object=bcast_object,
                                                   get_rank=rank,
                                                   **kwargs)
Copy the code

Concrete implementation of a few methods, is basically storage, restore state, synchronization.

def save(self) :
    self._save_model()
    super(TensorFlowKerasState, self).save()

def restore(self) :
    self._load_model()
    super(TensorFlowKerasState, self).restore()

def sync(self) :
    self._bcast_model()
    self._save_model()
    super(TensorFlowKerasState, self).sync()

def _save_model(self) :
    if _executing_eagerly():
        self._saved_model_state = [tf.identity(var) for var in self.model.variables]
        self._saved_optimizer_state = [tf.identity(var) for var in self.optimizer.variables()]
    else:
        self._saved_model_state = self.model.get_weights()
        self._saved_optimizer_state = self.optimizer.get_weights()

def _load_model(self) :
    if _executing_eagerly():
        for var, saved_var in zip(self.model.variables, self._saved_model_state):
            var.assign(saved_var)
        for var, saved_var in zip(self.optimizer.variables(), self._saved_optimizer_state):
            var.assign(saved_var)
    else:
        self.model.set_weights(self._saved_model_state)
        self.optimizer.set_weights(self._saved_optimizer_state)
Copy the code

4.4 Restore

As you can see, Restore restores the model from memory.

def restore(self) :
    self._load_model()
    super(TensorFlowKerasState, self).restore()
Copy the code

So we have a question: When do we call Restore?

If horovod captures the HorovodInternalError, restore will be used.

def run_fn(func, reset) :
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs) :
        notification_manager.init()
        notification_manager.register_listener(state)
        skip_sync = False

        try:
            while True:
                if not skip_sync:
                    state.sync()

                try:
                    return func(state, *args, **kwargs)
                except HorovodInternalError:
                    state.restore() # called here
                    skip_sync = False
                except HostsUpdatedInterrupt as e:
                    skip_sync = e.skip_sync

                reset()
                state.on_reset()
        finally:
            notification_manager.remove_listener(state)
    return wrapper
Copy the code

0 x05 summary

The overall logic of the simplified version is shown below, and several key design points are as follows:

  • When there is a node change, how to find it immediately? Horovod is done with periodic calls.
  • How to notify each worker when a node change is discovered? Horovod does this by building a notification mechanism. That is, each worker to register yourself into WorkerNotificationManager above, when a node changes, WorkerNotificationManager will inform the worker one by one.
  • What happens when the worker is notified? Horovod further encapsulates the worker’s State into various states on the depth framework, and when notified, calls the corresponding callback function for the states, either to synchronize the State, or to do other processing.
+-----------------------------v ^ thread loop | | | +----------------+----------------------+ | | ElasticDriver._discovery_thread | | _notify_workers_host_changes | | | | | | +------------------+ | | | | | | | | HostManager.update_available_hosts | | | | | | | +-----------------+---------------------+ | | ^ | | | | | | | | +----------<---------------+ v v +---------------------------+ HostsUpdatedReques +----------------------------+ handle_hosts_updated +----------------------------+ | | | | | | | WorkerNotificationClient +----------------------> | WorkerNotificationService | +------------------> | WorkerNotificationManager | | | | | | | +---------------------------+  +----------------------------+ +--------+-------------------+ | | | on_hosts_updated | v +----+---+ | State | +--------+Copy the code

Mobile phones are as follows:

At this point, found the node part introduction, because this article just USES WorkerNotificationService complete notification, but no further introduction, so next article introduces internal radio and notification mechanism.

0xEE Personal information

Thoughts on life and technology

Wechat public account: Rosie’s Thinking

If you want to get updates on your own articles, or check out your own technical recommendations, please pay attention.