In TensorFlow, distributed variables are variables created on multiple devices. Mirrored variable and SyncOnRead variable are two examples. This paper analyzes distributed variables. We guided the analysis through a series of questions:

  • How is creation called here in Strategy?
  • How do I make Mirrored Variable?
  • How do I distribute tensors to devices?
  • How about maintaining a unified view of the outside world?
  • How are variables consistent?

Still amway two gods:

[TensorFlow Internals] (github.com/horance-liu… TF internal implementation mechanism interested friends are to read, will definitely harvest. Home.cnblogs.com/u/deep-lear… It’s not just TensorFlow, but there are a lot of other areas that are at the forefront of the industry.

The other articles in this series are:

Heterogeneous Distribute Learning based on TensorFlow distributed thesis [翻译

Implementation of Control Flow in TensorFlow

TensorFlow Distributed environment (1) — overall architecture

TensorFlow distributed environment (2)– Master static logic

TensorFlow distributed environment (3)- Worker static logic

TensorFlow distributed environment (4) — WorkerCache

TensorFlow distributed environment (5) — Session

TensorFlow distributed environment (6) — Master dynamic logic

TensorFlow distributed environment (7) — Worker dynamic logic

TensorFlow distributed environment (8) — communication mechanism

Distributed training using TensorFlow

TensorFlow is the foundation of DistributedStrategy

1. MirroredVariable

Tf. Distribute. MirroredStrategy support on one machine multiple synchronous distributed training on the GPU. This policy creates a copy for each GPU device. Each variable in the model is mirrored across all copies. These variables came together into a single conceptual MirroredVariable called MirroredVariable. These variables are kept in sync with each other by applying the same updates.

Figure 1 MirroredVariable

Specific code examples are as follows:

strategy = tf.distribute.MirroredStrategy(["GPU:0"."GPU:1"])
# Variable created inside scope:
with strategy.scope():
  mirrored_variable = tf.Variable(1.)

# Variable created outside scope:
regular_variable = tf.Variable(1.)

Copy the code

The print result is as follows:

>>> mirrored_variable
  MirroredVariable:{
    0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
    1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>}>>> regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>  
Copy the code

Or you can see tensorflow/python/module/module_test py example.

def test_supports_distributed_variables(self) :
  mirrored = distributed_values.MirroredVariable(
      None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
  tpu = tpu_values.TPUMirroredVariable(
      strategy=None, values=[variables.Variable(42.)], aggregation=None)
  aggregating = ps_values.AggregatingVariable(
      strategy=None, v=variables.Variable(1.), aggregation=None)

  m = module.Module()
  m.a = mirrored
Copy the code

1.1 define

The MirroredVariable annotation points out that this function was to preserve a mapping from replicas to variables whose values remained synchronized. There are no new member variables, just some member functions.

class MirroredVariable(DistributedVariable, Mirrored) :
  """Holds a map from replica to variables whose values are kept in sync."""

  def _update_replica(self, update_fn, value, **kwargs) :
    return _on_write_update_replica(self, update_fn, value, **kwargs)

  def scatter_min(self, *args, **kwargs) :
    if values_util.is_saving_non_distributed():
      return self._primary.scatter_min(*args, **kwargs)
    return super(MirroredVariable, self).scatter_min(*args, **kwargs)

  def scatter_max(self, *args, **kwargs) :
    if values_util.is_saving_non_distributed():
      return self._primary.scatter_max(*args, **kwargs)
    return super(MirroredVariable, self).scatter_max(*args, **kwargs)

  def scatter_update(self, *args, **kwargs) :
    if values_util.is_saving_non_distributed(): # Non-distributed case
      # return the local value directly
      return self._primary.scatter_update(*args, **kwargs)
    Otherwise, distributed processing is performed
    return super(MirroredVariable, self).scatter_update(*args, **kwargs)

  def _get_cross_replica(self) :
    # Return identity, to avoid directly exposing the variable to the user and
    # allowing it to be modified by mistake.
    return array_ops.identity(Mirrored._get_cross_replica(self))
Copy the code

Take the scatter_update example, which calls _primary directly when it is not distributed, or base class methods otherwise. In addition, the _update_Replica method will call _on_WRITe_update_replica for replica synchronization, and _ON_WRITe_update_replica will update from the usage context. Specific definitions in tensorflow/python/distribute/values. Py.

def _on_write_update_replica(var, update_fn, value, **kwargs) :
  """Updates variables with ON_WRITE synchronization in replica context."""
  if var.aggregation == vs.VariableAggregation.NONE:
    return update_fn(var._get_on_device_or_primary(), value, **kwargs) 

    aggregated_value = apply_aggregation_replica_context(
        value, var.aggregation, var)
    values_util.mark_as_unsaveable()

    return ds_context.get_replica_context()._update(  
        var,
        update_fn,
        args=(aggregated_value,),
        kwargs=kwargs,
        group=True)

  else:

    def merge_fn(strategy, value, **kwargs) :
      """Aggregate values and update all variables in cross replica context."""
      v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
      return var._update_cross_replica(update_fn, v, **kwargs)  

    return ds_context.get_replica_context().merge_call(
        merge_fn, args=(value,), kwargs=kwargs)
Copy the code

Looking only at these member methods, it’s hard to really get a sense of MirroredVariable, but you really needed to look at MirroredVariable’s class structure.

1.2 related classes

1.2.1 class system

MirroredVariable’s MirroredVariable structure follows, which we analyzed point by point before putting things together.

Figure 2 MirroredVariable’s magic system

1.2.2 DistributedValues

Let’s look at DistributedValues first.

Figure 3 DistributedValues

Distributed variables (DistributedValues) by the base class tf. Distribution. DistributedValues said. Tf. Distributed. DistributedValues concept for said the value of multiple devices, it contains a copy of ID to value mapping.

Tf. Distributed. DistributedValues contains a value for each copy. Depending on the subclass, these values can be synchronized when updated, on demand, or never. Tf. Distributed. DistributedValues can code (reduce) to get across a copy of a single value for tf. Distributed. The Strategy. The run input, Or use tf. Distributed. Strategy. Experimental_local_results check each copy of the value.

DistributedValues as a base class should not be directly instantiated. And should be in the distribution of the strategy to create its subclass instance, concrete can pass in tf. Distribution. DistributedDataset iteration or by tf. Distribution. The strategy. The run to create.

Tf. Distributed. The two types of representative DistributedValues are “PerReplica” and “Mirrored” values.

  • The “PerReplica” value exists on the worker device and each copy has a different value. They are made of tf. Distribution. The Strategy. Experimental_distribute_dataset and Tf. Distribution. The Strategy. Distribution_datasets_from_function return of distributed data set of iterations. They are also the tf. Distribution. The Strategy. Typical results of run back.

  • The value of “Mirrored” is similar to the value of “PerReplica” except that the values on all copies are the same. We can safely read “Mirrored” values in a cross-replica context by using values on any replica.

define

Two of the more important DistributedValues member variables are _VALUES and _primary. The initialization variable is set to the _values array, and the first variable in the array is copied as _primary.

Because derived classes use them, we analyze several of the DistributedValues member functions.

  • _get_on_device_or_primary either returns the value of this copy or the value of _primary.
  • _get_cross_replica: returns cross-replica value, which is left to the derived class.
  • _get: If a replica_id is obtained, _get_cross_replica is called to return the cross-replica value or local data.

The concept map is as follows:

Figure 4 DistributedValues

The code for DistributedValues is as follows:

@tf_export("distribute.DistributedValues", v1=[])
class DistributedValues(object) :
  """Base class for representing distributed values. A subclass instance of tf.distribute.DistributedValues is created when creating variables within a distribution strategy, iterating a tf.distribute.DistributedDataset or through tf.distribute.Strategy.run . This base class should never be instantiated directly. tf.distribute.DistributedValues contains a value per replica. Depending on the subclass, the values could either be synced on update, synced on demand, or never synced. tf.distribute.DistributedValues can be reduced to obtain single value across replicas, as input into tf.distribute.Strategy.run or the per-replica values inspected using tf.distribute.Strategy.experimental_local_results . """

  def __init__(self, values) :
    """Should only be called by subclass __init__."""
    self._values = tuple(values)

  def _get(self) :
    """Returns the value for the current device or raises a ValueError."""
    replica_id = values_util.get_current_replica_id_as_int()
    if replica_id is None:
      return self._get_cross_replica() # return cross-replica information
    else:
      return self._values[replica_id] Return local information

  def _get_cross_replica(self) :
    raise NotImplementedError(
        "DistributedValues._get_cross_replica should be implemented by "
        "sub-classes which support cross-replica accesses.")

  def _get_on_device_or_primary(self) :
    """Returns value in same replica or device if possible, else the _primary."""
    Get the current replica ID
    replica_id = values_util.get_current_replica_id_as_int()
    if replica_id is None: If there is no replica ID, look at the device collection on the local machine
      # Try to find a value on the current device.
      Current_device is a string
      current_device = device_util.canonicalize(device_util.current())
      for value in self._values: # traversal
        if device_util.canonicalize(value.device) == current_device:
          return value # returns
      return self._primary # returns _primary
    else:
      # return the value of this copy
      return self._values[replica_id]

  @property
  def _primary(self) :
    """Returns a representative component."""
    return self._values[0]

  @property
  def _devices(self) :
    return tuple(v.device for v in self._values)
Copy the code

A lot of code above USES get_current_replica_id_as_int, this function is defined in tensorflow/python/distribute/values_util py, role is to obtain the current copy of id.

def get_current_replica_id_as_int() :
  """Returns the current replica ID as an integer, or None ."""
  replica_context = ds_context.get_replica_context()
  if replica_context:
    replica_id = replica_context._replica_id
    if not isinstance(replica_id, int):
      replica_id = tensor_util.constant_value(replica_id)
  else:
    replica_id = distribute_lib.get_update_replica_id()
  return replica_id
Copy the code
use

We looked at some examples from the source code below, all using MirroredStrategy to get DistributedValues.

# 1. Created from a tf.distribute.DistributedDataset :
strategy = tf.distribute.MirroredStrategy(["GPU:0"."GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5..6..7..8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)

# 2. Returned by run :
strategy = tf.distribute.MirroredStrategy(["GPU:0"."GPU:1"])
@tf.function
def run() :
   ctx = tf.distribute.get_replica_context()
   return ctx.replica_id_in_sync_group
distributed_values = strategy.run(run)

# 3. As input into run :
strategy = tf.distribute.MirroredStrategy(["GPU:0"."GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5..6..7..8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
@tf.function
def run(input) :
   return input + 1.0
updated_value = strategy.run(run, args=(distributed_values,))

# 4. Reduce value:
strategy = tf.distribute.MirroredStrategy(["GPU:0"."GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5..6..7..8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                 distributed_values,
                                 axis = 0)

# 5. Inspect local replica values:
strategy = tf.distribute.MirroredStrategy(["GPU:0"."GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5..6..7..8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
per_replica_values = strategy.experimental_local_results(distributed_values)
print(per_replica_values)

# output result
# (
      
       ,
      
# 
      
       )
      
Copy the code

1.2.3 DistributedDelegate

Next, let’s look at the DistributedDelegate.

Figure 5 DistributedDelegate

The DistributedDelegate function is to add computing capabilities on top of DistributedValues. The _get_AS_operand is used to call the _GET method of the base DistributedValues class, obtain the value, and then evaluate it.

Figure 6 how to calculate

The DistributedDelegate is defined as follows, with some code omitted.

class DistributedDelegate(DistributedValues) :
  """A map from device to values; acts as the same type as the values."""

  def __getattr__(self, name) :
    # The '_use_resource_variables' and the attrs starts with '_self' are used
    # for restoring the saved_model proto, and '_attribute_sentinel' is used for
    # Layer tracking. At the point these attrs are queried, the variable has not
    # been initialized. Thus it should not query those of the underlying
    # components.
    if name.startswith("_self_") or name in ("_use_resource_variables"."_attribute_sentinel"."_distributed_container") :return super(DistributedDelegate, self).__getattr__(name)

    # This allows copy.copy(DistributedDelegate). When copying an object,
    # copy.copy doesn't invoke its __init__ method, instead it makes a new
    # empty object, then copies the attributes over. copy.copy looks for
    # attributes like "__getstate__" in case the object implements its custom
    # copying. Since DistributedDelegate doesn't have those attributes defined,
    # __getattr__ will be invoked, which tries to access "_values" attributes,
    # but that doesn't exist either because this is an empty object, and again
    # __getattr__ is invoked, leading to an infinite recursion.
    if name == "_values":
      raise AttributeError()

    # TODO(priyag): This needs to be made robust against pitfalls from mix use
    # __getattr__ and @property. See b/120402273.
    return getattr(self._get(), name)

  @property
  def values(self) :
    """Returns the per replica values."""
    return self._values

  def _get_as_operand(self) :
    """Returns the value for operations for the current device. Some implementations, e.g. TPUMirroredVariable , are not able to return the value type within a replica context. They can, however, return a value that can be used by the operations below. """
    return self._get()

  def __add__(self, o) :
    return self._get_as_operand() + o

  def __radd__(self, o) :
    return o + self._get_as_operand()

  def __sub__(self, o) :
    return self._get_as_operand() - o

  def __rsub__(self, o) :
    return o - self._get_as_operand()

  # omit most of the code
Copy the code

1. PerReplica

PerReplica holds a map to maintain the mapping from replicas to unsynchronized values.

class PerReplica(DistributedValues, composite_tensor.CompositeTensor) :
  """Holds a map from replica to unsynchronized values."""

  @property
  def _type_spec(self) :
    return PerReplicaSpec(
        *(type_spec.type_spec_from_value(v) for v in self._values))

  @property
  def values(self) :
    """Returns the per replica values."""
    return self._values
Copy the code

1.2.5 Mirrored

Then we came to Mirrored.

Figure 7 Mirrored

Mirrored represents variables created on multiple devices and keeps them in sync by applying the same updates to each copy. Mirrored variables are tf.variable (… synchronization=tf.VariableSynchronization.ON_WRITE…) Created. Usually they are only used for synchronous training.

Recall the function of DistributedValues, which holds a mapping from copies to values that will remain synchronized, and its _get_cross_Replica method is not implemented. Mirrored is intended to be used in cross-replica mode. So Mirrored implements _get_cross_replica. _get_cross_replica calls the _get_on_device_or_primary method of the base class DistributedValues (see corresponding section) to return the value of the replica. Or simply return the value of _primary.

The concept map is as follows:

Figure 8 How does Mirrored compute

Mirrored is defined as follows:

# Note that unlike PerReplica, Mirrored values inherit from
# DistributedDelegate and so can be used directly in cross-replica mode.
class Mirrored(DistributedDelegate) :
  """Holds a map from replica to values which are kept in sync."""

  def _get_cross_replica(self) :
    return self._get_on_device_or_primary() Call the method of the base DistributedValues class

  def _as_graph_element(self) :
    obj = self._get() Call the method of the base DistributedValues class
    conv_fn = getattr(obj, "_as_graph_element".None)
    if conv_fn and callable(conv_fn):
      return conv_fn()
    return obj
Copy the code

1.2.6 Policy

Let’s look at distributed strategy.

Figure 9 distributed policy

VariablePolicy

VariablePolicy is the base class for distributed policies, which define policies for synchronization and aggregation of distributed variables. Given that the synchronization and aggregation parameters are set on tF. Variable when creating variables in tF. distribution, Tf.distribution creates an appropriate policy object and assigns it to distributed variables. All variable operations are delegated to the corresponding policy object.

class VariablePolicy(object) :
  """Policy defining synchronization and aggregation of a distributed variable. Given synchronization and aggregation parameters set on a tf.Variable during variable creation within tf.distribute scope, tf.distribute creates an appropriate policy object and assigns it to the distributed variable. All variable operations are delegated to the respective policy object. """

  def __init__(self, aggregation) :
    self._aggregation = aggregation

  def value(self) :
    raise NotImplementedError(
        "VariablePolicy.value should be overriden by sub-classes.")

  def _is_mirrored(self) :
    raise NotImplementedError(
        "VariablePolicy._is_mirrored should be overriden by sub-classes.")

  def _as_graph_element(self, _) :
    raise NotImplementedError(
        "VariablePolicy._as_graph_element should be overriden by sub-classes.")

  def _get_cross_replica(self, var) :
    raise NotImplementedError(
        "VariablePolicy._get_cross_replica should be overriden by sub-classes.")

  def _update_replica(self, var, update_fn, value, **kwargs) :
    raise NotImplementedError(
        "VariablePolicy._update_replica should be overriden by sub-classes.")
Copy the code
OnReadPolicy

OnReadPolicy is a read policy. For example, its member variable _get_cross_replica calls var.distribute_strategy.reduce to complete the read.

class OnReadPolicy(VariablePolicy) :
  """Policy defined for tf.VariableSynchronization.ON_READ synchronization. This policy is created when synchronization is  set to tf.VariableSynchronization.ON_READ and aggregation is set to any of the values allowed by the tf.VariableAggregation enum such as NONE , SUM , MEAN or ONLY_FIRST_REPLICA when creating a tf.Variable in tf.distribute scope. """

  def _is_mirrored(self) :
    return False

  def value(self, var) :
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
          return var._get_replica(0).value()  
        return var._get_cross_replica()  
      else:
        return var._get_on_device_or_primary().value()  

  def _as_graph_element(self, var) :
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if ds_context.in_cross_replica_context():
        return ops.convert_to_tensor(var._get_cross_replica())  
    return var._get()._as_graph_element()  

  def _get_cross_replica(self, var) :
    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
      return var._get_replica(0)  Read from the first copy
    if self._aggregation == vs.VariableAggregation.SUM:
      values_util.mark_as_unsaveable() # cannot update
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      Call distribute_strategy to complete the specification
      return var.distribute_strategy.reduce(
          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
          var,
          axis=None)

  def _update_replica(self, var, update_fn, value, **kwargs) :
    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  

  def assign_add(self,
                 var,
                 value,
                 use_locking=False,
                 name=None,
                 read_value=True) :
    """Adds a value to this variable."""
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        values_util.mark_as_unsaveable()
        return values_util.on_read_assign_add_cross_replica(
            var, value, read_value=read_value)
      else:
        return values_util.on_write_assign_add(
            var,
            value,
            use_locking=use_locking,
            name=name,
            read_value=read_value)

  def assign(self, var, value, use_locking=False, name=None, read_value=True) :
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        values_util.mark_as_unsaveable()
        return values_util.on_read_assign_cross_replica(
            var, value, read_value=read_value)
      else:
        return values_util.on_write_assign(
            var,
            value,
            use_locking=use_locking,
            name=name,
            read_value=read_value)
    
  # omit most of the code
Copy the code
OnWritePolicy

The OnWritePolicy class is used to implement the write policy. It mainly calls var._get_on_device_or_primary() to complete various operations. For example, _get_cross_replica calls var._get_on_device_or_primary(). Various basic operations in values_util are also called.

class OnWritePolicy(VariablePolicy) :
  """Policy defined for tf.VariableSynchronization.ON_WRITE synchronization. This policy is created when the following synchronization and aggregation parameters are specified when creating a tf.Variable in tf.distribute scope and synchronization is equal to tf.VariableSynchronization.ON_WRITE or tf.VariableSynchronization.AUTO . """

  def _is_mirrored(self) :
    return True

  def value(self, var) :
    return var._get_on_device_or_primary().value()  

  def _as_graph_element(self, var) :
    return var._get_on_device_or_primary()._as_graph_element()  

  def _get_cross_replica(self, var) :
    # Return identity, to avoid directly exposing the variable to the user and
    # allowing it to be modified by mistake.
    return array_ops.identity(var._get_on_device_or_primary())  

  Call update_FN and _on_write_update_replica to complete the corresponding operation
  def _update_replica(self, var, update_fn, value, **kwargs) :
    if var.aggregation == variables_lib.VariableAggregation.NONE:
      return update_fn(var._get_on_device_or_primary(), value, **kwargs)  
    return _on_write_update_replica(var, update_fn, value, **kwargs)

  def assign(self, var, value, use_locking=False, name=None, read_value=True) :
    return values_util.on_write_assign(
        var, value, use_locking=use_locking, name=name, read_value=read_value)

  def assign_add(self,
                 var,
                 value,
                 use_locking=False,
                 name=None,
                 read_value=True) :
    Call values_util to finish the job
    return values_util.on_write_assign_add(
        var, value, use_locking=use_locking, name=name, read_value=read_value)

  # This will be mentioned later
  def scatter_update(self, var, sparse_delta, use_locking=False, name=None) :
    return values_util.scatter_update(
        var, sparse_delta, use_locking=use_locking, name=name)

  def get_saveable(self, var, primary_var, name) :
    """Saveable ops for AUTO variables."""
    return values_util.get_on_write_saveable(var, primary_var, name)

  def get_restore_ops(self, var, tensor) :
    return values_util.get_on_write_restore_ops(var, tensor)

  # omit most of the code
Copy the code
values_util

The above two strategies use on_write_assign_add, its definition in ensorflow/python/distribute/values_util py.

def on_write_assign_add(var, value, use_locking=False, name=None,
                        read_value=True) :
  assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
  return var._update(  
      update_fn=assign_add_fn,
      value=value,
      use_locking=use_locking,
      name=name,
      read_value=read_value)
Copy the code

OnWritePolicy also uses the scatter_update defined by Values_util and finds that the call returns to var._update.

def scatter_update(var, sparse_delta, use_locking=False, name=None) :
  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
  return var._update( 
      update_fn=scatter_update_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)
Copy the code

1.2.7 DistributedVariable

The class relationship led us to DistributedVariable, where MirroredVariable was really a key feature of magic.

Figure 10 DistributedVariable

DistributedVariable held the mapping from replicas to variables. In MirroredVariable’s case, self. _Policy was OnWritePolicy, and changes to variables were made through _policy.

class DistributedVariable(DistributedDelegate, variables_lib.Variable, core.Tensor) :
  """Holds a map from replica to variables."""

  def __init__(self, strategy, values, aggregation, var_policy=None) :
    if (aggregation == variables_lib.VariableAggregation.MEAN and
        not values[0].dtype.is_floating):
      raise ValueError(
          "creating distributed tf.Variable with aggregation=MEAN and a "
          "non-floating dtype is not supported, please use a different "
          "aggregation or dtype")
    self._distribute_strategy = strategy
    self._aggregation = aggregation
    super(DistributedVariable, self).__init__(values)
    self._common_name = self._primary.name.split(":") [0]
    # Use a weakref to make it easy to map from the contained values
    # to the container without introducing a reference cycle.
    for v in values:
      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access

    # Packed variable is used to reduce the overhead of function execution.
    # For a DistributedVariable, only one variable handle is captured into a
    # function graph. It's only supported in eager mode.
    if ops.executing_eagerly_outside_functions() and getattr(
        strategy, "_enable_packed_variable_in_eager_mode".False):
      name = "%s/packed/" % self._common_name
      self._packed_var = packed.PackedDistributedVariable(values, name=name)
    else:
      self._packed_var = None

    # tf.keras keeps track of variables initialized using this attribute. When
    # tf.keras gets the default session, it initializes all uninitialized vars.
    # We need to make _keras_initialized a member of DistributedVariable because
    # without this it will use __getattr__ which will delegate to a component
    # variable.
    self._keras_initialized = False
    # Typically, a DistributedVariable 's initializer is composed of the
    # initializers of the components variables. However, in some cases, such as
    # when restoring from a checkpoint, we may set the _initializer_op
    # property on the entire DistributedVariable .
    self._initializer_op = None
    # Set a VariablePolicy which decides how we replicate/aggregate the given
    # variable.
    self._policy = var_policy
Copy the code

How to handle this depends on the actual situation, but ultimately it all boils down to strategy or Strategy.extended.

read

When read, _get_cross_replica is called and Policy is called internally. Policy calls distribute_strategy to complete the specification.

def _get_cross_replica(self) :
  if values_util.is_saving_non_distributed(): 
    return self._primary If the storage is not distributed, return it directly
  if self._policy:
    # return cross-sample
    return self._policy._get_cross_replica(self)  

  raise NotImplementedError(
      "DistributedVariable._get_cross_replica requires a valid "
      "VariablePolicy. Please set the policy via the var_policy argument "
      "in the constructor, or override this method in sub-classes which "
      "support cross-replica accesses.")
Copy the code

Details are as follows:

Figure 11 DistributedVariable reading

scatter_update

Scatter_update, for example, also calls _policy to perform the update.

def scatter_update(self, sparse_delta, use_locking=False, name=None) :
  if values_util.is_saving_non_distributed():
    return self._primary.scatter_update(sparse_delta, use_locking, name)
  if self._policy:
    return self._policy.scatter_update(
        self, sparse_delta, use_locking=use_locking, name=name)
  return values_util.scatter_update(
      self, sparse_delta, use_locking=use_locking, name=name)
Copy the code

As discussed earlier in OnWritePolicy, scatter_update eventually calls the _update method back to DistributedVariable itself.

def scatter_update(var, sparse_delta, use_locking=False, name=None) :
  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
  return var._update(  
      update_fn=scatter_update_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)
Copy the code

Var._update has various paths to run, so we only select some of them for analysis.

def _update(self, update_fn, value, **kwargs) :
  """Applies updates depending on the context. The method calls _update_replica in replica context, _update_cross_replica in cross replica context, and update_fn in update context. If read_value is True, the method returns the updated Variable. If read_value is False, the method returns the update tf.Operation . Args: update_fn: A callable to pass to strategy.extended.update to update the variable. It should have the same signature as Variable.assign() . value: value to be passed to update_fn . **kwargs: keyword arguments to update_fn . Returns: Updated variable or tf.Operation . """
  if values_util.is_saving_non_distributed():
    return update_fn(self._primary, value, **kwargs) # non-distributed

  with ds_context.enter_or_assert_strategy(self.distribute_strategy):
    if ds_context.in_cross_replica_context():
      update_replica_id = distribute_lib.get_update_replica_id()
      if update_replica_id is not None:
        replica_value = self._get_replica(update_replica_id)
        return update_fn(replica_value, value, **kwargs)
      return self._update_cross_replica(update_fn, value, **kwargs) Update across replicas
    else:
      values_util.assert_replica_context(self.distribute_strategy)
      return self._update_replica(update_fn, value, **kwargs)
Copy the code

_update_CROSS_replica is then called for cross-replica updates.

def _update_cross_replica(self, update_fn, value, **kwargs) :
  """Applies updates across replicas. Args: update_fn: A callable to pass to strategy.extended.update to update the variable. It should has the same signature as Variable.assign() . value: value to be passed to update_fn . **kwargs: remaining arguments to update_fn . Returns: Updated variable or tf.Operation . """
  values_util.mark_as_unsaveable()
  return self.distribute_strategy.extended.update(
      self, update_fn, args=(value,), kwargs=kwargs, group=True)
Copy the code

We show it as follows:

Figure 12 DistributedVariable update

1.2.8 storage

When we looked at MirroredVariable storage next, we saw that _MirroredSaveable was used during _saveable_factory.

class MirroredVariable(DistributedVariable, Mirrored) :

  def _gather_saveables_for_checkpoint(self) :
    """Overrides Trackable method. This allows both name-based and object-based save and restore of MirroredVariables. Returns: A dictionary mapping attribute names to SaveableObject factories. """

    def _saveable_factory(name=self._common_name) :
      return _MirroredSaveable(self, self._primary, name)

    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
Copy the code

_MirroredSaveable defines how MirroredVariable was stored.

class _MirroredSaveable(saveable_object.SaveableObject) :
  """Class for defining how to restore a MirroredVariable."""

  def __init__(self, mirrored_variable, primary_variable, name) :
    self._mirrored_variable = mirrored_variable
    # this is called
    tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
                                                     primary_variable, name)
    super(_MirroredSaveable, self).__init__(tensor, spec, name)

  def restore(self, restored_tensors, restored_shapes) :
    """Restore the same value into all variables."""
    tensor, = restored_tensors
    return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor)
Copy the code

Get_on_write_saveable:

def get_on_write_saveable(var, primary_var, name) :
  """Return saveable spec for AUTO and ON_WRITE variables."""
  # We use a callable so that we don't have to evaluate this expression
  # in the case where we are trying to restore instead of save.
  def tensor() :
    if context.executing_eagerly() and not primary_var.is_initialized():
      # A SaveSpec tensor value of None indicates that the variable is
      # uninitialized.
      return None
    strategy = var.distribute_strategy
    return strategy.extended.read_var(var) # Get tensor

  spec = saveable_object.SaveSpec(
      tensor=tensor,
      slice_spec="",
      name=name,
      dtype=var.dtype,
      device=primary_var.device)

  return tensor, [spec]
Copy the code

Tensorflow/python/distribute/mirrored_strategy. Py here goes across a copy of the value.

def read_var(self, replica_local_var) :
  """Read the aggregate value of a replica-local variable."""
  if distribute_utils.is_sync_on_read(replica_local_var):
    return replica_local_var._get_cross_replica()
  return array_ops.identity(replica_local_var._get())
Copy the code

1.2.9 summary

As a result of that analysis MirroredVariable inheritance was mirrored below, with many of its features eventually rolling over to TF.Distribut.Strategy.

Figure 13 MirroredVariable inheritance Annotations

1.3 Building Variables

The variable created during MirroredStrategy was a MirroredVariable. If no device is specified in the construction parameter of the policy, it will use all available Gpus. If no GPU is found, it uses the available CPU. Note that TensorFlow treats all cpus on a machine as a single device and internally parallelizes using threads. Let’s look at how MirroredVariable was built.

1.3.1 StrategyBase

First of all, in tensorflow/python/distribute/distribute_lib py has the following code, to clarify the truth about the scope of use, or _extended played a role.

def scope(self) :
  """Returns a context manager selecting this Strategy as current. Inside a with strategy.scope(): code block, this thread will use a variable creator set by strategy , and will enter its "cross-replica context". Returns: A context manager. """
  return self._extended._scope(self)  
Copy the code

1.3.2 StrategyExtendedV2

Which brings us to StrategyExtension DV2. StrategyExtendedV2 calls creator_with_resource_vars to provide a mechanism for creating variables, Inside creator_WITH_resource_vars, the derived class’s _create_variable is called to create the variable.

def _scope(self, strategy) :
  """Implementation of tf.distribute.Strategy.scope()."""

  def creator_with_resource_vars(next_creator, **kwargs) :
    """Variable creator to use in _CurrentDistributionContext ."""
    _require_strategy_scope_extended(self)
    kwargs["use_resource"] = True
    kwargs["distribute_strategy"] = strategy

    # Unwrap initial_value if it is a CheckpointInitialValue to avoid
    # dereferencing a Tensor that is without a name . We still need to
    # propagate the metadata it's holding.
    if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
      checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
      kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
    elif isinstance(kwargs["initial_value"],
                    trackable.CheckpointInitialValueCallable):
      checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
    elif (isinstance(kwargs["initial_value"], functools.partial) and
          isinstance(kwargs["initial_value"].func,
                     trackable.CheckpointInitialValueCallable)):
      # Some libraries (e.g, Keras) create partial function out of initializer
      # to bind shape/dtype, for example:
      # initial_val = functools.partial(initializer, shape, dtype=dtype)
      # Therefore to get the restore_uid we need to examine the "func" of
      # the partial function.
      checkpoint_restore_uid = kwargs[
          "initial_value"].func.checkpoint_position.restore_uid
    else:
      checkpoint_restore_uid = None

    created = self._create_variable(next_creator, **kwargs)

    if checkpoint_restore_uid is not None:
      # Let the checkpointing infrastructure know that the variable was
      # already restored so it doesn't waste memory loading the value again.
      # In this case of CheckpointInitialValueCallable this may already be
      # done by the final variable creator, but it doesn't hurt to do it
      # again.
      created._maybe_initialize_trackable()
      created._update_uid = checkpoint_restore_uid
    return created

  def distributed_getter(getter, *args, **kwargs) :
    return getter(*args, **kwargs)

  Creator_with_resource_vars is used
  return _CurrentDistributionContext(
      strategy,
      variable_scope.variable_creator_scope(creator_with_resource_vars), Configure how to create variables
      variable_scope.variable_scope(
          variable_scope.get_variable_scope(),
          custom_getter=distributed_getter), self._default_device)
Copy the code

The logic is as follows, after entering the scope after a series of operations, returned to the _CurrentDistributionContext, its internal will be, a series of operations, we continue to look at.

Figure 14. How to create a variable

1.3.3 _CurrentDistributionContext

_CurrentDistributionContext maintenance strategy information, set all kinds of scope, return policy.

class _CurrentDistributionContext(object) :
  """Context manager setting the current tf.distribute.Strategy . Also: overrides the variable creator and optionally the current device. """

  def __init__(self,
               strategy,
               var_creator_scope,
               var_scope=None,
               resource_creator_scope=None,
               default_device=None) :
    self._context = distribution_strategy_context._CrossReplicaThreadMode( 
        strategy)
    self._var_creator_scope = var_creator_scope
    self._var_scope = var_scope
    self._resource_creator_scope = resource_creator_scope
    if default_device:
      self._device_scope = ops.device(default_device)
    else:
      self._device_scope = None
    self._same_scope_again_count = 0

  def __enter__(self) :
    # Allow this scope to be entered if this strategy is already in scope.
    if distribution_strategy_context.has_strategy():
      _require_cross_replica_or_default_context_extended(
          self._context.strategy.extended)
      self._same_scope_again_count += 1
    else:
      _push_per_thread_mode(self._context)
      if self._var_scope:
        self._var_scope.__enter__()
      self._var_creator_scope.__enter__()
      if self._resource_creator_scope:
        nest.map_structure(lambda scope: scope.__enter__(),
                           self._resource_creator_scope)
      if self._device_scope:
        self._device_scope.__enter__()
    return self._context.strategy

  def __exit__(self, exception_type, exception_value, traceback) :
    if self._same_scope_again_count > 0:
      self._same_scope_again_count -= 1
      return
    if self._device_scope:
      try:
        self._device_scope.__exit__(exception_type, exception_value, traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Device scope nesting error: move call to "
                         "tf.distribute.set_strategy() out of with scope."),
            e)

    try:
      self._var_creator_scope.__exit__(
          exception_type, exception_value, traceback)
    except RuntimeError as e:
      six.raise_from(
          RuntimeError("Variable creator scope nesting error: move call to "
                       "tf.distribute.set_strategy() out of with scope."),
          e)

    if self._resource_creator_scope:
      try:
        if isinstance(self._resource_creator_scope, list):
          reversed_resource_creator_scope = self._resource_creator_scope[::-1]
          nest.map_structure(
              lambda scope: scope.__exit__(exception_type, exception_value,  
                                           traceback),
              reversed_resource_creator_scope)

        else:
          self._resource_creator_scope.__exit__(exception_type, exception_value,
                                                traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Resource creator scope nesting error: move call "
                         "to tf.distribute.set_strategy() out of with "
                         "scope."), e)

    if self._var_scope:
      try:
        self._var_scope.__exit__(exception_type, exception_value, traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Variable scope nesting error: move call to "
                         "tf.distribute.set_strategy() out of with scope."),
            e)
    _pop_per_thread_mode()
Copy the code

1.3.4 MirroredStrategy

From the above analysis, we can see that when using Strategy, the _create_variable of Strategy will be used to generate the variable.

Create_variable is responsible for specific services. It uses self._devices, and then calls distribute_utils.create_mirrored_variable, which uses real_mirrored_creator, Variables were created by VARIABLE_CLASS_MAPPING and create_mirrored_variable. Real_mirrored_creator configured specific names for variables, and subsequent safsafes used them to determine which devices the variables should be placed on. For the first device, the original name is used, and for subsequent devices, the /replica_ device number is added after the original variable name to distinguish it from the original variable. The value of the original variable is then assigned to the corresponding copy variables.

def _create_variable(self, next_creator, **kwargs) :
  """Create a mirrored variable. See DistributionStrategy.scope ."""
  colocate_with = kwargs.pop("colocate_with".None)
  if colocate_with is None:
    devices = self._devices
  elif isinstance(colocate_with, numpy_dataset.SingleDevice):
    with ops.device(colocate_with.device):
      return next_creator(**kwargs)
  else:
    devices = colocate_with._devices  

  def _real_mirrored_creator(**kwargs) :  
    value_list = []
    for i, d in enumerate(devices):
      with ops.device(d):
        kwargs["initial_value"] = self._get_variable_creator_initial_value(
            replica_id=i,
            device=d,
            primary_var=value_list[0] if value_list else None,
            **kwargs)
        if i > 0:
          # Give replicas meaningful distinct names:
          var0name = value_list[0].name.split(":") [0]
          # We append a / to variable names created on replicas with id > 0 to
          # ensure that we ignore the name scope and instead use the given
          # name as the absolute name of the variable.
          kwargs["name"] = "%s/replica_%d/" % (var0name, i)
        with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
          # Don't record operations (e.g. other variable reads) during
          # variable creation.
          with tape.stop_recording():
            v = next_creator(**kwargs)
        assert not isinstance(v, values.DistributedVariable)
        value_list.append(v)
    return value_list

  return distribute_utils.create_mirrored_variable(
      self._container_strategy(), _real_mirrored_creator,
      distribute_utils.VARIABLE_CLASS_MAPPING,
      distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)
Copy the code

VARIABLE_CLASS_MAPPING is used to specify which type of variable to generate. VARIABLE_POLICY_MAPPING sets the policy used to handle read/write synchronization.

# The following mapping indicates the policy that you must use for a given
# variable synchronization and aggregation pair.
# OnWritePolicy is used for:
# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# OnReadPolicy is used for:
# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
VARIABLE_POLICY_MAPPING = {
    vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
    vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
}

VARIABLE_CLASS_MAPPING = {
    "VariableClass": values_lib.DistributedVariable,
    vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable, # We follow here
    vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
}
Copy the code

1.3.5 distribute_utils

Tensorflow/python/distribute/distribute_utils. Py create_mirrored_variable specific variables is established. For our example, class_mapping was values_lib.Mirroredvariable.

def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping, policy_mapping, **kwargs) :
  """Create distributed variables with given synchronization and aggregation."""
  # Figure out what collections this variable should be added to.
  # We'll add the MirroredVariable to those collections instead.
  var_collections = kwargs.pop("collections".None)
  if var_collections is None:
    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
  kwargs["collections"] = []

  synchronization = _validate_synchronization(kwargs)
  # Update synchronization in kwargs in case it's AUTO, which is converted to
  # ON_WRITE.
  kwargs["synchronization"] = synchronization
  aggregation = _validate_aggregation(kwargs)
  use_var_policy = getattr(strategy.extended, "_use_var_policy".False)

  # Ignore user-specified caching device, not needed for mirrored variables.
  kwargs.pop("caching_device".None)

  with tape.stop_recording():
    Build a list of mirror variables
    value_list = real_mirrored_creator(**kwargs)
    # MirroredVariable is recreated during saved_model loading, and its
    # component variables (value_list) will have None initializer. We
    # set their initializers to no_op so that consumer like
    # global_variables_initializer wouldn't complain, as it groups all
    # variables' initializers thus all variables have to have initializers.
    for v in value_list:
      if hasattr(v, "_initializer_op") and v._initializer_op is None:
        v._initializer_op = control_flow_ops.no_op()
    if use_var_policy:
      Get policy, get class, generate variable
      var_policy_cls = policy_mapping.get(synchronization)
      var_policy = var_policy_cls(aggregation=aggregation)
      var_cls = class_mapping.get("VariableClass")
      result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
    else:
      var_cls = class_mapping.get(synchronization)
      result = var_cls(strategy, value_list, aggregation)

  # Add the wrapped variable to the requested collections.
  # The handling of eager mode and the global step matches
  # ResourceVariable._init_from_args().
  if not context.executing_eagerly():
    g = ops.get_default_graph()
    # If "trainable" is True, next_creator() will add the member variables
    # to the TRAINABLE_VARIABLES collection, so we manually remove
    # them and replace with the MirroredVariable. We can't set
    # "trainable" to False for next_creator() since that causes functions
    # like implicit_gradients to skip those variables.
    if kwargs.get("trainable".True):
      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
      for value in value_list:
        for i, trainable_variable in enumerate(l):
          if value is trainable_variable:
            del l[i]
            break

    g.add_to_collections(var_collections, result)
  elif ops.GraphKeys.GLOBAL_STEP in var_collections:
    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

  return result
Copy the code

The final build logic is as follows, _var_creator_scope will point to creator_with_resource_vars _CurrentDistributionContext member function. When variables were generated, creator_with_resource_vars was called layer by layer, leading to MirroredVariable.

Figure 15 creating a variable

1.4 summarize

Here’s what we’ve answered so far:

  • How to call Strategy here?
    • Read and write variables eventually fall on strategy or Strategy.extended.
  • How do I make Mirrored Variable?
    • The user can get a context in the scope. The context provides a way to create variables. The variables created in the context are Mirrored Variable.
  • How do I distribute tensors to devices?
    • When Strategy is used, the variable is generated using Strategy’s _create_variable. _create_variable eventually called _real_mirrored_creator.
    • _real_mirrored_Creator configured specific names for variables, and subsequent safsafes used them to determine which devices the variables should be placed on. For the first device, the original name is still used here, while for the subsequent devices, / Replica _ device number is added after the original variable name, so that it can be distinguished from the original variable.
    • In subsequent placement, variables are assigned based on device names and placed on corresponding devices.
  • How about maintaining a unified view of the outside world?
    • In the context, the user gets Mirrored Variable, which hides internal variables and provides a unified view. For example, when reading, _get_cross_replica is called and Policy is called internally. Policy calls distribute_strategy to complete the specification.
  • How are variables consistent?
    • As noted in the scatter_update analysis above, variables are called on Strategy. extended, where variables are kept consistent with each other through, for example, all-reduce, which we’ll look at in more detail later.

To illustrate this point, consider A MirroredVariable A that was made up of three tensors. Every Worker felt like they were updating MirroredVariable A, which was actually updating different variables individually, with variables being consistent through things like all-reduce.

Figure 16 how to update

2. ShardedVariable

In machine learning training, if the variable is too large to fit on a single device (such as a large embedding), it may need to be sharded on multiple devices. In TensorFlow, the corresponding concept to this idea is ShardedVariable.

Figure 17 ShardedVariable

Variable sharding is the splitting of a Variable into many smaller variables, called shards. ShardedVariable can be thought of as a container, and the “variables” in the container should be considered sharded. The ShardedVariable class maintains a list of smaller variables that can be stored independently on different devices (for example, multiple parameter servers) and is responsible for saving and restoring these variables as if they were a larger variable. Variable sharding is useful for alleviating network load when allocating access to these sharding, as well as for computation and storage when allocating a common variable across multiple parameter servers.

Figure 18 ShardedVariable container

Objects of the ShardedVariable class can be saved with a given number of shards and then restored from the checkpoint to a different number of shards. SavedModel can be used by programs such as the TF Serving API, but tF.saved_model.load is not supported. Since ShardedVariable can be saved and then restored to a different number of shards depending on the recovery environment (for example, the TF Serving API restores to just one shard for efficiency), when using ShardedVariable in tF.function, It should generally not be assumed that it has the same number of shards when saved and loaded.

2.1 the problem

For ShardedVariable, we still use several questions to guide the analysis.

  • How to save parameters to parameter server?
  • How to fragment parameters?
  • How to put calculations (gradient update parameter operations) on the parameter server? (This will be analyzed in subsequent chapters)
  • Are coordinators randomly assigned to calculations? (This will be analyzed in subsequent chapters)

2.2 define

The definition of ShardedVariable is not very much. The essence of ShardedVariable is in the base class ShardedVariableMixin, which we will analyze later.

Figure 19 ShardedVariable definition

The specific definition code is as follows:

class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor) :
  """A container for Variables that should be treated as shards. """

  @property
  def _type_spec(self) :
    return ShardedVariableSpec(
        *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
          for v in self._variables))

  @classmethod
  def _overload_all_operators(cls) :
    """Register overloads for all operators."""
    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
      if operator == '__getitem__':
        continue

      cls._overload_operator(operator)

  @classmethod
  def _overload_operator(cls, operator) :
    """Delegate an operator overload to ops.Tensor ."""
    tensor_operator = getattr(ops.Tensor, operator)

    def _operator(v, *args, **kwargs) :
      return tensor_operator(_var_to_tensor(v), *args, **kwargs)

    setattr(cls, operator, _operator)
Copy the code

2.3 Partitioning

Partition is the essence of ShardedVariable. Let’s explore its mechanism. Note that ShardedVariable only supports partitioning in the first dimension.

2.3.1 base class

There’s not much to the base Partitioner, whose derived classes need to implement Call.

@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
class Partitioner(object) :
  """Partitioner base class: all partitiners inherit from this class. Partitioners should implement a __call__ method with the following signature: ```python def __call__(self, shape, dtype, axis=0): # Partitions the given shape and returns the partition results. # See docstring of __call__ method for the format of partition results. ``` """

  def __call__(self, shape, dtype, axis=0) :
    """Partitions the given shape and returns the partition results. Examples of a partitioner that allocates a fixed number  of shards: ```python partitioner = FixedShardsPartitioner(num_shards=2) partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0) print(partitions) # [2, 0] ``` Args: shape: a tf.TensorShape , the shape to partition. dtype: a tf.dtypes.Dtype indicating the type of the partition value. axis: The axis to partition along. Default: outermost axis. Returns: A list of integers representing the number of partitions on each axis, where i-th value correponds to i-th axis. """
    raise NotImplementedError

Copy the code

2.2.4 Fixed Partitions

The FixedShardsPartitioner breaks variables into fixed shards. For this example, min(self._num_shards, shape.dims[axis].value) = min(2, 10) when axis = 0, so split into two shards.

@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
class FixedShardsPartitioner(Partitioner) :
  """Partitioner that allocates a fixed number of shards. Examples: >>> # standalone usage: >>> partitioner = FixedShardsPartitioner(num_shards=2) >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32) >>> [2, 1] >>> >>> # use in ParameterServerStrategy >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) """

  def __init__(self, num_shards) :
    """Creates a new FixedShardsPartitioner . Args: num_shards: int , number of shards to partition. """
    self._num_shards = num_shards

  def __call__(self, shape, dtype, axis=0) :
    del dtype
    result = [1] * len(shape)
    result[axis] = min(self._num_shards, shape.dims[axis].value)
    return result
Copy the code

2.2.5 Minimum Partition

The MinSizePartitioner allocates the smallest partitioner for each fragment. The partitioner ensures that each shard has at least “min_shard_ bytes “and tries to allocate as many shards as possible, keeping the shard size as small as possible. The maximum number of such shards (upper limit) is given by “max_Shard”.

@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
class MinSizePartitioner(Partitioner) :
  """Partitioner that allocates a minimum size per shard. This partitioner ensures each shard has at least min_shard_bytes  , and tries to allocate as many shards as possible, i.e., keeping shard size as small as possible. The maximum number of such shards (upper bound) is given by max_shards . Examples: >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2) >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) >>> [2, 1] >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10) >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) >>> [6, 1] >>> >>> # use in ParameterServerStrategy >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) """

  def __init__(self,
               min_shard_bytes=256 << 10,
               max_shards=1,
               bytes_per_string=16) :
    """Creates a new MinSizePartitioner . Args: min_shard_bytes: Minimum bytes of each shard. Defaults to 256K. max_shards: Upper bound on the number of shards. Defaults to 1. bytes_per_string: If the partition value is of type string, this provides an estimate of how large each string is. """
    self._min_shard_bytes = min_shard_bytes
    self._max_shards = max_shards
    self._bytes_per_string = bytes_per_string

  def __call__(self, shape, dtype, axis=0) :
    return partitioned_variables.min_max_variable_partitioner(
        max_partitions=self._max_shards,
        axis=axis,
        min_slice_size=self._min_shard_bytes,
        bytes_per_string_element=self._bytes_per_string)(shape, dtype)

Copy the code

Min_max_variable_partitioner is the concrete business implementation. This method returns a divider that partitions variables of “given shape and data type” so that the minimum value each partition has is a slice of min_slice_size. The maximum number of such partitions (the upper limit) is given by max_partitions.

@tf_export(v1=["min_max_variable_partitioner"])
def min_max_variable_partitioner(max_partitions=1, axis=0,
                                 min_slice_size=256 << 10,
                                 bytes_per_string_element=16) :
  """Partitioner to allocate minimum size per slice. Returns a partitioner that partitions the variable of given shape and  dtype such that each partition has a minimum of min_slice_size slice of the variable. The maximum number of such partitions (upper bound) is given by max_partitions . Args: max_partitions: Upper bound on the number of partitions. Defaults to 1. axis: Axis along which to partition the variable. Defaults to 0. min_slice_size: Minimum size of the variable slice per partition. Defaults to 256K. bytes_per_string_element: If the Variable is of type string, this provides an estimate of how large each scalar in the Variable is. Returns: A partition function usable as the partitioner argument to variable_scope and get_variable . """
  def _partitioner(shape, dtype) :
    """Partitioner that partitions list for a variable of given shape and type. Ex: Consider partitioning a variable of type float32 with shape=[1024, 1024]. If max_partitions >= 16, this function would return [(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1]. If max_partitions < 16, this function would return [ max_partitions , 1]. Args: shape: Shape of the variable. dtype: Type of the variable. Returns: List of partitions for each axis (currently only one axis can be partitioned). Raises: ValueError: If axis to partition along does not exist for the variable. """
    if axis >= len(shape):
      raise ValueError("Can not partition variable along axis %d when shape is "
                       "only %s" % (axis, shape))
    if dtype.base_dtype == dtypes.string:
      bytes_per_element = bytes_per_string_element
    else: bytes_per_element = dtype.size total_size_bytes = shape.num_elements() * bytes_per_element partitions = total_size_bytes  / min_slice_size partitions_list = [1] * len(shape)
    # We can not partition the variable beyond what its shape or
    # max_partitions allows.
    partitions_list[axis] = max(1.min(shape.dims[axis].value,
                                       max_partitions,
                                       int(math.ceil(partitions))))
    return partitions_list
  return _partitioner
Copy the code

2.3.4 Maximum Partition

This partition ensures that each shard is at most max_shard_bytes large, and tries to allocate as few shards as possible, that is, to keep the shard as large as possible. If the partition program reaches the max_shard limit, each shard may end up being larger than max_shard_bytes. By default, max_shards.. Equals None, the number of shards is not limited.

@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
class MaxSizePartitioner(Partitioner) :
  """Partitioner that keeps shards below max_shard_bytes . This partitioner ensures each shard has at most max_shard_bytes  , and tries to allocate as few shards as possible, i.e., keeping shard size as large as possible. If the partitioner hits the max_shards limit, then each shard may end up larger than max_shard_bytes . By default max_shards equals None and no limit on the number of  shards is enforced. Examples: >>> partitioner = MaxSizePartitioner(max_shard_bytes=4) >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) >>> [6, 1] >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2) >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) >>> [2, 1] >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024) >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) >>> [1, 1] >>> >>> # use in ParameterServerStrategy >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) """

  def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16) :
    """Creates a new MaxSizePartitioner . Args: max_shard_bytes: The maximum size any given shard is allowed to be. max_shards: The maximum number of shards in int created taking precedence over max_shard_bytes . bytes_per_string: If the partition value is of type string, this provides an estimate of how large each string is. """
    if max_shard_bytes < 1:
      raise ValueError('max_shard_bytes must be positive, got: %r' %
                       max_shard_bytes)
    if max_shards and max_shards < 1:
      raise ValueError('max_shards must be positive, got: %r' % max_shards)
    if bytes_per_string < 1:
      raise ValueError('bytes_per_string must be positive, got: %r' %
                       bytes_per_string)

    self._max_shard_bytes = max_shard_bytes
    self._max_shards = max_shards
    self._bytes_per_string = bytes_per_string

  def __call__(self, shape, dtype, axis=0) :
    return partitioned_variables.variable_axis_size_partitioner(
        max_shard_bytes=self._max_shard_bytes,
        max_shards=self._max_shards,
        bytes_per_string_element=self._bytes_per_string,
        axis=axis)(shape, dtype)
Copy the code

Variable_axis_size_partitioner is a concrete business function. This partition program splits a variable along an axis, trying to keep the maximum shard size below max_shard_bytes. If the partition program reaches the max_shard limit, each shard may end up being larger than max_shard_bytes. By default, max_shards is equal to None, which means the number of fragments is not limited.

A reasonable value for max_shard_bytes is (64<<20) -1, or around 64MB, which is guaranteed to be below the limit of protobuf bytes.

@tf_export(v1=["variable_axis_size_partitioner"])
def variable_axis_size_partitioner(
    max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None) :
  """Get a partitioner for VariableScope to keep shards below max_shard_bytes . This partitioner will shard a Variable along one axis, attempting to keep the maximum shard size below max_shard_bytes . In practice, this is not always possible when sharding along only one axis. When this happens, this axis is sharded as much as possible (i.e., every dimension becomes a separate shard). If the partitioner hits the max_shards limit, then each shard may end up larger than max_shard_bytes . By default max_shards equals None and no limit on the number of  shards is enforced. One reasonable value for max_shard_bytes is (64 << 20) - 1 , or almost 64MB , to keep below the protobuf byte limit. Args: max_shard_bytes: The maximum size any given shard is allowed to be. axis: The axis to partition along. Default: outermost axis. bytes_per_string_element: If the Variable is of type string, this provides an estimate of how large each scalar in the Variable is. max_shards: The maximum number of shards in int created taking precedence over max_shard_bytes . Returns: A partition function usable as the partitioner argument to variable_scope and get_variable . Raises: ValueError: If any of the byte counts are non-positive. """

  def _partitioner(shape, dtype) :
    """Partitioner that partitions shards to have max_shard_bytes total size. Args: shape: A TensorShape . dtype: A DType . Returns: A tuple representing how much to slice each axis in shape. Raises: ValueError: If shape is not a fully defined TensorShape or dtype is not a DType . """
    if dtype.base_dtype == dtypes.string:
      element_size = bytes_per_string_element
    else:
      element_size = dtype.size

    partitions = [1] * shape.ndims
    bytes_per_slice = 1.0 * (
        shape.num_elements() / shape.dims[axis].value) * element_size
    # How many slices can we fit on one shard of size at most max_shard_bytes?
    # At least one slice is required.
    slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
    # How many shards do we need for axis given that each shard fits
    # slices_per_shard slices from a total of shape[axis] slices?
    axis_shards = int(math.ceil(
        1.0 * shape.dims[axis].value / slices_per_shard))
    if max_shards:
      axis_shards = min(max_shards, axis_shards)

    partitions[axis] = axis_shards

    return partitions

  return _partitioner
Copy the code

2.4 ShardedVariableMixin

As mentioned earlier, sharadded variableMixin is the core, so let’s analyze it. ShardedVariableMixin main member variables are:

  • _variables: Partition variables.

  • _var_offsets: The offset of the partition variables in ShardedVariableMixin, which treats the _variables as a whole and uses offset to find the corresponding data.

  • _shape: ShardedVariableMixin’s shape.

  • _name: name of ShardedVariableMixin.

class ShardedVariableMixin(trackable.Trackable) :
  """Mixin for ShardedVariable."""

  def __init__(self,
               variables: Sequence[variables_lib.Variable],
               name='ShardedVariable') :
    """Treats variables as shards of a larger Variable. Args: variables: A list of ResourceVariable s that comprise this sharded variable. Variables should not be shared between different ShardedVariableMixin objects. name: String. Name of this container. Defaults to "ShardedVariable". """
    super(ShardedVariableMixin, self).__init__()
    self._variables = variables
    self._name = name

    var_dtypes = {v.dtype for v in variables}
    first_var = variables[0]
    self._dtype = first_var.dtype

    # All variables must have the same shape for axes > 0.
    # Calculate the overall shape
    higher_dim_shapes = {tuple(v.shape.as_list()[1:)for v in variables}
    first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
    self._shape = tensor_shape.TensorShape([first_dim] +
                                           first_var.shape.as_list()[1:)Calculate the offset of each partition in the whole
    self._var_offsets = [
        [0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
    ]
    for i in range(1.len(variables)):
      # Always partition on the first axis. Offsets on other axes are 0.
      self._var_offsets[i][0] += (
          self._var_offsets[i - 1] [0] + variables[i - 1].shape.as_list()[0])

    save_slice_info = [v._get_save_slice_info() for v in variables]  

    # We create an uninitialized saving_variable with the full shape, which can
    # be later captured in signatures so that the signatures can treat this
    # ShardedVariable as one single variable.
    self._saving_variable = resource_variable_ops.UninitializedVariable(
        shape=self._shape, dtype=self._dtype, name=self._name)
Copy the code

Against 2.4.1 use

Let’s use the following example to see how it works.

variables = [
  tf.Variable(np.array([[3.2]]), shape=(1.2), dtype=tf.float32,),
  tf.Variable(np.array([[3.2], [0.1]]),  shape=(2.2), dtype=tf.float32),
  tf.Variable(np.array([[3.2]]),  shape=(1.2), dtype=tf.float32)
]
sharded_variable = ShardedVariableMixin(variables)
Copy the code

The internal member variable of sharded_variable is printed as follows. As can be seen, _var_offsets is to view all parameter partitions as a whole and find the corresponding partition from them.

_shape = {TensorShape: 2} (4.2)
_var_offsets = {list: 3} [[0.0], [1.0], [3.0]]
first_dim = {int} 4
Copy the code

In the example above, the three variables are packaged together to look like this, where the user can use offset to find the data.

[[3, 2] [3, 2], [0, 1], [3, 2]]Copy the code

Let’s do another legend. If the parameter has four partitions, the details are as follows:

Figure 20 partitions

If variables are placed on the parameter server, the details are as follows.

Figure 21 Partition and parameter server

2.4.2 Obtaining a Partition

Let’s look at how to get partitions. That’s taking the specified part out of the Sharded variable as a tensor. The specific logic is as follows: analyze the incoming spec, process Sharded variable according to the spec content, and obtain a parameter partition.

  def __getitem__(self, slice_spec) :
    """Extracts the specified region as a Tensor from the sharded variable. The API contract is identical to Tensor.__getitem__ . Assignment to the sliced range is not yet supported. Args: slice_spec: The arguments to __getitem__, specifying the global slicing of the sharded variable. Returns: The appropriate slice of tensor based on slice_spec . Raises: IndexError: If a slice index is out of bound. TypeError: If spec_spec contains Tensor. """

    Get the partition spec
    if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
                                         slice_spec.dtype == dtypes.bool) or
        (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
      tensor = _var_to_tensor(self)
      return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)

    if not isinstance(slice_spec, (list.tuple)):
      slice_spec = (slice_spec,)

    s = slice_spec[0]
    if isinstance(s, slice) :If it is slice, parse the partition
      first_dim_slice_specs = self._decompose_slice_spec(s)
      values = []
      for i, var in enumerate(self._variables):
        if first_dim_slice_specs[i] is not None:
          all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
          values.append(var[all_dim_slice_spec])
      if s.step is not None and s.step < 0:
        values.reverse()
      if not values:
        return constant_op.constant([],
                                    dtype=self._dtype,
                                    shape=((0,) + self._shape[1:))return array_ops.concat(values, axis=0)
    elif s is Ellipsis:
      return array_ops.concat([var[slice_spec] for var in self._variables],
                              axis=0)
    elif s is array_ops.newaxis:
      return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
                              axis=0)[array_ops.newaxis]
    else:
      if isinstance(s, ops.Tensor):
        raise TypeError(
            'ShardedVariable: using Tensor for indexing is not allowed.')
      if s < 0:
        s += self._shape[0]
        
      # walk through the parameter partition, using offset to extract the data
      for i in range(len(self._variables)):
        if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
                                             s < self._var_offsets[i + 1] [0) :return self._variables[i][(s - self._var_offsets[i][0],) +
                                    slice_spec[1:]]
Copy the code

What is a Spec in general? The following example makes it clear.

    For example, given component variables:
      v0 = [0.1.2]
      v1 = [3.4.5]
      v2 = [6.7.8.9]

    If  slice_spec  is slice(start=None, stop=None, step=None), we will have:
      v0[returned[0]] = [0.1.2]
      v1[returned[1]] = [3.4.5]
      v2[returned[2]] = [6.7.8.9]
    If  slice_spec  is slice(start=2, stop=8, step=3), we will have:
      v0[returned[0]] = [2]
      v1[returned[1]] = [5]
      returned[2] = =None
    If  slice_spec  is slice(start=9, stop=3, step=-2), we will have:
      returned[0] = =None
      v1[returned[1]] = [5]
      v2[returned[2]] = [9.7]
Copy the code

The code to get/parse a spec is as follows:

  def _decompose_slice_spec(self, slice_spec) :
    """Decompose a global slice_spec into a list of per-variable slice_spec. ShardedVariable only supports first dimension partitioning, thus slice_spec must be for first dimension. Args: slice_spec: A python slice object that specifies the global slicing. Returns: A list of python slice objects or None specifying the local slicing for each component variable. None means no slicing. "" "
    result = []
    # Normalize start, end and stop.
    slice_step = slice_spec.step if slice_spec.step is not None else 1
    if slice_step == 0:
      raise ValueError('slice step cannot be zero')
    slice_start = slice_spec.start
    if slice_start is None:
      slice_start = 0 if slice_step > 0 else self._shape[0] - 1
    elif slice_start < 0:
      slice_start += self._shape[0]
    slice_end = slice_spec.stop
    if slice_end is None:
      # After the normalization, we no longer interpret negative index, thus
      # "-1" conceptually refers to the element before the first one, which
      # doesn't exist. This is to ease the decomposition code.
      slice_end = self._shape[0] if slice_step > 0 else -1
    elif slice_end < 0:
      slice_end += self._shape[0]

    # To find the local slice_spec of each component variable, we start from
    # the start of the global slice, and iterate through each variable.
    # When iterating on a variable, we move the cursor ( cur ) to the first
    # index that falls into the variable's range, which becomes the start of
    # the variable's local slice_spec. The end of the local_spec is determined
    # by using whatever is smaller between global slice end and variable range
    # end.
    cur = slice_start
    if slice_step > 0:
      for i in range(len(self._var_offsets)):
        var_start = self._var_offsets[i][0]
        var_end = (
            self._var_offsets[i + 1] [0]
            if i < len(self._var_offsets) - 1 else self._shape[0])
        if cur < var_start:
          cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
        if cur >= var_end or cur >= slice_end:
          result.append(None)
        else:
          start = cur - var_start
          end = min(slice_end, var_end) - var_start
          result.append(slice(start, end, slice_step))
    else:  # slice_step < 0
      for i in range(len(self._var_offsets) - 1, -1, -1):
        var_start = self._var_offsets[i][0]
        var_end = (
            self._var_offsets[i + 1] [0]
            if i < len(self._var_offsets) - 1 else self._shape[0])
        if cur >= var_end:
          cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
        if cur < var_start or cur <= slice_end:
          result.append(None)
        else:
          start = cur - var_start
          if slice_end >= var_start:
            end = slice_end - var_start
          else:
            end = None  # no explicit end: slice until hitting the boundary.
          result.append(slice(start, end, slice_step))

      result.reverse()

    return result
Copy the code

2.4.3 Embedding

Let’s look at embedded lookups. The partition_strategy, name, VALIDate_indices, max_norm, etc. are added and passed to the embedding_ops.embedding_lookup. Here the partitioning policy is ‘mod’.

# Override the behavior of embedding_lookup(sharded_variable, ...)
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
def embedding_lookup(params,
                     ids,
                     partition_strategy='mod',
                     name=None,
                     validate_indices=True,
                     max_norm=None) :
  if isinstance(params, list):
    params = params[0]
  return embedding_ops.embedding_lookup(params.variables, ids,
                                        partition_strategy, name,
                                        validate_indices, max_norm)

Copy the code

Process to embedding_lookup (tensorflow/python/ops/embedding_ops. Py), we need to continue to see _embedding_lookup_and_transform.

@tf_export(v1=["nn.embedding_lookup"])
@dispatch.add_dispatch_support
def embedding_lookup(
    params,
    ids,
    partition_strategy="mod",
    name=None,
    validate_indices=True.# pylint: disable=unused-argument
    max_norm=None) :
  """Looks up embeddings for the given ids from a list of tensors. This function is used to perform parallel lookups on the list of tensors in params . It is a generalization of tf.gather , where params is interpreted as a partitioning of a large embedding tensor. params may be a PartitionedVariable as returned by using tf.compat.v1.get_variable() with a partitioner. If len(params) > 1 , each element id of ids is partitioned between the elements of params according to the partition_strategy . In all strategies, if the id space does not evenly divide the number of partitions, each of the first (max_id + 1) % len(params) partitions will be assigned one more id. If the input ids are ragged tensors, partition variables are not supported and the partition strategy and the max_norm are ignored. The results of the lookup  are concatenated into a dense tensor. The returned tensor has shape shape(ids) + shape(params)[1:] . Args: params: A single tensor representing the complete embedding tensor, or a list of P tensors all of same shape except for the first dimension, representing sharded embedding tensors. Alternatively, a PartitionedVariable , created by partitioning along dimension 0. Each element must be appropriately sized for the given partition_strategy . ids: A Tensor or a 'RaggedTensor' with type int32 or int64 containing the ids to be looked up in params . partition_strategy:  A string specifying the partitioning strategy, relevant if len(params) > 1 . Currently "div" and "mod" are supported. Default is "mod" . name: A name for the operation (optional). validate_indices: DEPRECATED. If this operation is assigned to CPU, values in indices are always validated to be within range. If assigned to GPU, out-of-bound indices result in safe but unspecified behavior, which may include raising an error. max_norm: If not None , each embedding is clipped if its l2-norm is larger than this value. Returns: A Tensor or a 'RaggedTensor', depending on the input, with the same type as the tensors in params . Raises: ValueError: If params is empty. """
  if isinstance(ids, ragged_tensor.RaggedTensor):
    return embedding_lookup_ragged(params, ids,
                                   partition_strategy=partition_strategy,
                                   max_norm=max_norm,
                                   name=name)

  return _embedding_lookup_and_transform(
      params=params,
      ids=ids,
      partition_strategy=partition_strategy,
      name=name,
      max_norm=max_norm,
      transform_fn=None)
Copy the code

_embedding_lookup_and_transform Here’s the code for how to partition, and let’s start with an example.

  • If “partition_strategy “is “mod”, we assign each ID to the partition p = id % len(params). For example. 13 ID is divided into five partition, the results are as follows: [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]
  • If “partition_strategy “is “div”, we assign ids to partitions in a sequential fashion. In this case, the 13 ID is divided into five partitions, the results are as follows: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].

The specific code is as follows:

def _embedding_lookup_and_transform(params,
                                    ids,
                                    partition_strategy="mod",
                                    name=None,
                                    max_norm=None,
                                    transform_fn=None) :
  """Helper function for embedding_lookup and _compute_sampled_logits. This function is a generalization of embedding_lookup that optionally applies a caller-specified transformation to each embedding. This is done through the transform_fn argument. If provided, the function is applied to each partitioned tensor of retrieved embeddings, colocated with the embeddings. This function will be called with a single Tensor argument of the same type as the params  tensor and should return a Tensor . The shape of the argument will be the same as params except for the size of the first dimension. The first dimension of the result's shape must be the same size as the argument's. Args: params: See embedding_lookup. ids: See embedding_lookup. partition_strategy: See embedding_lookup. name: See embedding_lookup. max_norm: See embedding_lookup. transform_fn: An optional function to apply to each retrieved embedding. If max_norm is provided, transform_fn is applied to the norm-limited embeddings. Returns: See embedding_lookup for details. Raises: ValueError: If params is empty. """

  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
    # omit code
    else:
      # Flatten the ids. There are two cases where we need to do this.
      # - There is more than one params tensor.
      # - There is a transform_fn and ids is not statically known to be 1-D.
      # We must flatten in this case because transform_fn expects a flat
      # tensor of embeddings.
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = tensor_shape.Dimension(
            tensor_shape.dimension_value(params[0].get_shape()[0]))
        for p in xrange(1, np):
          dim_0_size += tensor_shape.Dimension(
              tensor_shape.dimension_value(params[p].get_shape()[0]))
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
            if param_p_dim is not None:
              dim_0_sizes.append(param_p_dim)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
                                         (flat_ids - extras) //
                                         ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        new_ids = array_ops.where(p_assignments < extras,
                                  flat_ids % (ids_per_partition + 1),
                                  (flat_ids - extras) % ids_per_partition)
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

  # omit other code

Copy the code

How to use embedding? We extract the usage from the annotation as follows, where we build a ShardedVariable that the model operates on via embedding_lookup.

  >>> class Model(tf.Module) :.def __init__(self) :. self.sharded_variable = ShardedVariable([ ... tf.Variable([3.0], dtype=tf.float32),
  ...       tf.Variable([2.0], dtype=tf.float32) ... ] )... . @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) ...def fn(self, x) :.returntf.nn.embedding_lookup(self.sharded_variable.variables, x) ... . @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) ...def serve_fn(self, x) :.return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
  >>>
  >>> model = Model()
  >>> model.fn(1).numpy()
  2.0
  >>> tf.saved_model.save(model, export_dir='/tmp/saved_model'. signatures=model.serve_fn)Copy the code

If the embedding is represented by a legend, the following worker will operate in parallel on both parameter servers to extract the embedding.

Figure 22 Handling embedding

2.5 build

About ShardedVariable build, we see the build process of ParameterServerStrategyV2 directly.

2.5.1 Variable Sharding

To enable variable sharding, you can pass a variable_partitioner when building the ParameterServerStrategy object. Each time a variable is created, the variable_partitioner is called with the expectation that it will return the number of shards along each dimension of the variable. System provides some out-of-the-box variable_partitioner, such as tf. Distribution. Experimental. Partitioners. MinSizePartitioner. Recommend the use of partition based on size (size – -based), such as tf. Distribution. Experimental. Partitioners. MinSizePartitioner, in order to avoid the small variable partition, Because that may have a negative impact on the model training speed.

When passing the variable_partitioner, if you create a variable directly under Strategy.scope (), it will become a container type with the variables property, which will provide access to the shard list. In most cases, the container will be automatically converted to a tensor by concatenating all the shards. Therefore, it can be used as a normal variable. On the other hand, some TensorFlow methods, such as tf.nm.embedding_lookup, provide effective implementations of this container type, and these methods avoid automatic wiring.

3.2.4 initialization

In ParameterServerStrategyV2Extended initialization, the incoming variable_partitioner set into _variable_partitioner, The parameters number of servers and number of workers are also configured.

class ParameterServerStrategyV2Extended(
    parameter_server_strategy.ParameterServerStrategyExtended) :
  """Extended class for ParameterServerStrategyV2. Please see tf.distribute.StrategyExtended doc for more information. """

  def __init__(self, container_strategy, cluster_resolver, variable_partitioner) :
    """Initialization of ParameterServerStrategyV2Extended."""
    super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
    self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
    self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get(
        "worker", []))
    self._variable_count = 0

    self._variable_partitioner = variable_partitioner

Copy the code

2.5.3 build

Let’s look at the creation process, which is how variables are shard to different parameter servers. The specific ideas are as follows:

  • If the partition generator is not configured, the RR policy (_create_variable_round_robin) is used to assign variables to the parameter server.
  • If the partition generator is configured, do the following:
    • Rank-0 is not partitioned.
    • The _variable_partitioner gets the number of partitions.
    • The number of partitions must be greater than the first dimension, otherwise the first dimension is used.
    • Compute the tensor offset.
    • It generates a lot of small tensors.
    • Build a list of small tensors using _create_variable_round_robin.
    • Use a list of small tensors to generate ShardedVariable.
  def _create_variable(self, next_creator, **kwargs) :
    """Implements StrategyExtendedV2._create_variable. Creates a Variable or a ShardedVariable . A ShardedVariable will be created if satisfying all the following criteria: 1. self._variable_partitioner results in more than one partition on the first axis. 2. variable's rank is greater than 0. 3. variable is not colocated with another variable. Otherwise a Variable will be created. Args: next_creator: See variable_scope.variable_creator_scope ; the next creator in the chain. **kwargs: Passed through to the next creator. Returns: A Variable or ShardedVariable . """

    var_creator = self._create_var_creator(next_creator, **kwargs)
    if "colocate_with" in kwargs:  # Never partition colocated_with variables.
      colocate_with = kwargs["colocate_with"]
      # Clear the variable scope to avoid possible conflicts between device
      # scope and colocation scope.
      with ops.device(None) :with ops.colocate_with(colocate_with):
          var = var_creator(**kwargs)
          return var

    If the partition generator is not configured, use the RR policy to assign variables to the parameter server
    if self._variable_partitioner is None:
      return self._create_variable_round_robin(var_creator, **kwargs)

  The partition generator is configured below
    name = kwargs.get("name".None)
    initial_value = kwargs.get("initial_value".None)

    # Two cases where initial_value can be a callable:
    # 1. initial_value is passed as a callable, e.g, an initializer class.
    # 2. restoring from checkpoint, initial_value is a
    # "CheckpointInitialValueCallable".
    init_from_fn = callable(initial_value)

    dtype = kwargs.get("dtype".None)
    shape = kwargs.get("shape".None)
    if init_from_fn and (shape is None or dtype is None):
      init_from_fn = False
      initial_value = initial_value()
    if not init_from_fn:
      # The initial_value is created on coordinator, it will need to be sent to
      # ps for variable initialization, which can be inefficient and can
      # potentially hit the 2GB limit on protobuf serialization.
      initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
      dtype = initial_value.dtype
      shape = initial_value.shape
    else:
      shape = tensor_shape.as_shape(shape)

    # rank-0 does not partition
    if shape.rank == 0:  # Skip partitioning rank-0 variable.
      return self._create_variable_round_robin(var_creator, **kwargs)

    Get the number of partitions
    num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
    if num_partitions[0] = =1:  # no partition
      return self._create_variable_round_robin(var_creator, **kwargs)

    The number of partitions must be greater than the first dimension, otherwise the first dimension is used
    # Use "div" partition strategy to partition the variable.
    num_partitions = min(num_partitions[0], shape[0])
    base = shape[0] // num_partitions
    
    # calculated offset
    extra = shape[0] % num_partitions
    # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
    # offsets: [0, 3, 6, 8, 10]
    offsets = []
    for i in range(num_partitions):
      if i == 0:
        offsets.append(0)
      else:
        prev_shard_size = base + (1 if i - 1 < extra else 0)
        offsets.append(offsets[i - 1] + prev_shard_size)
    offsets.append(shape[0])

    def init_shard_fn(shard_index) :
      if not init_from_fn:
        return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
    
      partition_shape = (offsets[shard_index + 1] -
                         offsets[shard_index],) + shape[1:]
      partition_offset = (offsets[shard_index],) + (0.) *len(shape[1:])
      arg_spec = tf_inspect.getfullargspec(initial_value)
      if ("shard_info" not in arg_spec.args and
          "shard_info" not in arg_spec.kwonlyargs):
        try:
          value = initial_value(
              partition_shape=partition_shape,
              partition_offset=partition_offset)
        except (TypeError, ValueError):
          # TypeError: Initializer doesn't accept kwargs
          # ValueError: Initializer doesn't accept partition kwargs
          # In both cases we go ahead creating the full value and then slice.
          value = initial_value()

        if value.shape == partition_shape:
          # Initializer supports partition: value is the partition value.
          return value
        else:
          # Initializer doesn't support partition: value is the full value
          # and needs to be sliced to get the partition value.
          return value[offsets[shard_index]:offsets[shard_index + 1]]
      else:
        # For compatibility with CheckpointInitialValueCallable .
        return initial_value(
            shard_info=trackable.ShardInfo(
                shape=tensor_shape.as_shape(partition_shape),
                offset=partition_offset))

    # produces a lot of small tensors
    var_list = []
    for i in range(num_partitions):
      kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
      kwargs["initial_value"] = lambda: init_shard_fn(i) # initialization
      if name is not None:
        kwargs["name"] = "{}/part_{}".format(name, i)
      # how is the tensor allocated using _create_variable_round_robin
      var_list.append(self._create_variable_round_robin(var_creator, **kwargs))

    # Generate ShardedVariable from a list of small tensors
    result = sharded_variable.ShardedVariable(var_list)
    return result
Copy the code

In the above logic, both branches use _create_variable_round_robin, which uses the RR policy to determine how specific placement is done. In fact, the tensor is configured with the corresponding device name, the subsequent layout operation, according to the device name operation.

  def _create_variable_round_robin(self, next_creator, **kwargs) :
    # Clear the colocation scope to avoid possible conflicts between device
    # scope and colocation scope.
    with ops.colocate_with(None, ignore_existing=True) :# Explicitly set CPU:0 device for PS in case create variable is called
      # inside replica_fn and worker has with GPU:0 scope.
      with ops.device("/job:ps/task:%d/device:CPU:0" %
                      (self._variable_count % self._num_ps)):
        var = next_creator(**kwargs)
        logging.debug(
            "Creating variable (name:%s, shape:%r) on "
            "/job:ps/task:%d/device:CPU:0",
            var.name, var.shape, (self._variable_count % self._num_ps))
        self._variable_count += 1
        return var
Copy the code

The next_creator parameter for _create_variable_round_robin is generally described as follows. AggregatingVariable and CachingVariable are used to build var_list, Then you build ShardedVariable using var_list. We’ll focus on AggregatingVariable.

  def _create_var_creator(self, next_creator, **kwargs) :
    aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)

    def var_creator(**kwargs) :
      """Create an AggregatingVariable."""
      # Create and wrap the variable.
      v = next_creator(**kwargs)
      wrapped_v = ps_values.CachingVariable(v)
      wrapped = ps_values.AggregatingVariable(self._container_strategy(),
                                              wrapped_v, aggregation)
      return wrapped

    if self._num_replicas_in_sync > 1:
      if aggregation not in (
          vs.VariableAggregation.NONE,
          vs.VariableAggregation.SUM,
          vs.VariableAggregation.MEAN,
          vs.VariableAggregation.ONLY_FIRST_REPLICA
      ):
        raise ValueError("Invalid variable aggregation mode: " + aggregation +
                         " for variable: " + kwargs["name"])
      return var_creator
    else:
      def variable_creator_single_replica(**kwargs) :
        v = next_creator(**kwargs)
        return ps_values.CachingVariable(v)
      return variable_creator_single_replica
Copy the code

2.5.4 AggregatingVariable

AggregatingVariable is used to wrap variables that can aggregate changes across replicas. In the case of _assign_func, you can see that it operates on variables using _distribute_strategy.extend.update.

# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
class AggregatingVariable(resource_variable_ops.BaseResourceVariable, core.Tensor) :
  """A wrapper around a variable that aggregates updates across replicas."""

  def __init__(self, strategy, v, aggregation) :
    self._distribute_strategy = strategy
    self._v = v
    # NOTE: We don't use "_distributed_container" here because we don't want
    # to trigger that code path in regroup().
    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
    self._aggregation = aggregation

  def __deepcopy__(self, memo) :
    """Perform a deepcopy of the AggregatingVariable . Unlike the deepcopy of a regular tf.Variable, this keeps the original strategy and devices of the AggregatingVariable . To avoid confusion with the behavior of deepcopy on a regular Variable (which does copy into new devices), we only allow a deepcopy of a AggregatingVariable within its originating strategy scope. Args: memo: The memoization object for deepcopy . Returns: A deep copy of the current AggregatingVariable . Raises: RuntimeError: If trying to deepcopy into a different strategy. """
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      v = copy.deepcopy(self._v, memo)

    copied_variable = type(self)(
        strategy=self._distribute_strategy,
        v=v,
        aggregation=self._aggregation)

    memo[id(self)] = copied_variable

    return copied_variable

  def get(self) :
    return self._v

  @property
  def distribute_strategy(self) :
    return self._distribute_strategy

  def __getattr__(self, name) :
    return getattr(self._v, name)

  def _assign_func(self, *args, **kwargs) :
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      f = kwargs.pop("f")
      if ds_context.in_cross_replica_context():
        if distribute_lib.get_update_replica_id() is not None:
          # We are calling an assign function in an update context.
          return f(self._v, *args, **kwargs)

        # We are calling an assign function in cross replica context, wrap it in
        # an update call.
        return self._distribute_strategy.extended.update(
            self, f, args=args, kwargs=kwargs)
      else:
        replica_context = ds_context.get_replica_context()
          # We are calling an assign function in replica context.
        # We reduce the value we want to assign/add/sub. More details about how
        # we handle the different use cases can be found in the _reduce method.
        # We call the function with the reduced value.
        if self._aggregation == vs.VariableAggregation.NONE:
          raise ValueError(
              values_util.aggregation_error_msg.format(
                  variable_type="AggregatingVariable"))

        def merge_fn(strategy,
                     value,
                     use_locking=False,
                     name=None,
                     read_value=True) :
          v = values_util.apply_aggregation(strategy, value, self._aggregation,
                                            self)
          if name and isinstance(name, values.PerReplica):
            name = name.values[0]
          return strategy.extended.update(
              self,
              f,
              args=(v,),
              kwargs={
                  "use_locking": use_locking,
                  "name": name,
                  "read_value": read_value
              })
        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
Copy the code

2.6 the use of

The following example shows how ShardedVariable can be used. A ShardedVariable (self.w) is constructed in Dense, whose shape is [100, 10], and the result after partition is two (50, 10) tensors.

  class Dense(tf.Module) :
    def __init__(self, name=None) :
      super().__init__(name=name)
      self.w = tf.Variable(tf.random.normal([100.10]), name='w')

    def __call__(self, x) :
      return x * self.w

  # Partition the dense layer into 2 shards.
  variable_partitioner = (
    tf.distribute.experimental.partitioners.FixedShardsPartitioner(
      num_shards = 2)) strategy = tf.distribute.experimental.ParameterServerStrategy( cluster_resolver=... , variable_partitioner = variable_partitioner)with strategy.scope():
    dense = Dense() # is in the strategy context, so the generated variable is automatically divided into two partitions.
    
  assert len(dense.variables) == 2
  assert isinstance(dense.variables[0], tf.Variable)
  assert isinstance(dense.variables[1], tf.Variable)
  assert dense.variables[0].shape == (50.10)
  assert dense.variables[1].shape == (50.10)
Copy the code

ShardedVariable is also a form of model parallelism. For example, the matrix AB is decomposed to two parameter servers and multiplied with C respectively. Finally, the multiplication results are aggregated on worker and concatenated into a final result tensor.

FIG. 23 Merging tensors

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

Distributed_runtime for tensorflow source code parsing

TensorFlow distributed training

TensorFlow kernel analysis

The source code

Tensorflow distributed principle

TensorFlow Architecture and Design: Overview

Tensorflow communicates across devices

TensorFlow article | TensorFlow 2 x distributed an overview of the training

The TensorFlow 2.4 is used to implement distributed training Zhou Yue maple www.bilibili.com/video/BV1MT…

In-depth TensorFlow: parameter server training www.bilibili.com/video/BV1u5…