0 x00 the

Horovod is an easy-to-use, high-performance distributed training framework released by Uber in 2017, which has been widely used in the industry.

This series takes you through the source code analysis of Horovod. These articles explain how Horovod runs on Spark. This is the ninth article on how to start Horovod on Spark.

Other articles in this series are as follows:

Deep learning distributed training framework Horovod (1) — basic knowledge

Deep learning distributed training framework Horovod (2) — from the user’s perspective

Deep learning distributed training framework Horovod (3) — What is behind Horovodrun

Deep learning distributed training framework Horovod (4) — Network foundation & Driver

Deep learning distributed training framework Horovod (5) — fusion framework

Deep learning distributed training framework Horovod (6) — background architecture

Deep learning distributed training framework Horovod (6) — thread implementation

Deep learning Distributed Training framework Horovod (7) — DistributedOptimizer

Deep learning distributed training framework Horovod (8) — on Spark

0x01 Overall Architecture Diagram

First of all, we need to lay out the architecture so that you can follow it.

In general, the overall logic of Horovod on Spark is divided into the following stages:

  • Start the SparkDriverService, start the Spark task with _make_spark_thread, and then horovod waits for the start to finish.
  • Multiple threads start Spark Tasks in Spark Executor, and each task runs a SparkTaskService. SparkTaskService registers with SparkDriverTask in the main hovorod process and waits for the next command to start.
  • Horovod notifies each task to proceed to the next phase when all tasks are finished.
  • Horovod calls mpi_run (again using mpirun_rsh.py) to start the orted process on each Spark Executor to start the MPI cluster;
  • Orted runs training code on top of each executor;

Let’s see how it works.

0x02 Phase 1: Horovod starts

The logic of this part is as follows: Start SparkDriverService, start Spark Task with _make_SPARk_thread, and then horovod waits for the start to finish.

2.1 DriverService: SparkDriverService

SparkDriverService inherited driver_service BasicDriverService, so its internal launched a socket server, network interaction.

Horovod uses SparkDriverService to interact with Spark Executor (through the SparkTaskService that runs within it), for example, gathering information, having Spark start a training job, and so on. This is an RPC mechanism.

For details about the functions of SparkDriverService, see its internal processing of various requests, such as

  • CodeRequest: SparkTaskService will be used to request user code;
  • TaskHostHashIndicesRequest: get task host address;
  • TaskIndexByRankRequest: Obtains the Task index from rank.
  • Rank information is SetLocalRankToRankRequest: from local rank.
  • WaitForTaskShutdownRequest: waiting for shutdown;

This is similar to the HorovodRunDriverService described earlier.

The member variable _fn is the training function that will be sent back directly via CodeResponse when SparkTaskService requests code. This solves the code distribution problem.

class SparkDriverService(driver_service.BasicDriverService) :
    NAME = 'driver service'

    def __init__(self, initial_np, num_proc, fn, args, kwargs, key, nics) :
        super(SparkDriverService, self).__init__(num_proc,
                                                 SparkDriverService.NAME,
                                                 key, nics)
        self._initial_np = initial_np
        self._fn = fn Save the user code
        self._args = args # user parameters
        self._kwargs = kwargs 
        self._key = key
        self._nics = nics # Nic information
        self._ranks_to_indices = {}
        self._spark_job_failed = False
        self._lock = threading.Lock()
        self._task_shutdown = threading.Event()

    def _handle(self, req, client_address) :

        if isinstance(req, TaskHostHashIndicesRequest): Get task host address
            return TaskHostHashIndicesResponse(self._task_host_hash_indices[req.host_hash])

        if isinstance(req, SetLocalRankToRankRequest): Select local rank from local rank
            self._lock.acquire()

            try:
                # get index for host and local_rank
                indices = self._task_host_hash_indices[req.host]
                index = indices[req.local_rank]

                values = list(self._ranks_to_indices.values())
                prev_pos = values.index(index) if index in values else None
                if prev_pos is not None:
                    prev_rank = list(self._ranks_to_indices.keys())[prev_pos]
                    del self._ranks_to_indices[prev_rank]

                # memorize rank's index
                self._ranks_to_indices[req.rank] = index
            finally:
                self._lock.release()
            return SetLocalRankToRankResponse(index)

        if isinstance(req, TaskIndexByRankRequest): Get task index from rank
            self._lock.acquire()
            try:
                return TaskIndexByRankResponse(self._ranks_to_indices[req.rank])
            finally:
                self._lock.release()

        if isinstance(req, CodeRequest): # SparkTaskService will be used to request user code
            return CodeResponse(self._fn, self._args, self._kwargs)

        if isinstance(req, WaitForTaskShutdownRequest): Wait for the task to end
            self._task_shutdown.wait()
            return network.AckResponse()

        return super(SparkDriverService, self)._handle(req, client_address)
Copy the code

2.2 Starting spark Task: _make_spark_thread

In horovod.spark.run, _make_spark_thread creates the thread. The key code here is:

mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
result = procs.mapPartitionsWithIndex(mapper).collect()
Copy the code

The mapPartitionsWithIndex code causes Spark to run the Mapper function across multiple executors and get the result.

Create settings.num_proc Spark tasks. Each task will run mapper (_task_fn), and the external run function will wait for the results. In fact, if you need to use RDD, maybe you can use foreachPartition, so that each node will hold a partition of RDD in memory.

def _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings, use_gloo, is_elastic) :
    """Creates `settings.num_proc` Spark tasks in a parallel thread."""
    
    def run_spark() :
        """Creates `settings.num_proc` Spark tasks, each executing `_task_fn` and waits for them to terminate."""
        try:
            spark_context.setJobGroup(spark_job_group, "Horovod Spark Run", interruptOnCancel=True)
            procs = spark_context.range(0, numSlices=settings.max_np if settings.elastic else settings.num_proc)
            # We assume that folks caring about security will enable Spark RPC encryption,
            # thus ensuring that key that is passed here remains secret.
            mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
            Cause Spark to run the mapper function across multiple executors and get results
            result = procs.mapPartitionsWithIndex(mapper).collect()
            result_queue.put(result)
        except:
            driver.notify_spark_job_failed()
            raise

    spark_thread = in_thread(target=run_spark, daemon=False)
    return spark_thread
Copy the code

2.3 Waiting until the Spark Task is started

After starting the Spark Task, the main horovod process calls down to wait for the task to complete.

# wait for all tasks to register, notify them and initiate task-to-task address registration
_notify_and_register_task_addresses(driver, settings)
Copy the code

That is, in run, after _make_spark_thread, the main horovod process calls _notifY_and_register_task_addresses, Call driver.wait_for_initial_registration(settings.start_timeout) for an overall wait.

Wait for all num_proc tasks to be registered. When all spark threads are ready, the main Horovod process resumes.

2.3.1 _notify_and_register_task_addresses

In the main horovod process, _notify_and_register_task_addresses is used to wait for spark tasks to register, Call driver.wait_for_initial_registration(settings.start_timeout) for an overall wait.

Note that after sending the registration request at the same time, spark Task itself calls task.wait_for_initial_registration and waits for horovod to inform the next phase to start.

The _notify_and_register_task_addresses in the main horovod process are also complicated:

  • Call driver.wait_for_initial_registration to wait for tasks to be registered.
  • Notify_and_register registers tasks and notifies each task to start the next step.

The specific code is as follows:

def _notify_and_register_task_addresses(driver, settings, notify=True) :
    # wait for num_proc tasks to register
    Wait for num_proc tasks to register
    driver.wait_for_initial_registration(settings.start_timeout) 

    def notify_and_register(index) : Register tasks and notify each task to proceed to the next step
        task_client = task_service.SparkTaskClient(index,
                                                   driver.task_addresses_for_driver(index),
                                                   settings.key, settings.verbose)

        if notify:
            task_client.notify_initial_registration_complete()

        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver.all_task_addresses(next_task_index)
        task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
        driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)

    for index in driver.task_indices():
        in_thread(notify_and_register, (index,)) Start task in thread

    driver.wait_for_task_to_task_address_updates(settings.start_timeout)
Copy the code

We can only look at the first step, “waiting for registration”.

2.3.2 driver. Wait_for_initial_registration

Here SparkDriverSerivce first waits for all Spark executors to register.

In the class BasicDriverService (network. BasicService) : As you can see from the following code, the main Horovod process will continue to run after all spark threads are ready, only after all _num_proc is registered.

While len(self._all_task_addresses) < self._num_proc is waiting for the number of self._all_task_addresses to reach _num_proc.

class BasicDriverService(network.BasicService) :
  
  def wait_for_initial_registration(self, timeout) :
      self._wait_cond.acquire()
      try:
          # wait for self._all_task_addresses to reach _num_proc
          while len(self._all_task_addresses) < self._num_proc:
              self._wait_cond.wait(timeout.remaining())
              timeout.check_time_out_for('tasks to start')
      finally:
          self._wait_cond.release()
Copy the code

2.4 waiting for

Wait code, we have to make a special description, see the picture.

There are two wait_FOR_initial_registrations. Think of it as two sets of barriers.

Is this:

  • Barrier 1: SparkDriverSerivce Waits for all SparkTaskSerivce to be ready.
  • Barrier 2: All SparkTaskSerivce need to run together, so SparkTaskSerivce are waiting for Barrier 2. SparkDriverSerivce will tell these SparkTaskSerivce to start together;

2.3.1 Barrier 1 in Driver

In the run function, after _make_spark_thread, the main horovod process calls _notifY_and_register_task_addresses, Call driver.wait_for_initial_registration(settings.start_timeout) for an overall wait.

Wait for all num_proc tasks to be registered. When all spark threads are ready, the main Horovod process resumes. The key here is:

while len(self._all_task_addresses) < self._num_proc

Wait for self._all_task_addresses to reach _num_proc.

def wait_for_initial_registration(self, timeout) :
    self._wait_cond.acquire()
    try:
        while len(self._all_task_addresses) < self._num_proc:
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('tasks to start')
    finally:
        self._wait_cond.release()
Copy the code

In BasicDriverService, if a spark Executor registration request is received, it will be processed. The most important thing is:

self._all_task_addresses[req.index] = req.task_addresses

When all Spark executors are registered, success is awaited.

2.3.2 Barrier 2 in task

Each Spark thread runs in _task_fn, that is, in the Spark task. The general process of Spark Task can also be seen here:

  • First callregister_task;
  • The second calltask.wait_for_initial_registration(settings.start_timeout)
  • And then callwait_for_command_terminationTo wait for the end;

Task.wait_for_initial_registration waits for the condition self._initial_registration_complete = True, which waits for the register_task registration to complete.

Each Spark Executor has a SparkTaskService, so each Spark task has its own _initial_Registration_complete.

The hovorod.run main process notify_initial_Registration_complete of each SparkTaskService one by one.

That is, the _initial_Registration_complete of the SparkTaskService is notified which SparkTaskService is good. The SparkTaskService is now ready to run.

2.3.3 Overall waiting process

The overall waiting process is shown in the figure, and the number is the execution sequence:

  1. SparkDriverSerivce calls driver.wait_for_initial_registration to wait for SparkTaskSerivce to register, which is barrier 1.
  2. SparkTaskSerivce 1 registers, and SparkTaskSerivce 1 itself calls task.wait_FOR_Initial_registration to wait for horovod to start the next phase, This is barrier 2;
  3. SparkTaskSerivce 2 registers, and SparkTaskSerivce 2 itself calls task.wait_FOR_Initial_registration to wait for horovod to start the next phase, This is barrier 2;
  4. The main hovorod.run process notifysparkTaskService’s _initial_Registration_complete one by one after discovering that all tasks are registered and the barrier 1 wait is over. Only after 4 is complete can the two SparkTaskSerivce continue 5,6;
  5. SparkTaskSerivce 1 For barrier 2, continue.
  6. SparkTaskSerivce 2 For barrier 2, continue.
    SparkTaskSerivce 1          SparkTaskSerivce 2            SparkDriverSerivce

            +                           +                             +
            |                           |                             |
            |                           |                             |
            |                           |                             |
            |                           |                             |   1
            |                           |                             |
            |                           |                             |
            |                           |                             v
            |                           |
            |                           |         +--------------------------------------+
            |                           |         | barrier 1                            |
            |                           |   2     |                                      |
            |          3+-------> | | | | | | +-----------------------------------> | driver.wait_for_initial_registration | | | | | | | | | | |  | | | | +--------------------+-----------------+ | | | | | | +-----------+----------------------+ |4           |
|barrier 2                         | <---------------------------------+
|                                  |    |                              |
|task.wait_for_initial_registration|    |                              |
|                                  |    |                              |
+-----------+----------------------+    |                              |
            |                           |                              |
            |             +-------------+----------------------+       |
            |             | barrier 2                          |   4   |
            | 6           |                                    +<------+
            |             | task.wait_for_initial_registration |       |
            |             |                                    |       |
            |             +-------------+----------------------+       |
            |                           |                              |
            |                           |                              |
            |                           |  5                           |
            |                           |                              |
            v                           v                              v

Copy the code

Let’s go into the details of task startup and driver follow-up.

0x03 Phase 2: Spark Task starts

This section describes the startup process of Spark Task in detail.

The main functions of this part are: Multiple threads start spark tasks in Spark Executor. Each Spark task runs the _task_fn function, and the _task_fn function runs a SparkTaskService. SparkTaskSerivce registers with SparkDriverTask in the main hovorod process and waits for the next startup command.

At this point, the program (not the trainer, but the SparkTaskService) is already running inside the Spark Executor. Let’s look at how SparkTaskService is started and run in Spark Executor.

3.1 Spark startup logic: _task_fn

Horovod makes Spark run _task_fn with _make_mapper in thread.

def _make_mapper(driver_addresses, settings, use_gloo, is_elastic) :

    def _mapper(index, _) :
        yield _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic)

    return _mapper
Copy the code

The _task_fn function is to register the horovod into the Spark task. That is, start a SparkTaskService for each SparkTask (executor).

It is important to note that these SparkTaskServices run in spark Executor and interact with SparkDriverService in Horovod over the network.

As you can see, the overall logic of _task_fn is:

  • Start the SparkTaskService;
  • Through driver_service. SparkDriverClient. Register_task to horovod the Driver registration;
  • Wait for the start indication for the next startup via task.wait_for_initial_registration(settings.start_timeout);
  • If the next step starts, the task.wait_for_command_termination() is called and the wait ends.

Details are as follows:

def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic) :
    settings.key = key
    hosthash = host_hash(salt='{} - {}'.format(index, time.time()) if is_elastic else None)
    os.environ['HOROVOD_HOSTNAME'] = hosthash
    SparkTaskService contains a socket server that interacts with the driver
    task = task_service.SparkTaskService(index, settings.key, settings.nics,...)
    try:
        driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
        Register with the Driver in horovod
        driver_client.register_task(index, task.addresses(), hosthash)

        It's still running in spark Task, but it's not SparkTaskService, so it's just assisting and waiting
        if not is_elastic:
            Wait for the start indication for the next startup
            task.wait_for_initial_registration(settings.start_timeout)
            task_indices_on_this_host = driver_client.task_host_hash_indices(hosthash)
            local_rank_zero_index = task_indices_on_this_host[0]
        else:
            local_rank_zero_index = None

        if is_elastic:
						...... # This will be covered in a future article
        elif use_gloo or index == local_rank_zero_index:
            # Either Gloo or first task with MPI.
            Use Gloo or use the first task of MPI and let the task do the operation
            task.wait_for_command_start(settings.start_timeout)
            # Wait over
            task.wait_for_command_termination()
        else:
            # The other tasks with MPI need to wait for the first task to finish.
            Make other tasks wait for the first task to finish
            first_task_addresses = driver_client.all_task_addresses(local_rank_zero_index)
            first_task_client = \
                task_service.SparkTaskClient(local_rank_zero_index,
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            The task.wait_for_command_termination() wait is complete
            first_task_client.wait_for_command_termination()

        return task.fn_result()
    finally:
        task.shutdown()
Copy the code

3.2 SparkTaskService

Again, the following code:

task = task_service.SparkTaskService(index, settings.key, settings.nics,...)

SparkTaskService is defined in each _task_fn. That is, each Spark Executor generates one or more SparkTaskServices to run and function in the SparkTask.

3.2.1 SparkTaskService definition

SparkTaskService is defined as follows. Since BasicTaskService inherits from BasicTaskService, it will eventually start a socket server inside to interact with SparkDriverService in Horovod:

class SparkTaskService(task_service.BasicTaskService) :
    NAME_FORMAT = 'task service #%d'

    def __init__(self, index, key, nics, minimum_command_lifetime_s, verbose=0) :
        # on a Spark cluster we need our train function to see the Spark worker environment
        # this includes PYTHONPATH, HADOOP_TOKEN_FILE_LOCATION and _HOROVOD_SECRET_KEY
        env = os.environ.copy()

        # we inject the secret key here
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)

        # we also need to provide the current working dir to mpirun_exec_fn.py
        env['HOROVOD_SPARK_WORK_DIR'] = os.getcwd()

        super(SparkTaskService, self).__init__(SparkTaskService.NAME_FORMAT % index,
                                               index, key, nics, env, verbose)
        self._key = key
        self._minimum_command_lifetime_s = minimum_command_lifetime_s
        self._minimum_command_lifetime = None
Copy the code

3.2.2 Basic Functions

The basic functions of SparkTaskService are as follows.

  • _run_command will be used to start training jobs in Spark;
  • _handle handles GetTaskToTaskAddressesRequest, used to retrieve the task address, also can handle ResourcesRequest, return resources;
  • _get_resources returns spark resources;
  • Wait_for_command_termination waits for the command execution to end.

The specific code is as follows:

def _run_command(self, command, env, event,
                 stdout=None, stderr=None, index=None,
                 prefix_output_with_timestamp=False) :
    Start the training job in Spark
    super(SparkTaskService, self)._run_command(command, env, event,
                                               stdout, stderr, index,
                                               prefix_output_with_timestamp)

    if self._minimum_command_lifetime_s is not None:
        self._minimum_command_lifetime = timeout.Timeout(self._minimum_command_lifetime_s,
                                                         message='Just measuring runtime')

def _handle(self, req, client_address) :
    # return resources
    if isinstance(req, ResourcesRequest):
        return ResourcesResponse(self._get_resources())

    Get the task address
    if isinstance(req, GetTaskToTaskAddressesRequest):
        next_task_index = req.task_index
        next_task_addresses = req.all_task_addresses
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task_client = \
            SparkTaskClient(next_task_index, next_task_addresses,
                            self._key, self._verbose,
                            match_intf=True)
        return GetTaskToTaskAddressesResponse(next_task_client.addresses())

    return super(SparkTaskService, self)._handle(req, client_address)

def _get_resources(self) :
    # Return spark resources
    if LooseVersion(pyspark.__version__) >= LooseVersion('3.0.0'):
        task_context = pyspark.TaskContext.get()
        if task_context:
            return task_context.resources()
        else:
            print("Not running inside Spark worker, no resources available")
    return dict(a)def wait_for_command_termination(self) :
    """ Waits for command termination. Ensures this method takes at least self._minimum_command_lifetime_s seconds to return  after command started. """
    try:
        Wait until the command is executed
        return super(SparkTaskService, self).wait_for_command_termination()
    finally:
        # command terminated, make sure this method takes at least
        # self._minimum_command_lifetime_s seconds after command started
        # the client that started the command needs some time to connect again
        # to wait for the result (see horovod.spark.driver.rsh).
        if self._minimum_command_lifetime is not None:
            time.sleep(self._minimum_command_lifetime.remaining())
Copy the code

3.3 registered Task

The next step is to register the task with the Driver.

driver_client.register_task(index, task.addresses(), hosthash)
Copy the code

3.3.1 Sending a Registration Request

The registration is completed by calling the _send function of network.py, which uses socket, Spark Executor, and Horovod driver for network interaction:

class BasicDriverClient(network.BasicClient) :

    def register_task(self, index, task_addresses, host_hash) :
        self._send(RegisterTaskRequest(index, task_addresses, host_hash))
Copy the code

3.3.2 rainfall distribution on 10-12 Driver processing

Let’s take a look at the Driver running in Horovod (in the next section, look ahead here).

In BasicDriverService, if a RegisterTaskRequest is received, the RegisterTaskRequest is processed.

self._all_task_addresses[req.index] = req.task_addresses

This increases the number of self._all_task_addresses.

As we mentioned earlier, horovod is waiting on driver.wait_for_initial_registration, and the key is:

while len(self._all_task_addresses) < self._num_proc

If the number of self._all_task_addresses reaches _num_proc, driver.wait_for_initial_registration is completed and execution is complete.

The code to handle RegisterTaskRequest is as follows. BasicDriverService has various member variables to maintain the required information. As we explained in detail in horovod (4), the RegisterTaskRequest handler of _handle is used to update these member variables:

class BasicDriverService(network.BasicService) :

    def _handle(self, req, client_address) :
        if isinstance(req, RegisterTaskRequest):
            self._wait_cond.acquire()
            try:

                self._all_task_addresses[req.index] = req.task_addresses
                # Just use source address for service for fast probing.
                self._task_addresses_for_driver[req.index] = \
                    self._filter_by_ip(req.task_addresses, client_address[0])
                  
                # Remove host hash earlier registered under this index.
                if req.index in self._task_index_host_hash:
                    earlier_host_hash = self._task_index_host_hash[req.index]
                    ifearlier_host_hash ! = req.host_hash: self._task_host_hash_indices[earlier_host_hash].remove(req.index)# Make index -> host hash map.
                self._task_index_host_hash[req.index] = req.host_hash

                # Make host hash -> indices map.
                if req.host_hash not in self._task_host_hash_indices:
                    self._task_host_hash_indices[req.host_hash] = []
                self._task_host_hash_indices[req.host_hash].append(req.index)
                # TODO: this sorting is a problem in elastic horovod
                self._task_host_hash_indices[req.host_hash].sort()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
                
            return network.AckResponse()
Copy the code

3.4 Task Waiting for next Notification

As mentioned earlier, after the Spark Task sends a registration request to the driver, The Spark Task waits for the start indication of the next startup through task.wait_for_initial_registration(settings.start_timeout). This is when the driver considers that you are finished registering and ready to proceed to the next step.

Task.wait_for_initial_registration waits for the condition self._initial_registration_complete = True, which waits for the register_task registration to complete.

class BasicTaskService(network.BasicService) :
  
  def wait_for_initial_registration(self, timeout) :
        self._wait_cond.acquire()
        try:
            while not self._initial_registration_complete:
                self._wait_cond.wait(timeout.remaining())
                timeout.check_time_out_for('tasks to start')
        finally:
            self._wait_cond.release()
Copy the code

Each Spark Executor has a SparkTaskService, so each Spark task has its own _initial_Registration_complete.

The hovorod.run main process notify_initial_Registration_complete of each SparkTaskService one by one. That is, the _initial_Registration_complete of the SparkTaskService is notified which SparkTaskService is good.

Hovorod. Run the main process is by sending NotifyInitialRegistrationCompleteRequest to complete this step.

def notify_initial_registration_complete(self) :
    self._send(NotifyInitialRegistrationCompleteRequest())
Copy the code

BasicTaskService waiting NotifyInitialRegistrationCompleteRequest, if received, is set to True, so wait_for_initial_registration waiting is over.

if isinstance(req, NotifyInitialRegistrationCompleteRequest):
    self._wait_cond.acquire()
    try:
        self._initial_registration_complete = True
    finally:
        self._wait_cond.notify_all()
        self._wait_cond.release()
    return network.AckResponse()
Copy the code

If the spark thread is registered in the horovod, the spark thread is successfully started.

+-------------------------------------+ +----------------------------------------------------+ | Horovod Main thread | |  Spark Executor | | | | _task_fn | | | | + | | | | | | | | | | | | | | v | | +-------------------------------+ | | +---------------------+------------------------+ | | | SparkDriverService | | | | SparkTaskService | | | | | | | | + | |  | | | |1register | | | | | | | self._all_task_addresses <----------------------------------------+ | | | | | | | | | | | | | + |  | | | | | | | | | | | | | | | | | | |3              |   |             |  |               |                              |  |
| |              |                |   |             |  |               | 2| | | | v | | | | | | | | | self._wait_cond.notify_all() | | | | | | | | | + | | | | v | | | | | | | | | +---------+---------------------------+ | | | | | | | | | | | | | | | | | | | | | task.wait_for_initial_registration | |  | | | | | | | | | | | | | | | | | | | +-------------------------------------+ | | | | | | | | | | | | | | | | | | | | |  | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |  | | | | | | | | | | | | | | | | | | | | | | v | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +-------------------------------+ | | +----------------------------------------------+ | +-------------------------------------+ +----------------------------------------------------+Copy the code

The mobile phone is as follows:

0x04 Phase 3: The Driver notifies the Task that it is successfully registered

Horovod notifies tasks of the end of all tasks to proceed to the next phase.

4.1 _notify_and_register_task_addresses

Mentioned earlier. In the main horovod process, _notify_and_register_task_addresses is used to wait for spark tasks to register, Call driver.wait_for_initial_registration(settings.start_timeout) for an overall wait.

Note that after sending the registration request at the same time, spark Task itself calls task.wait_for_initial_registration and waits for horovod to inform the next phase to start.

_notify_and_register_task_addresses is also complicated:

  • Call driver.wait_for_initial_registration to wait for the task to register; (This step has now been completed)
  • Notify_and_register registers tasks and notifies each task to start the next step. (Let’s go to the next two steps here.)
  • Use driver.wait_for_task_to_task_address_updates to make sure that all tasks are OK;
def _notify_and_register_task_addresses(driver, settings, notify=True) :
    # wait for num_proc tasks to register
    driver.wait_for_initial_registration(settings.start_timeout)

    def notify_and_register(index) :
        Register tasks and notify each task to proceed to the next step
        task_client = task_service.SparkTaskClient(index,
                                                   driver.task_addresses_for_driver(index),
                                                   settings.key, settings.verbose)

        if notify:
            task_client.notify_initial_registration_complete()

        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver.all_task_addresses(next_task_index)
        task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
        driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)

    for index in driver.task_indices():
        in_thread(notify_and_register, (index,)) Register tasks and notify each task to proceed to the next step

    Check that all tasks are OK
    driver.wait_for_task_to_task_address_updates(settings.start_timeout)
Copy the code

4.2 notify_and_register

As you can see, notify_AND_register does the following:

  • Call task_client.notify_initial_registration_complete() to notify spark that the task was successfully registered. This allows all spark executors waiting for task.wait_FOR_initial_registration to run the next phase together.
  • Call driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses) to let the driver complete the registration.
def wait_for_task_to_task_address_updates(self, timeout) :
    self._wait_cond.acquire()
    try:
        while len(self._task_addresses_for_tasks) < self._initial_np:
            self.check_for_spark_job_failure()
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
    finally:
        self._wait_cond.release()
Copy the code

4.3 wait_for_task_to_task_address_updates

It is reconfirmed that all Spark tasks are OK.

def wait_for_task_to_task_address_updates(self, timeout) :
    self._wait_cond.acquire()
    try:
        while len(self._task_addresses_for_tasks) < self._initial_np:
            self.check_for_spark_job_failure()
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
    finally:
        self._wait_cond.release()
Copy the code

4.4 Waiting In Task

In Spark Task, if the next startup instruction is received, wait_for_command_termination is called to wait.

In fact, this step means that spark Exector’s own logical task is over, because the SparkTaskService will be responsible for training the code to start. Now that the logical task for _task_fn has finished, wait quietly.

4.4.1 wait_for_command_termination

In horovod – master/horovod/spark/task/task_service py

def wait_for_command_termination(self) :
    """ Waits for command termination. Ensures this method takes at least self._minimum_command_lifetime_s seconds to return  after command started. """
    try:
        return super(SparkTaskService, self).wait_for_command_termination()
    finally:
        # command terminated, make sure this method takes at least
        # self._minimum_command_lifetime_s seconds after command started
        # the client that started the command needs some time to connect again
        # to wait for the result (see horovod.spark.driver.rsh).
        if self._minimum_command_lifetime is not None:
            time.sleep(self._minimum_command_lifetime.remaining())
Copy the code

In horovod – master/horovod/runner/common/service/task_service py as you can see, in the end of the code of the thread is waiting for the training.

def wait_for_command_termination(self) :
    self._command_thread.join() # I'll explain that in a minute
Copy the code

4.4.2 _command_thread

_command_thread is explained briefly here.

When RunCommandRequest is handled by SparkTaskService, the thread running Command is assigned a value of _command_thread.

class BasicTaskService(network.BasicService) :
    def _handle(self, req, client_address) :
      
        if isinstance(req, RunCommandRequest): Run the command request
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:

                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    self._command_abort = threading.Event()
                    self._command_stdout = Pipe() if req.capture_stdout else None
                    self._command_stderr = Pipe() if req.capture_stderr else None
                    Configure various parameters
                    args = (req.command, req.env, self._command_abort,
                            self._command_stdout, self._command_stderr,
                            self._index,
                            req.prefix_output_with_timestamp)
                    Start a new thread to run the command
                    self._command_thread = in_thread(self._run_command, args)
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()  
Copy the code

The logic is as follows:

+-------------------------------------+ +----------------------------------------------------+ | Horovod Main thread | |  Spark Executor | | | | _task_fn | | | | + | | | | | | | | | | | | | | v | | +-------------------------------+ | | +---------------------+------------------------+ | | | SparkDriverService | | | | SparkTaskService | | | | | | | | + | |  | | | |1register | | | | | | | self._all_task_addresses <----------------------------------------+ | | | | | | | | | | | | | + |  | | | | | | | | | | | | | | | | | | |3              |   |             |  |               |                              |  |
| |              |                |   |             |  |               | 2                            |  |
| |              v                |   |             |  |               |                              |  |
| |  self._wait_cond.notify_all() |   |             |  |               |                              |  |
| |              +                |   |             |  |               v                              |  |
| |              |                |   +             +  +     +---------+---------------------------+  |  |
| |              |            4| RegistrationComplete | | | | | | | +-----------------+-------------+--+---> | task.wait_for_initial_registration | | |  | | | | | | | | | | | | | | | | | | +---------+---------------------------+ | | | | | | | | | | | | | | | | | | | | | |  | | | | | | | |5                            |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               v                              |  |
| |              |                |   |             |  |        wait_for_command_termination          |  |
| |              |                | 6 |  RunCommand |  |               +                              |  |
| |              |                |   |             |  |               |                              |  |
| |              +----------------------------------------------->     | 7                            |  |
| |              |                |   |             |  |               v                              |  |
| |              v                |   |             |  |        self._command_thread.join()           |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| +-------------------------------+   |             |  +----------------------------------------------+  |
+-------------------------------------+             +----------------------------------------------------+

Copy the code

The mobile phone is as follows:

So far, the first stage completed, we continue next, please look forward to.

0 x05 summary

In general, the overall logic of Horovod on Spark is divided into the following stages:

  • Start the SparkDriverService, start the Spark task with _make_spark_thread, and then horovod waits for the start to finish.
  • Multiple threads start Spark Tasks in Spark Executor, and each task runs a SparkTaskService. SparkTaskService registers with SparkDriverTask in the main hovorod process and waits for the next command to start.
  • Horovod notifies each task to proceed to the next phase when all tasks are finished.
  • Horovod calls mpi_run (again using mpirun_rsh.py) and starts Orted on each Spark executor to start the MPI cluster;
  • Orted runs training code on top of each executor;

This article introduces the first three phases, the start-up phase. Stay tuned for the next two phases.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

If you want to get a timely news feed of personal articles, or want to see the technical information of personal recommendations, please pay attention.