checkpoint

This article describes how to save and restore the compilation of TensorFlow models with Estimators. TensorFlow provides two model formats:

  • Checkpoints (Checkpoints) : This is a format that relies on creating model code.
  • SavedModel: This is a format that has nothing to do with creating model code.

Save and Restore

The sample code

Premade Estimators

git clone https://github.com/tensorflow/models/
cd models/samples/core/get_started
Copy the code

Most of the code snippets in this article are lightly modified versions of premade_estimator.py.

Save the untrained model

Estimators automatically writes the following to disk:

  • Checkpoint: Different versions of the model generated during training.
  • Event file: Contains some information for TensorBoard visualization

Assign an optional parameter model_DIR to the constructor of any Estimator to specify the top-level directory where the Estimator stores information. For example, the following code sets the model_dir parameter to the models/iris directory:

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')
Copy the code

Suppose you call the train method of Estimator. Such as:

classifier.train(
        input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
                steps=200)
Copy the code

The first call to train adds checkpoints and other files to the model_dir directory, as shown in the following diagram:

On a Unix-like system, the command ls can be used to view objects in the model_dir directory:

$ ls -1 models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
Copy the code

The ls command above shows that this Estimator generates checkpoints at step 1 (at the start of training) and step 200 (at the end of training).

Default checkpoint directory

If you specify the model_DIR parameter in an Estimator constructor, the Estimator writes the checkpoint file to a temporary directory specified by Python’s tempfile.mkdtemp function. For example, the following Estimator constructor does not specify the model_dir parameter:

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3)

print(classifier.model_dir)
Copy the code

The tempfile.mkdtemp function selects a secure temporary directory for you on the operating system. For example, on macOS, a typical temporary directory is:

/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa
Copy the code

Save frequency of checkpoints

By default, Estimator saves checkpoints in the model_DIR directory and uses the following strategy:

  • Save a checkpoint every 10 minutes (600 seconds).
  •   当 trainThe method saves a checkpoint at the beginning of execution (the first loop) and at the end of execution (the last loop).
  • Keep the last five checkpoints in the directory.

You can change the above default policy using the following steps:

  1. tf.estimator.RunConfig
  2. This is used when instantiating EstimatorRunConfigObject passed to EstimatorconfigParameters.

For example, the following code changes the checkpoint save policy to save every 20 minutes and keep the last 10 checkpoints:

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)
Copy the code

Restore your model

When an Estimator train method is called for the first time, TensorFlow stores a checkpoint in the model_DIR directory. Each subsequent invocation of Estimator’s train, evaluate, or predict methods will result in the following behavior:

  1. Create a custom Estimator
  2. The Estimator retrieves data from the most recent check points and initializes the weights for the new model.

In other words, as shown in the figure below, TensorFlow always rebuilds the model when you call train(), evaluation(), or predict() once the checkpoint file exists.

Avoid bad recovery

Only if the model is compatible with the checkpoint can we recover the state of the model from this checkpoint. For example, suppose you train an Estimator called DNNClassifier, which contains two hidden layers with 10 nodes each:

classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)
Copy the code

After training (and, of course, creating checkpoints in the Models/Iris directory at the same time), suppose you change the number of nodes in each hidden layer from 10 to 20 and try to restore the model:

classifier2 = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[20, 20], N_classes =3, model_dir='models/iris') classifier. Train (input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100), steps=200)Copy the code

Because the state of the checkpoint is incompatible with the state of the model described by Classifier2, restoring the model will fail with the following error message:

. InvalidArgumentError (see above for traceback): tensor_name = dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10] does not match the shape stored in checkpoint: [20]Copy the code

Save the code that creates each model_DIR as you experiment with training and comparing models with slightly different versions. For example, you can create a separate Git branch for each release. This separation ensures that your checkpoints are recoverable.

conclusion

Checkpoints provide an easy automated mechanism for saving and restoring models generated by Estimator.

Save and Restore

  • Use the underlying TensorFlow API to save and restore the model.
  • Export and import models in the SavedModel pattern, which is a language-independent, recoverable, serializable format.