Let’s look at the caching mechanism. Why cache? Because there are so many workers inside the cluster. There is interaction between Master and Worker, and between Worker and Worker, so it is necessary to cache Worker and its Grpc channel. It can be said that the use of caching is everywhere in the TensorFlow distributed environment.

The other articles in this series are:

Heterogeneous Distribute Learning based on TensorFlow distributed thesis [翻译

Implementation of Control Flow in TensorFlow

TensorFlow Distributed environment (1) — overall architecture

TensorFlow distributed environment (2)– Master static logic

TensorFlow distributed environment (3)- Worker static logic

1. WorkerCache

The WorkerCache obtains the WorkerInterface instance, which can access the remote WorkerSerivice service. An example of a WorkerInterface instance is GrpcRemoteWorker.

1.1 Usage

When MasterEnv was initialized earlier, WorkerCacheFactory was configured to master_env_.worker_cache_factory.

master_env_.worker_cache_factory =
    [this] (const WorkerCacheFactoryOptions& options,
           WorkerCacheInterface** worker_cache) {
      return WorkerCacheFactory(options, worker_cache);
    };
Copy the code

Master::CreateSession contains the following abriminated code to learn how to retrieve worker_Cache (an instance of WorkerCacheInterface) from a factory class and how to use worker_Cache for subsequent operations.

void Master::CreateSession(const CreateSessionRequest* req,
                           CreateSessionResponse* resp, MyClosure done) {
  SchedClosure([this, req, resp, done]() {
      / / configure option
      WorkerCacheFactoryOptions worker_cache_factory_options;
      worker_cache_factory_options.protocol = &grpc_protocol;
      worker_cache_factory_options.rpc_options = &req->config().rpc_options(a);/ / worker_cache is established
      // Create the worker cache from the computed server_def.
      status = env_->worker_cache_factory(worker_cache_factory_options,
                                          &worker_cache);

      // Use worker_cache to do the rest
      status =
          DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
                                         worker_cache, remote_devices.get());

  });
}
Copy the code

1.2 configuration

WorkerCacheFactoryOptions equivalent to ServerDef, it contains ClusterDef, job_name, task_index information, etc.

// Options passed to the worker_cache_factory function.
struct WorkerCacheFactoryOptions {
  const ClusterDef* cluster_def = nullptr;
  const string* job_name = nullptr;
  int task_index;
  const string* protocol = nullptr;
  const RPCOptions* rpc_options = nullptr;

  WorkerCacheFactoryOptions() {}

  // Construct from a ServerDef proto.
  //
  // Note: server_def must outlive WorkerCacheFactoryOptions!
  WorkerCacheFactoryOptions(const ServerDef& server_def) {
    if (server_def.has_cluster() && !server_def.job_name().empty()) {
      cluster_def = &server_def.cluster(a); job_name = &server_def.job_name(a); task_index = server_def.task_index(a); protocol = &server_def.protocol(a); rpc_options = &server_def.default_session_config().rpc_options(a); }}};Copy the code

1.3 the factory class

WorkerCacheFactory is a function that does the following:

  • The ParseChannelSpec command is used to obtain the GrpcChannelSpec instance. GrpcChannelSpec is equivalent to ClusterSpec and contains basic cluster configuration information.
  • Run NewGrpcChannelCache to get a GrpcChannelCache channel_cache. Is used here to GetChannelCreationFunction.
  • Using NewGrpcWorkerCacheWithLocalWorker worker_cache (channel_cache).
Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
                                      WorkerCacheInterface** worker_cache) {

  / / get GrpcChannelSpec
  GrpcChannelSpec channel_spec;
  TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));

  / / get GrpcChannelCache
  std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache( channel_spec, GetChannelCreationFunction(), *options.rpc_options));

  string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0"."/task:", options.task_index);

  const string host_port = channel_cache->TranslateTask(name_prefix);
  int requested_port;

  auto colon_index = host_port.find_last_of(':');
  if(! strings::safe_strto32(host_port.substr(colon_index + 1),
                             &requested_port)) {
    return errors::Internal("Could not parse port for local server from \"",
                            host_port, "\".");
  }
  if(requested_port ! = bound_port_) {return errors::InvalidArgument("Requested port ", requested_port,
                                   " differs from expected port ", bound_port_);
  }
  // Get the Worker Cache
  *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
      channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
  return Status::OK(a); }Copy the code

1.3.1 ParseChannelSpec

ParseChannelSpec is used to obtain GrpcChannelSpec instances. GrpcChannelSpec is equivalent to ClusterSpec and contains basic cluster configuration information.

Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
                                    GrpcChannelSpec* channel_spec) {
  for (const auto& job : options.cluster_def->job()) {
    std::map<int, string> host_ports;
    for (const auto& task : job.tasks()) {
      string& host_port = host_ports[task.first];
      if(! host_port.empty()) {
        return errors::InvalidArgument("JobDef for job \"", job.name(),
                                       "\" specified two addresses for task \"",
                                       task.first, "\",", host_port, " and ",
                                       task.second);
      }
      if (job.name() == *options.job_name && task.first == options.task_index) {
        host_port = strings::StrCat(host_name_, ":", bound_port_);
      } else{ host_port = task.second; }}TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
  }
  return Status::OK(a); }Copy the code

1.3.2 NewGrpcChannelCache

NewGrpcChannelCache is used to create GrpcChannelCache instances. You can see that each Job has a SparseGrpcChannelCache. If there is only one SparseGrpcChannelCache, the SparseGrpcChannelCache is returned. Otherwise, the SparseGrpcChannelCache is combined to form a MultiGrpcChannelCache. Is the incoming channel_func GetChannelCreationFunction. We’ll talk about that later.

GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
                                      ChannelCreationFunction channel_func,
                                      const RPCOptions& options) {
  const int num_jobs = spec.host_ports_jobs().size(a);if(! num_jobs) {return nullptr;
  }
  std::vector<GrpcChannelCache*> caches;
  caches.reserve(num_jobs);
  for (auto& job : spec.host_ports_jobs()) {
    caches.push_back(
        new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func,
                                   options.num_channels_per_target()));
  }
  return caches.size() = =1 ? caches[0]
                            : new MultiGrpcChannelCache(
                                  caches, options.num_channels_per_target());
}
Copy the code

1.3.3 NewGrpcWorkerCacheWithLocalWorker

NewGrpcWorkerCacheWithLocalWorker method to create GrpcWorkerCache instance.

WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
    std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
    WorkerInterface* local_worker, const string& local_target) {
  return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
}
Copy the code

The local_worker argument is obtained and passed in via worker_impl(), which is generated in GrpcServer::Init, the local GrpcWorker.

GrpcWorker* worker_impl(a) const { return worker_impl_.get(a); }std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
                                          const ConfigProto& config) {
  return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
}

Status GrpcServer::Init(const GrpcServerOptions& opts) {
  
    / / to omit
  
    worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
                                  : NewGrpcWorker(&worker_env_, config);
  
  	/ / to omit
}  
Copy the code

We comb factory class process at present, you can see, the first input is WorkerCacheFactoryOptions, then step by step through the processing of various functions, and generate GrpcWorkerCache.

Figure 1 Factory class process

1.4 WorkerCacheInterface

1.4.1 interface

WorkerCacheInterface is the interface class from which GrpcWorkerCache is derived in the figure above.

class WorkerCacheInterface {
 public:
  virtual ~WorkerCacheInterface() {}

  // Updates *workers with strings naming the remote worker tasks to
  // which open channels have been established.
  virtual void ListWorkers(std::vector<string>* workers) const = 0;
  virtual void ListWorkersInJob(const string& job_name,
                                std::vector<string>* workers) const = 0;

  // If "target" names a remote task for which an RPC channel exists
  // or can be constructed, returns a pointer to a WorkerInterface object
  // wrapping that channel. The returned value must be destroyed by
  // calling `this->ReleaseWorker(target, ret)`
  virtual WorkerInterface* GetOrCreateWorker(const string& target) = 0;

  // Release a worker previously returned by this->GetOrCreateWorker(target).
  //
  // TODO(jeff,sanjay): Consider moving target into WorkerInterface.
  // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a
  // per-rpc-subsystem WorkerInterface creator.
  virtual void ReleaseWorker(const string& target, WorkerInterface* worker) {
    // Subclasses may override to reuse worker objects.
    delete worker;
  }

  // Set *locality with the DeviceLocality of the specified remote device
  // within its local environment. Returns true if *locality
  // was set, using only locally cached data. Returns false
  // if status data for that device was not available. Never blocks.
  virtual bool GetDeviceLocalityNonBlocking(const string& device,
                                            DeviceLocality* locality) = 0;

  // Set *locality with the DeviceLocality of the specified remote device
  // within its local environment. Callback gets Status::OK if *locality
  // was set.
  virtual void GetDeviceLocalityAsync(const string& device,
                                      DeviceLocality* locality,
                                      StatusCallback done) = 0;

  // TODO(b/189159585): Define a general client cache maker function to
  // construct client cache of different types sharing the same underling RPC
  // channels, to replace the eager and coordination cache function.
  // Build and return a EagerClientCache object wrapping that channel.
  virtual Status GetEagerClientCache( std::unique_ptr
       <:eagerclientcache>
        * eager_client_cache)
        = 0;

  // Build and return a CoordinationClientCache object wrapping that channel.
  virtual Status GetCoordinationClientCache( std::unique_ptr
       
        * coordination_client_cache)
        = 0;

  // Start/stop logging activity.
  virtual void SetLogging(bool active) {}

  // Discard any saved log data.
  virtual void ClearLogs(a) {}

  // Return logs for the identified step in *ss. Any returned data will no
  // longer be stored.
  virtual bool RetrieveLogs(int64_t step_id, StepStats* ss) { return false; }};Copy the code

WorkerCachePartial inherits WorkerCacheInterface.

// Implements the part of the interface that caches and returns remote
// device status attributes.
class WorkerCachePartial : public WorkerCacheInterface {
 public:
  bool GetDeviceLocalityNonBlocking(const string& device,
                                    DeviceLocality* locality) override;

  void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
                              StatusCallback) override;

  ~WorkerCachePartial(a)override {}

  // Clear all entries from the DeviceStatus cache.
  void FlushStatusCache(a);

 private:
  mutex mu_;

  // Initiate a GetStatusAsync to the remote task named by "task", and
  // update the cache with all the DeviceAttributes reported.
  Status RefreshDeviceStatus(const string& device_name);

  typedef std::unordered_map<string, DeviceAttributes> StatusMap;
  StatusMap device_status_cache_ TF_GUARDED_BY(mu_);
};
Copy the code

1.4.2 GrpcWorkerCache

GrpcWorkerCache inherits WorkerCachePartial.

class GrpcWorkerCache : public WorkerCachePartial {
 public:
  explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
                           WorkerInterface* local_worker,
                           const string& local_target,
                           GrpcWorkerEnv* worker_env)
      : local_target_(local_target),
        local_worker_(local_worker),
        channel_cache_(channel_cache),
        worker_env_(worker_env),
        next_round_robin_assignment_(0) {}

  const string local_target_;
  WorkerInterface* const local_worker_;  // Not owned.
  std::shared_ptr<GrpcChannelCache> channel_cache_;
  WorkerCacheLogger logger_;
  GrpcWorkerEnv* worker_env_;  // Not owned

  mutex assignment_mu_;
  std::unordered_map<std::string, size_t> target_assignments_
      TF_GUARDED_BY(assignment_mu_);
  size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
};
Copy the code

Its main function is to use ListWorkers to list the names of all workers in the cluster.

void ListWorkers(std::vector<string>* workers) const override {
  channel_cache_->ListWorkers(workers);
}

void ListWorkersInJob(const string& job_name,
                        std::vector<string>* workers) const override {
	channel_cache_->ListWorkersInJob(job_name, workers);
}
Copy the code

GetOrCreateWorker establishes the Worker based on the Worker’s RPC channel. If it is local, it returns local_worker_, which is the local GrpcWorker we set up earlier.

WorkerInterface* GetOrCreateWorker(const string& target) override {
  if (target == local_target_) {
    return local_worker_;
  } else {
    SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
    if(! channel) {return nullptr;
    }
    size_t index = AssignWorkerToThread(target);
    return NewGrpcRemoteWorker(
        channel, worker_env_->GetCompletionQueue(index),
        worker_env_->GetThreadPool(), &logger_, target); }}Copy the code

2. The RPC channel

The Worker runs on top of an RPC channel, so let’s see how to set up this RPC channel. Because workers have caches, RPC channels also have caches. GrpcChannelCache is the cache that is used to get/create RPC channels for remote workers in the cluster.

2.1 GrpcChannelCache interface

GrpcChannelCache is an interface class that defines a series of interfaces, such as:

  • ListWorkers can return the names of workers in the cluster.
  • TranslateTask: Converts the Worker name to the address information in the format of host:port.
  • FindWorkerChannel: Finds an instance of GRPC ::Channel from the cache. If it is not in the cache, it dynamically generates an instance based on the address information and puts it in the cache.
class GrpcChannelCache {
 public:
  virtual ~GrpcChannelCache() {}

  // Populates *workers with names of all workers which this object
  // was created to handle. Worker names are in the format
  // /job:
      
       /task:
       
      
  // e.g. /job:mnist/task:2
  virtual void ListWorkers(std::vector<string>* workers) = 0;
  virtual void ListWorkersInJob(const string& job_name,
                                std::vector<string>* workers) = 0;

  // If found, returns a gRPC channel that is connected to the remote
  // worker named by 'target'. 'target' is of the following
  // format: /job:<job identifier>/task:<task id>
  // E.g., /job:mnist/task:2
  virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0;

  // Translates a string in the form `/job:X/task:Z` into a host_port.
  virtual string TranslateTask(const string& task) = 0;
};
Copy the code

2.2 Cache Mechanism

CachingGrpcChannelCache is a cache class that avoids the overhead of creating a GRPC ::Channel each time. Concrete is defined as follows, derived the GenericCachingChannelCache GrpcChannelCache.

// GrpcChannelCache that caches results to FindWorkerChannel() calls.
using CachingGrpcChannelCache = GenericCachingChannelCache<GrpcChannelCache>;
Copy the code

GenericCachingChannelCache, used to cache FindWorkerChannel () call, the result of the first from the cache lookup GRPC: : Channel as an example, if none of the cache, Call FindChannelOnce to dynamically generate an instance based on the address information and place it in the cache.

GenericCachingChannelCache allows the use of multiple channels with the same goal of communication in order to improve throughput. When multiple channels exist for the same target, these channels are selected in a round Robin fashion each time FindWorkerChannel is called.

Note that absL ::flat_hash_map<string, ChannelState> channels_ is the cache set :: GRPC ::Channel as defined below.

typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
Copy the code

The specific code is:

template <typename ChannelCacheT>
class GenericCachingChannelCache : public ChannelCacheT {
 public:
  explicit GenericCachingChannelCache(int num_channels_per_target)
      : num_channels_per_target_(
            num_channels_per_target > 0 ? num_channels_per_target : 1) {}

  ~GenericCachingChannelCache(a)override {}

  SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {{mutex_lock l(mu_);
      auto iter = channels_.find(target);
      if(iter ! = channels_.end()) {
        return GetNextChannelPtrAndUpdateState(iter->second);
      }
    }
    ChannelState new_chan_state;
    for (int indx = 0; indx < num_channels_per_target_; indx++) {
      auto ch = FindChannelOnce(target);
      if(! ch)return nullptr;
      new_chan_state.channels.push_back(ch);
    }
    new_chan_state.last_used = num_channels_per_target_ - 1;

    {
      mutex_lock l(mu_);
      typename absl::flat_hash_map<string, ChannelState>::iterator iter;
      bool was_inserted;
      std::tie(iter, was_inserted) = channels_.insert({target, new_chan_state});
      return GetNextChannelPtrAndUpdateState(iter->second); }}protected:
  // Find the ClientChannel for "target". Only called when no channel was
  // found in the channels_ cache for "target". A non nullptr result will be
  // cached in channels_.
  virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;

 private:
  struct ChannelState {
    std::vector<SharedGrpcChannelPtr> channels; 
    int last_used;
  };

  // Should be called with mu_ held.
  SharedGrpcChannelPtr GetNextChannelPtrAndUpdateState( ChannelState& chan_state) {
    // Following statement is marked as Crash OK as this is an invariant of
    // code flow in this class.
    CHECK_EQ(chan_state.channels.size(), num_channels_per_target_);  // Crash OK
    chan_state.last_used =
        (chan_state.last_used + 1) % num_channels_per_target_;
    return chan_state.channels[chan_state.last_used];
  }

  const int num_channels_per_target_;
  // TODO(zhifengc): Eviction when the map becomes too big.
  mutex mu_;
  absl::flat_hash_map<string, ChannelState> channels_ TF_GUARDED_BY(mu_);
};
Copy the code

2.3 Service Derived Classes

Two more classes are derived from CachingGrpcChannelCache as follows:

2.3.1 Leaf node

SparseGrpcChannelCache is a leaf node. Each Job in the cluster has a SparseGrpcChannelCache. The GRPC ::Channel set of SparseGrpcChannelCache is the GRPC ::Channel set of Job tasks. Each Task has a GRPC ::Channel.

SparseGrpcChannelCache main variables are as follows:

  • Const string job_ID_ : Which Job this class corresponds to.
  • Const STD ::map

    Host_ports_ : indicates the host:port list of the Task corresponding to the Job.
    ,>
  • Const ChannelCreationFunction channel_func_ : Method of generating GRPC :Channel.

SparseGrpcChannelCache has the following functions:

  • ListWorkers: This method returns the list of Task names corresponding to the Job.
  • TranslateTask: get the address information (host:port format) based on a Task name, e.g. /job:ps/ Replica :1/ Task :1 May be ps1:1111.
  • FindChannelOnce: Creates a GRPC ::Channel based on a Task name. Specifically, the task ID corresponding to the worker is obtained through TranslateTask first, then the address information is obtained, and finally the ADDRESS information is used to construct GRPC ::Channel.
class SparseGrpcChannelCache : public CachingGrpcChannelCache {
 public:
  SparseGrpcChannelCache(const string& job_id,
                         const std::map<int, string>& host_ports,
                         ChannelCreationFunction channel_func,
                         int num_channels_per_target)
      : CachingGrpcChannelCache(num_channels_per_target),
        job_id_(job_id),
        host_ports_(host_ports),
        channel_func_(std::move(channel_func)) {
  }
  ~SparseGrpcChannelCache(a)override {}

  void ListWorkers(std::vector<string>* workers) override {
    workers->reserve(workers->size() + host_ports_.size());
    for (const auto& id_host_port : host_ports_) {
      workers->emplace_back(MakeAddress(job_id_, id_host_port.first)); }}void ListWorkersInJob(const string& job_name,
                        std::vector<string>* workers) override {
    if (job_name == job_id_) {
      ListWorkers(workers); }}string TranslateTask(const string& target) override {
    DeviceNameUtils::ParsedName parsed;
    if(! DeviceNameUtils::ParseFullName(target, &parsed)) {
      return "";
    }

    if(! parsed.has_job || parsed.job ! = job_id_) {return "";
    }
    if(! parsed.has_replica || parsed.replica ! =0) {
      return "";
    }
    int32_t task = parsed.has_task ? parsed.task : - 1;
    auto iter = host_ports_.find(task);
    if (iter == host_ports_.end()) {
      return "";
    }
    return iter->second;
  }

 protected:
  SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
    const string host_port = TranslateTask(target);
    if (host_port.empty()) {
    if (host_port.empty()) {
      return nullptr;
    }
    auto chan_ptr = channel_func_(host_port);
    return chan_ptr;
  }

 private:

  const string job_id_;
  const std::map<int, string> host_ports_;
  const ChannelCreationFunction channel_func_;
  TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
};
Copy the code

2.3.2 Non-leaf nodes

In order to improve the search process of SparseGrpcChannelCache and the combination management of all Worker nodes in the cluster, TF combined SparseGrpcChannelCache in the cluster to build MultiGrpcChannelCache. MultiGrpcChannelCache caches the SparseGrpcChannelCache that has been accessed.

// A ChannelCache that is the union of multiple ChannelCaches.
// Takes ownership of the caches passed to the constructor.
class MultiGrpcChannelCache : public CachingGrpcChannelCache {
 public:
  explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches,
                                 int num_channels_per_target)
      : CachingGrpcChannelCache(num_channels_per_target), caches_(caches) {}

  ~MultiGrpcChannelCache(a)override {
    for (GrpcChannelCache* cache : caches_) {
      deletecache; }}void ListWorkers(std::vector<string>* workers) override {
    for (GrpcChannelCache* cache : caches_) {
      cache->ListWorkers(workers); }}void ListWorkersInJob(const string& job_name,
                        std::vector<string>* workers) override {
    for (GrpcChannelCache* cache : caches_) {
      cache->ListWorkersInJob(job_name, workers); }}string TranslateTask(const string& target) override {
    mutex_lock l(mu_);  // could use reader lock
    GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
    if (cache == nullptr) {
      for (GrpcChannelCache* c : caches_) {
        string r = c->TranslateTask(target);
        if(! r.empty()) {
          target_caches_.insert({target, c});
          cache = c;
          break; }}}return cache->TranslateTask(target);
  }

 protected:
  SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
    for (GrpcChannelCache* cache : caches_) {
      SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
      if (ch) {
        mutex_lock l(mu_);
        target_caches_.insert({target, cache});
        returnch; }}return nullptr;
  }

 private:
  // List of channels used by this MultiGrpcChannelCache.
  const std::vector<GrpcChannelCache*> caches_;

  mutex mu_;
  // Cache of channels keyed by the target they are handling.
  // The same GrpcChannelCache can appear multiple times in the cache.
  std::unordered_map<string, GrpcChannelCache*> target_caches_
      TF_GUARDED_BY(mu_);
};
Copy the code

The current structure is as follows:

Figure 2 cache logic

2.4 generate GrpcChannelCache

The front in the generated GrpcChannelCache, incoming GetChannelCreationFunction, there were no introduction, we comb now.

  / / get GrpcChannelCache
  std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache( channel_spec, GetChannelCreationFunction(), *options.rpc_options));
Copy the code

2.4.1 Objectives & Usage

Let’s first look at how to use or target, which is to generate a SharedGrpcChannelPtr via target (host:port string). SharedGrpcChannelPtr is GRPC ::Channel.

SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
  const string host_port = TranslateTask(target);
  if (host_port.empty()) {
  if (host_port.empty()) {
    return nullptr;
  }
  auto chan_ptr = channel_func_(host_port);
  VLOG(5) < <"Channel created for: job: " << job_id_
          << " host_port: " << host_port << " target : " << target
          << " Ptr: " << chan_ptr.get(a);return chan_ptr;
}
Copy the code

2.4.2 NewHostPortGrpcChannel

We’ll start with NewHostPortGrpcChannel, an existing TF API. Its main function is to call: : GRPC: : CreateCustomChannel (GRPC API) get a GRPC: : Channel, configuration to SharedGrpcChannelPtr * channel_pointer above, Then return channel_pointer (GRPC ::Channel). The result of this method is satisfactory, but the calling method is incorrect and needs to be wrapped or transformed.

Status NewHostPortGrpcChannel(const string& target,
                              const RPCOptions* rpc_options,
                              SharedGrpcChannelPtr* channel_pointer) {
  // Minimally ensure that the target is valid
  TF_RETURN_IF_ERROR(ValidateHostPortPair(target));

  ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
  *channel_pointer = ::grpc::CreateCustomChannel(
      "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
  return Status::OK(a); }Copy the code

2.4.3 ConvertToChannelCreationFunction

ConvertToChannelCreationFunction method is used to convert incoming new_channel_func_ptr method, Turn new_channel_func_ptr into a method that generates SharedGrpcChannelPtr by passing const String & target.

ChannelCreationFunction ConvertToChannelCreationFunction(
    const std::function<Status(string, const RPCOptions*,
                               SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
  return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
    SharedGrpcChannelPtr channel_ptr;
    if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
            .ok()) {
      return channel_ptr;
    } else {
      return nullptr; }}; }Copy the code

2.4.4 GetChannelCreationFunction

GetChannelCreationFunction is to use NewHostPortGrpcChannel as the incoming parameters, get a ConvertToChannelCreationFunction method, This method is the one that can be exploited by the WorkerCache factory class.

ChannelCreationFunction GrpcServer::GetChannelCreationFunction(a) const {
  // We can do this because SparseGrpcChannelCache is robust to nullptr being
  // returned by the channel creation function
  return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
}
Copy the code

2.4.5 Usage analysis

Let’s go back to our call. Channel_func_ is GetChannelCreationFunction, then direct call can get GRPC: : Channel.

SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
  const string host_port = TranslateTask(target);
  auto chan_ptr = channel_func_(host_port);
}
Copy the code

GRPC ::Channel: GRPC ::Channel:

Figure 3 how to convert

3. Position of the Cache in the system

We’ve summarized how the Cache is initialized and used, but we’ve lost track of where the Cache is in the system. Now let’s look at where the Cache is in the system. GrpcWorkerCache The GrpcChannelCache inside the GrpcWorkerCache points to the gRPC ChannelCache inside the system to obtain the gRPC Channel of the Cache. Local_worker stores the local Worker.

Figure 4. Cache location

When GetOrCreateWorker of GrpcWorkerCache is called, it returns local_worker if target is local. Otherwise, a remote GrpcRemoteWorker is generated based on the Worker’s RPC channel.

Figure 5. Worker generation

WorkerCacheInterface (GrpcWorkerCache) can be found everywhere in Master, Worker, MasterSesision and WorkerSession. Many classes have a member variable pointing to the WorkerCacheInterface, which is quite widely used.

4. Search for the device set

In order to create a WorkerSession, MasterSession needs to know the set of devices on all workers at the remote end, so the Master will traverse all workers to get the device information on them before creating MasterSession. Because it leverages the capabilities of GrpcWorkerCache, we’ll explain it here. The basic logic is as follows:

  • According to GrpcWorkerCache: : ListWorkers access to all the Worker in the cluster name.
  • Call GetOrCreateWorker based on worker_name to find the WorkerInterface object in worker_Cache.
  • Then build GetStatusRequest and send it to the found Worker via GetStatusAsync.
  • After the Worker returns GetStatusResponse, it calls the function object in the callback function CB (WhenFound method) to get Worke’s device information. Add worker_name to the obtained device information.

Figure 6 Obtaining the device

4.1 DeviceFinder

4.4.1 definition

DeviceFinder is a function object that realizes the algorithm of finding the remote worker device. We first give the following member variables:

class DeviceFinder {
  ~DeviceFinder() {
    for (Device* dev : found_) delete dev;
  }

  typedef DeviceFinder ME;
  const MasterEnv* env_;
  WorkerCacheInterface* worker_cache_;
  std::vector<DeviceNameUtils::ParsedName> filters_;

  mutex mu_;
  int num_pending_ TF_GUARDED_BY(mu_);
  condition_variable pending_zero_;
  std::vector<Device*> found_ TF_GUARDED_BY(mu_);
  // List of targets to be contacted by this DeviceFinder. The
  // respective `bool` in `seen_targets_` indicates whether we have
  // heard from this target or not.
  std::vector<string> targets_;
  std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_);
  Status status_;

  TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
};
Copy the code

4.1.2 initialization

The main logic is: according to GrpcWorkerCache: : ListWorkers to obtain a list of all the Worker in the cluster name.

explicit DeviceFinder(
    const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
    WorkerCacheInterface* worker_cache)
    : env_(env), worker_cache_(worker_cache) {
  CHECK(worker_cache) << "Worker cache was null!";
  auto process_filter = [this] (const string& filter) {
    DeviceNameUtils::ParsedName parsed;
    if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
      filters_.push_back(parsed);
    } else {
      LOG(FATAL) << "Skipping invalid filter: "<< filter; }};for (const string& filter : device_filters) {
    process_filter(filter);
  }
  // Enumerates all known workers' target. A target name is a
  // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
  if (filters_.empty()) {
    // If no filters were specified, we list all known workers in
    // `worker_cache`.
    std::vector<string> workers;
    worker_cache->ListWorkers(&workers);
    std::swap(workers, targets_);
  } else {
    // When applying filters, we must include the local worker, even if it
    // does not match any of the filters.
    CHECK_GT(env_->local_devices.size(), 0) < <"No local devices provided.";
    const string& local_device_name = env_->local_devices[0] - >name(a); DeviceNameUtils::ParsedName local_parsed_name;CHECK(DeviceNameUtils::ParseFullName(local_device_name,
                                         &local_parsed_name));
    bool all_filters_have_job = true;
    std::unordered_set<string> filter_job_names({local_parsed_name.job});
    for (const DeviceNameUtils::ParsedName& filter : filters_) {
      all_filters_have_job = all_filters_have_job && filter.has_job;
      if (filter.has_job) {
        filter_job_names.insert(filter.job);
      }
    }

    std::vector<string> workers;
    if (all_filters_have_job) {
      // If all of the device filters have a job specified, then we only need
      // to list the workers in the jobs named in the filter, because a worker
      // in any other job would not match any filter.
      for (const string& job_name : filter_job_names) {
        VLOG(2) < <"Selectively listing workers in job: " << job_name;
        std::vector<string> workers_in_job;
        worker_cache->ListWorkersInJob(job_name, &workers_in_job);
        workers.insert(workers.end(), workers_in_job.begin(),
                       workers_in_job.end()); }}else {
      // If any of the device filters does not have a job specified, then we
      // must list the workers from all jobs.
      VLOG(2) < <"Listing workers in all jobs because some device "
              << "filter has no job specified. Filters were:";
      if (device_filters.empty()) {
        VLOG(2) < <"- <NO FILTERS>";
      } else {
        for (const string& filter : device_filters) {
          VLOG(2) < <"-" << filter;
        }
      }
      worker_cache->ListWorkers(&workers);
    }
    for (const string& name : workers) {
      if (MatchFilters(name) ||
          DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
        targets_.push_back(name);
      }
    }
  }
  seen_targets_.assign(targets_.size(), false);
}
Copy the code

4.1.3 GetRemoteDevices

The GetRemoteDevices method retrieves the remote device as follows:

  • Use finder.start () to broadcast GetStatusRequest to all workers in the cluster.
  • Use finder.wait () to collect GetStatusResponse messages returned by all workers.
  • Query results are retrieved with finder.getremoteDevices and returned to the customer.
static Status GetRemoteDevices(
    const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
    WorkerCacheInterface* worker_cache,
    std::vector<std::unique_ptr<Device>>* out_remote) {
  DeviceFinder finder(device_filters, env, worker_cache);
  finder.Start(a);TF_RETURN_IF_ERROR(finder.Wait());
  finder.GetRemoteDevices(env->local_devices, out_remote);
  return Status::OK(a); }Copy the code
4.1.3.1 Start

The Start method initializes the counter num_pending_ as the number of workers and then iterates through the workers, calling NewRemoteDevices one by one.

void Start(a) {{mutex_lock l(mu_);
    num_pending_ = targets_.size(a);if (num_pending_ == 0) {
      pending_zero_.notify_all();
    }
  }
  // Talk to all workers to get the list of available devices.
  using std::placeholders::_1;
  using std::placeholders::_2;
  for (size_t i = 0; i < targets_.size(a); ++i) {// TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
    // never be called.
    NewRemoteDevices(env_->env, worker_cache_, targets_[i],
                     std::bind(&ME::WhenFound, this, i, _1, _2)); }}Copy the code

The NewRemoteDevices logic is as follows:

  • Call GetOrCreateWorker based on worker_name to find the WorkerInterface object in worker_Cache.
  • Then build GetStatusRequest and send it to the found Worker via GetStatusAsync.
  • After the Worker returns GetStatusResponse, it calls the function object in the callback function CB (WhenFound method) to get Worke’s device information. Add worker_name to the obtained device information.
void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
                      const string& worker_name, NewRemoteDevicesDone done) {
  WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name);
  if (wi == nullptr) {
    std::vector<Device*> empty;
    done(errors::NotFound("Device ", worker_name, " is not found."), &empty);
    return;
  }
  struct Call {
    GetStatusRequest req; // Send a message
    GetStatusResponse resp; // Corresponding message
  };
  Call* call = new Call;
  // The callback function
  auto cb = [env, worker_cache, worker_name, done, wi,
             call](const Status& status) {
    Status s = status;
    std::vector<Device*> remote_devices;
    auto cleanup = gtl::MakeCleanup(
        [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
          worker_cache->ReleaseWorker(worker_name, wi);
          done(s, &remote_devices);
          delete call;
        });
    if (s.ok()) {
      DeviceNameUtils::ParsedName worker_name_parsed;
      if(! DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) || ! worker_name_parsed.has_job || ! worker_name_parsed.has_replica || ! worker_name_parsed.has_task) { s = errors::InvalidArgument("Could not parse worker name: ",
                                    worker_name);
        return;
      }
      remote_devices.reserve(call->resp.device_attributes_size());
      for (const DeviceAttributes& da : call->resp.device_attributes()) {
        DeviceNameUtils::ParsedName device_name_parsed;
        CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
            << "Device attribute name '" << da.name() < <"' could not be "
            << "parsed. Device Attribute: " << da.DebugString(a);// Preserve the exact name, if possible.
        if (device_name_parsed.job == worker_name_parsed.job &&
            device_name_parsed.replica == worker_name_parsed.replica &&
            device_name_parsed.task == worker_name_parsed.task) {
          auto d = new RemoteDevice(env, da);
          remote_devices.push_back(d);
        } else {
          DeviceAttributes da_rewritten = da;
          da_rewritten.set_name(DeviceNameUtils::FullName(
              worker_name_parsed.job, worker_name_parsed.replica,
              worker_name_parsed.task, device_name_parsed.type,
              device_name_parsed.id));
          auto d = new RemoteDevice(env, da_rewritten);

          // Experimental: Skipping over adding any TPU-type devices that aren't
          // on the job called "worker" (but still adds the CPUs of other jobs).
          if (getenv("TPU_NO_POPULATE_DEVICE_LIST_FROM_CLUSTER_SPEC") !=
              nullptr) {
            if (worker_name_parsed.job == "worker" ||
                device_name_parsed.type.find("TPU") == std::string::npos) {
              remote_devices.push_back(d); }}else {
            remote_devices.push_back(d); }}}}}; wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
                     /*fail_fast=*/false, cb);
}
Copy the code
4.1.3.2 Wait

In the Wait method, if the counter is not zero, pending_zero_.wait_for is called until the main thread periodically sleeps for 10 seconds.

Status Wait(a) {
  mutex_lock l(mu_);
  // TODO(mrry): Propagate a timeout here, since `num_pending_` may
  // never become zero.
  while(num_pending_ ! =0) {
    pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
    if(num_pending_ ! =0) {
      for (size_t i = 0; i < targets_.size(a); ++i) {if(! seen_targets_[i]) {LOG(INFO)
              << "CreateSession still waiting for response from worker: "<< targets_[i]; }}}}return status_;
}
Copy the code
4.1.3.3 Callback functions

The callback function of Start is as follows. If a Worker’s GetStatusResponse message is received, Start will call this function. WhenDone decrement the counter by 1. If the counter is 0, pending_zero_.notify_all() is called. Pending_zero_.wait_for in wait is awakened. The GetRemoteDevices method uses finder.getremoteDevices to retrieve the query results and return them to the customer.

void WhenFound(int target_index, const Status& s,
               std::vector<Device*>* devices) {
  mutex_lock l(mu_);
  seen_targets_[target_index] = true;
  if(! s.ok()) {
    LOG(ERROR) << "CreateSession failed because worker "
               << targets_[target_index] << " returned error: " << s;
    status_.Update(s);
  } else {
    found_.insert(found_.end(), devices->begin(), devices->end());
    devices->clear(a); } --num_pending_;if (num_pending_ == 0) {
    pending_zero_.notify_all();
  }
}
Copy the code

4.2 the Worker interaction

NewRemoteDevices will construct GetStatusRequest via GetStatusAsync and send it to the found Worker.

WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name);
wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
                     /*fail_fast=*/false, cb);
Copy the code

2 GrpcRemoteWorker

The WI is the found WorkerInterface, which is actually the GrpcRemoteWorker. It is the client of the gRPC and invokes the service interface of the remote WorkerService through the stub.

void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
                    GetStatusResponse* response, bool fail_fast,
                    StatusCallback done) override {
  IssueRequest(request, response, getstatus_, std::move(done), call_opts,
               fail_fast);
}
Copy the code

4.2.2 GrpcWorkerService

For the remote Worker, the message is received in GrpcWorkerService. When the GetStatusRequest message is received, it will be processed by the callback of GetStatusHandler, which is a macro.

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if(! s.ok()) { \ VLOG(3) <<"Bad response from " << #method << ":"<< s; \ } \ call->SendResponse(ToGrpcStatus(s)); The \}; \if((may_block_on_compute_pool)) { \ worker_->env()->env->SchedClosure(std::move(closure)); The \}else{ \ worker_->env()->compute_pool->Schedule(std::move(closure)); \ } \ ENQUEUE_REQUEST(method, false); The \}

  HANDLE_CALL(GetStatus, false);
Copy the code

Holdings Worker

Finally comes the Worker class, which is really just handed over to DeviceMgr and eventually returned to the remote caller via a GetStatusResponse message.

void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
                            GetStatusResponse* response, bool fail_fast,
                            StatusCallback done) {
  const DeviceMgr* dm = env_->device_mgr;
  std::vector<DeviceAttributes> devices;
  dm->ListDeviceAttributes(&devices);
  response->mutable_device_attributes() - >Reserve(devices.size());
  for (auto& d : devices) {
    response->add_device_attributes() - >Swap(&d);
  }
  done(Status::OK());
}
Copy the code

4.2.4 DeviceMgr

ListDeviceAttributes has two implementations for summarizing local device information, as follows.

void StaticDeviceMgr::ListDeviceAttributes( std::vector
       
        * devices)
        const {
  devices->reserve(devices_.size());
  for (const auto& dev : devices_) {
    devices->emplace_back(dev->attributes()); }}Copy the code

Implementation 2 is as follows:

void DynamicDeviceMgr::ListDeviceAttributes( std::vector
       
        * devices)
        const {
  tf_shared_lock l(devices_mu_);
  devices->reserve(dynamic_devices_.size());
  for (const auto& d : dynamic_devices_) {
    devices->emplace_back(d->attributes()); }}Copy the code

Now that we’ve analyzed the Cache and the set of lookup devices, let’s look at how the business handles it.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

TensorFlow Internals

TensorFlow Architecture and Design: Overview

TensorFlow kernel analysis

TensorFlow Architecture and Design: OP Essentialism

TensorFlow whitepaper

Tensorflow Developer Summit 2017

Jcf94.com/2018/02/28/…

TensorFlow 拆包(五):Distributed

TensorFlow Architecture

Tensorflow (Tensorflow)

What are In-graph replication and between-graph replication?

TensorFlow (1): create a session

05tensorflow Distributed session

Section 8, configure distributed TensorFlow

TensorFlow Distributed TensorFlow

Distributed_runtime for tensorflow source code parsing

Distributed TensorFlow: A Gentle Introduction

This article explains the essential knowledge of Tensorflow distributed training

Placer, the Placement heuristic algorithm module in TensorFlow

Graph Partitioner for TensorFlow

A communication mechanism for TensorFlow

TensorFlow distributed pit mining

TensorFlow: Distributed execution of model optimization

Tensorflow architecture process]

GRPC source code analysis (c++)