• This paper mainly introducesKerasIn the image classification task used in the image pretreatment part of the content.
  • Note: This is not an introductionKerasAll image preprocessing functions in.

1. Introduction

When using Keras for image classification task, if the data set is small (data acquisition difficulty, etc.), in order to make full use of the value of limited data, data enhancement processing can be carried out.

A series of random transformations are used to improve the data, which is beneficial to restrain over-fitting and improve the generalization ability of the model.

Keras, provides a class that is used for data to enhance (Keras. Preprocessing, image ImageDataGenerator) to implement this function. This class can:

  • During training, set up random changes to be implemented
  • through.flowor.flow_from_directory(directory)Method to instantiate a target imagebatchThese generators can be used to dokerasInput to the related method, e.gfit_generator.evaluate_generatororpredict_generator.

What does that mean? — ImageDataGenerator class can not only carry out random changes of images during training, but also increase training data; The ability to retrieve the batch generator object is also included, eliminating the need to manually retrieve the batch data.

2. Introduction to ImageDataGenerator

ImageDataGenerator class path: keras/preprocessing/image. Py

Purpose: Generate batch image data vector through real-time data enhancement. The function generates data in an infinite loop during training until the specified epoch number is reached.

ImageDataGenerator inheritance in keras_preprocessing/image/image_data_generator. Py ImageDataGenerator in class.

# keras/preprocessing/image.py
class ImageDataGenerator(image.ImageDataGenerator):
    def __init__(self,
                 featurewise_center=False,
                 samplewise_center=False,
                 featurewise_std_normalization=False,
                 samplewise_std_normalization=False,
                 zca_whitening=False,
                 zca_epsilon=1e-6,
                 rotation_range=0,
                 width_shift_range=0.,
                 height_shift_range=0.,
                 brightness_range=None,
                 shear_range=0.,
                 zoom_range=0.,
                 channel_shift_range=0.,
                 fill_mode='nearest',
                 cval=0.,
                 horizontal_flip=False,
                 vertical_flip=False,
                 rescale=None,
                 preprocessing_function=None,
                 data_format=None,
                 validation_split=0.0,
                 dtype=None):
Copy the code

parameter

  • featurewise_center: Boolean value, so that the input data set is decentralized (mean 0), feature by feature.
  • samplewise_center: Boolean value such that the mean of each sample of the input data is 0
  • featurewise_std_normalization: Boolean value. The input is divided by the standard deviation of the data set to complete standardization, which is executed according to feature
  • samplewise_std_normalization: Boolean value dividing each input sample by its own standard deviation
  • zca_whitening: Boolean value applied to input dataZCA albino
  • zca_epsilon: ZCAThe use ofeposilon, the default1e-6
  • Rotation_range: integer, the range of angles in which the image is rotated randomly
  • width_shift_rangeFloating-point number, a one-dimensional array or integer, a certain proportion of the width of the picture, how much the picture is offset horizontally as the data is promoted
    • float: If <1, divided by the value of the total width, if >=1, the width pixel value
    • One-dimensional array: Random elements in an array
    • Integer: from interval(-width_shift_range,width_shift_range)betweenIs an integer of pixels
    • width_shift_range=2: May be an integer[1, 1], andWidth_shift_range = [1, 1]The same, while whenWidth_shfit_range = 1.0, the possible value is the half-open interval[-1.0,1.0]Floating point number between (The second half of the sentence is not understood).
  • height_shift_range: floating point number, a percentage of the height of the image, the amount of the image’s vertical offset when the data is raised. Specific meaning andwidth_shift_rangeThe same.
  • brightness_rangeTwo:floatA tuple or list composed of. Select the range of brightness values
  • Shear_range: floating point number, shear strength (counterclockwise shear transformation Angle)
  • zoom_range: floating point number or[lower, upper]. Random scale range, if it’s floating point,[lower, upper] = [1-zoom_range, 1+zoom_range]
  • channel_shift_range: floating point number, random channel conversion range.
  • fill_mode:{"constant", "nearest", "reflect" or "wrap"}One of the. The default is'nearest'. Points outside the input boundary are filled according to the given pattern:
    • 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
    • 'nearest': aaaaaaaa|abcd|dddddddd
    • 'reflect': abcddcba|abcd|dcbaabcd
    • 'wrap': abcdabcd|abcd|abcdabcd
  • cval: floating point number or integer. Value used for points outside the boundary, whenfill_mode = "constant"At the right time.
  • Horizontal_flip: Boolean value, randomly flipped horizontally.
  • Vertical_flip: Boolean value, random vertical flip.
  • rescale: Rescaling factor. The default isNone. If None or 0, no scaling is done, otherwise the data is multiplied by the supplied value (before any other transformation is appliedafter)
  • preprocessing_function: This function is applied to each input in the imageresizeAnd enhancedafterRun. This function takes one argument, an image (rank 3)numpy tensor), also output an identicalshapetheNumpy tensor.
  • data_format: Image data format,{"channels_first", "channels_last"}One of the."channels_last"Mode indicates that the image input size should be(samples, height, width, channels)."channels_first"Mode indicates that the input size should be(samples, channels, height, width). Default is in the Keras configuration file~/.keras/keras.jsonIn theimage_data_formatValue. If you’ve never set it, it is"channels_last".
  • validation_split: Floating point type. The proportion of images reserved for validation sets (strictly in0, 1Between)
  • dtype: The data type used to generate the array.

Use the sample

from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

data_generator = datagen.flow_from_directory('./datas/train', target_size=(224.224), batch_size=32)
Copy the code

3. ImageDataGenerator class method

Several important methods of this class are as follows:

  • flow(): This method returns an iterator in the form of a tuple with input data (in Numpy or tuple form) and label (optional)(x,y)or(x)or(x,y,sample_weight). This method can also specify the sample output path and prefix, format, used to save the enhanced image.
  • flow_from_directory(): Obtain image path to generate batch enhanced data. This method only needsSpecify the path of the dataWithout input numPY data or label value, the corresponding label value is automatically returned. Return a build(x, y)Of a tupleDirectoryIterator.
  • flow_from_dataframe(): The input data isPandas dataframeFormat. Return to generate(x, y)Of a tupleDataFrameIterator.

Matters needing attention

  • The main difference is the format of input data and output data.
  • flow_from_directory()andflow_from_dataframe()Both functions will graphresizeTo the specified size. whileflow()This step is not required.

3.1 fit ()

This approach makes the data generator suitable for some sample data, which computes internal data statistics related to data-dependent transformations from sample data arrays.

This is calculated only if FeatureWISe_Center or FeatureWISe_STd_normalization or ZCA_Whitening is set to True.

That is to realize the decentralization/standardization /ZCA whitening of data. The mean and standard deviation of the data used are all data themselves.

Def fit(self, x, augment=False,rounds=1, seed=None)

  • x: Sample data, rank is 4. For grayscale images, channel axis should be 1RGBData, it should be 3 if it isRGBAThe data, it should be 4.
  • augment: Boolean, defaultFalse, whether to apply random enhancement
  • rounds: Integer, default1. ifaugment=TrueThis is the amount of expansion passed to the data usage.
  • seed: Integer, defaultNoneRandom seeds.

Return value: an Iterator that generates tuples (x, y) where X is a Numpy array of image data (when a single image is entered) or a Numpy arraylist (when multiple additional inputs are entered) where y is a Numpy array of the corresponding labels. If ‘sample_weight’ is not None, the generated tuple is of the form (x, y, sample_weight). If y is None, only the Numpy array x is returned.

The specific implementation

def fit(self, x,
    augment=False,
    rounds=1,
    seed=None):
    
    # Here is compliance test
    # Decentralize data
    if self.featurewise_center:
        self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
        broadcast_shape = [1.1.1]
        broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
        self.mean = np.reshape(self.mean, broadcast_shape)
        x -= self.mean
    # Data standardization
    if self.featurewise_std_normalization:
        self.std = np.std(x, axis=(0, self.row_axis, self.col_axis))
        broadcast_shape = [1.1.1]
        broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
        self.std = np.reshape(self.std, broadcast_shape)
        x /= (self.std + 1e-6)
    # Data ZAC whitening processing
    if self.zca_whitening:
        if scipy is None:
            raise ImportError('Using zca_whitening requires SciPy. '
                              'Install SciPy.')
        flat_x = np.reshape(
            x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
        sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
        u, s, _ = linalg.svd(sigma)
        s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon)
        self.principal_components = (u * s_inv).dot(u.T)
Copy the code

3.2 the flow ()

Collect data and label arrays to generate batch enhanced data.

Function definition:

def flow(self,
         x,
         y=None,
         batch_size=32,
         shuffle=True,
         sample_weight=None,
         seed=None,
         save_to_dir=None,
         save_prefix='',
         save_format='png',
         subset=None)
Copy the code

Parameters:

  • x: Input data. A Numpy matrix or tuple of rank 4. If it is a tuple, the first element should contain the image, and the second element is another Numpy array or a list of Numpy arrays, which are passed to the output without any modifications. Can be used to input model miscellaneous data with images. The channel axis of the image array should have a value of 1 for grayscale data and 3 for RGB data.
  • y: the tag
  • batch_size: Integer, 32 by default
  • shuffle: Boolean, defaultTrue, whether to shuffle data
  • sample_weight: Sample weight
  • seed: the defaultNone
  • save_to_dir: None or a string (default: None). This gives you the option of specifying the directory to save the enhanced image being generated
  • save_prefix: string (default ”). Save the file name prefix of the picture (only ifsave_to_dirAvailable when set).
  • save_format: "png", "jpeg"One of (only ifsave_to_dirAvailable when set). Default: “PNG”.
  • subset: Data subset ("training""validation"), if inImageDataGeneratorSet in thevalidation_split.

Return value: an Iterator that generates tuples (x, y) where X is a Numpy array of image data (when a single image is entered) or a Numpy arraylist (when multiple additional inputs are entered) where y is a Numpy array of the corresponding labels. If ‘sample_weight’ is not None, the generated tuple is of the form (x, y, sample_weight). If y is None, only the Numpy array x is returned.

Internally calling the array iterator class:

    return NumpyArrayIterator(
        x,
        y,
        self,
        batch_size=batch_size,
        shuffle=shuffle,
        sample_weight=sample_weight,
        seed=seed,
        data_format=self.data_format,
        save_to_dir=save_to_dir,
        save_prefix=save_prefix,
        save_format=save_format,
        subset=subset
    )
Copy the code

3.3 flow_from_directory ()

Function: obtain image path, generate batch enhanced data.

Function definition:

def flow_from_directory(self,
                directory,
                target_size=(256, 256),
                color_mode='rgb',
                classes=None,
                class_mode='categorical',
                batch_size=32,
                shuffle=True,
                seed=None,
                save_to_dir=None,
                save_prefix='',
                save_format='png',
                follow_links=False,
                subset=None,
                interpolation='nearest')
Copy the code

Parameters:

  • directory: Path of the destination directory. Each class should contain a subdirectory. Anything under the subdirectory treePNG, JPG, BMP, PPMTIFThe images will all be included in the generator.
  • target_size: integer tuple(height,width)The default:(256256).. All images will be adjusted to size.
  • color_mode:"grayscale","rbg"One of the. The default:"rgb". Whether the image is converted into 1 or 3 color channels.
  • classes: a list of subdirectories of optional classes (e.g['dogs', 'cats']). The default:None. If not provided, the list of classes willAutomatically inferred from subdirectory names/structures under directory, where each subdirectory will be treated as a different class (class nameAccording to the dictionary sequenceThe index mapped to the tag). Dictionaries that contain mappings from class names to class indexes can passclass_indicesProperty acquisition.
  • class_model:"categorical", "binary", "sparse", "input"orNoneOne of the. The default:"categorical". Determine the type of label array to return:
    • "categorical"will2D one-hotCode tag,
    • "binary"Will be 1D binary labels,"sparse"Will be the 1D integer label,
    • "input"Will be the same image as the input image (mainly used for autoencoders).
    • If None, no label is returned. (The generator will only generate batch image data formodel.predict_generator().model.evaluate_generator()Waiting is very useful. Please note ifclass_modeNone, then the data still needs to reside indirectoryTo work properly.
  • Batch_size: indicates the size of a batch of data (32 by default).
  • shuffle: Whether to shuffle data (default True)
  • seed: Optional random seeds for mixing and conversion.
  • save_to_dir:NoneOr a string (default None). This allows you to best specify the directory to save the enhanced image you are generating (for visualization of what you are doing).
  • save_format: string. Save the file name prefix of the picture (only ifsave_to_dirAvailable when set).
  • follow_links: Whether to trace symbolic links in class subdirectories (default: False).
  • subset: Data subset ("training"or"validation"), if inImageDataGeneratorSet in thevalidation_split.
  • interpolation: Interpolation method for resampling an image when the target size is different from that of the loaded image. The supported methods are"nearest", "bilinear", and "bicubic". If 1.1.3 or later is installedPILWords, also support"lanczos". If version 3.4.0 or later is installedPILWords, also support"box""hamming". By default, use"nearest".

The return value: A DirectoryIterator that generates a tuple of (x, y), where X is a Numpy array containing a batch of images of size (batch_size, *target_size, Channels) and y is a Numpy array of corresponding labels.

3.4 flow_from_dataframe ()

Function: Enter the dataframe and directory path, and generate bulk enhanced/standardized data.

The input data for this function is in the format Pandas Dataframe.

Function definition:

def flow_from_dataframe(self, dataframe, directory=None,
                x_col="filename", y_col="class", weight_col=None,
                target_size=(256, 256), color_mode='rgb', classes=None,
                class_mode='categorical', batch_size=32, shuffle=True, seed=None,
                save_to_dir=None, save_prefix='', save_format='png', subset=None,
                interpolation='nearest', validate_filenames=True, **kwargs)
Copy the code

Parameters:

  • dataframe: Pandas dataframe, one column is the file name of the image, another column is the category of the image, or multiple columns can be used as the original target data.
  • directory: string, the path to the destination directory, which is contained indataframeAll images mapped in.
  • x_col: string,dataframeThe column of the directory that contains the target image folder.
  • y_col: string or list of strings,dataframeColumns to be used as target data.
  • has_ext: Boolean value ifdataframe[x_col]If the file name with the extension inTrue, or forFalse.
  • target_size: integer tuple(height, width)By default,(256, 256). All graphs found will be adjusted to this dimension.
  • color_mode: "grayscale", "rbg"One of the. The default:"rgb". Whether the image is converted to 1 or 3 color channels.
  • classes: List of optional categories (e.g.,['dogs', 'cats']). The default:None. If not provided, the analogy list will automatically be downloaded fromy_colTo deduce that,y_colWill be mapped to category index). Dictionaries that contain mappings from class names to class indexes can be accessed by attributesclass_indicesTo obtain.
  • class_mode: "categorical", "binary", "sparse", "input", "other" or NoneOne of the. The default:"categorical". Determine the type of label array to return:
    • "categorical"It will be2D one-hotCode tag,
    • "binary"Will be 1D binary labels,
    • "sparse"Will be the 1D integer label,
    • "input"Will be the same as the input image (mainly for use with the autoencoder),
    • "other"It will bey_colThe data ofnumpyArray,None, does not return any labels (the generator will only generate batch image data, which is used for thismodel.predict_generator(), model.evaluate_generator()Waiting is very useful.
  • batch_size: Batch data size (default: 32).
  • shuffle: Whether to shuffle data (default: True)
  • seed: Optional mixing and conversion of random seeds.
  • save_to_dir: Nonestr(default:None). This allows you to optionally specify the directory in which you want to save the enhanced image being generated (for visualization of what you are doing).
  • save_prefix: string. Save the file name prefix of the picture (only ifsave_to_dirAvailable when set).
  • save_format: "png","jpeg"One of (only ifsave_to_dirAvailable when set). The default:"png".
  • follow_links: Whether to follow symbolic links in class subdirectories (default:False).
  • subset: Data subset ("training"or"validation"), if inImageDataGeneratorSet in thevalidation_split.
  • interpolation: Interpolation method for resampling an image when the target size is different from that of the loaded image. The supported methods are"nearest", "bilinear", and "bicubic". This is also supported if PIL 1.1.3 or above is installed"lanczos". This is also supported if PIL 3.4.0 or higher is installed"box"and"hamming". By default, use"nearest".

The return value: A DataFrameIterator that generates a tuple of (x, y), where X is a NUMpy array containing a batch of image samples of size (batch_size, *target_size, Channels) and y is a NUMPY array of corresponding labels.

3.5 standardize ()

This function standardizes a set of batch input data.

Main steps:

  • ifpreprocessing_functionIf not null, processing for the specified function is performedx = self.preprocessing_function(x)
  • ifrescaleforTrue, the implementationx*=self.rescale
  • ifsamplewise_centerforTrue, the implementationx-=np.mean(x, keepdims=True)decentralized

    The average value of the current batch data is calculated

  • ifsamplewise_std_normalizationforTrue, the implementationx /= (np.std(x, keepdims=True) + 1e-6)standardized

    The standard deviation of the current batch data is calculated

  • iffeaturewise_centerforTrue.self.meanIf the value is not empty, the command is executedx -= self.meanDecentralize or give a warning
  • iffeaturewise_std_normalizationforTrue.self.stdIf the value is not empty, the command is executedx /= (self.std + 1e-6)Decentralize or give a warning
  • ifzca_whiteningforTrue.self.principal_componentsIf it is not null, the calculation is performed, otherwise a warning is given

This function is called within the _get_batches_of_transformed_samples() function to get a batch of input data.

filepaths = self.filepaths
for i, j in enumerate(index_array):
    img = load_img(filepaths[j],
                   color_mode=self.color_mode,
                   target_size=self.target_size,
                   interpolation=self.interpolation)
    x = img_to_array(img, data_format=self.data_format)
    # Pillow images should be closed after `load_img`,
    # but not PIL images.
    if hasattr(img, 'close'):
        img.close()
    if self.image_data_generator:
        params = self.image_data_generator.get_random_transform(x.shape)
        x = self.image_data_generator.apply_transform(x, params)
        x = self.image_data_generator.standardize(x) # Perform standardized processing
    batch_x[i] = x
Copy the code

If featurewise_center is set to True, it calls the.fit() function to process the data (at which point self.mean has been assigned). Taichichuan contains some taichichuan. Taichichuan contains self. image_data_Generator. Taichichuan (x) contains some taichichuan.

4. Specific use

When Keras is used for image classification tasks, the training data can be saved according to the following structure:

datas/
    train/
        dogs/
            dog01.jpg
            dog02.jpg
            ...
        cats/
            cat01.jpg
            cat02.jpg
            ...
    validation/
        dogs/
            dog01.jpg
            dog02.jpg
            ...
        cats/
            cat01.jpg
            cat02.jpg
            ...
Copy the code
  • The images of each classification are stored in a folder and stored separately according to the training set and verification set.
  • callflow_from_directory()Function, the data label value is automatically inferred from the name/structure of the data subdirectories, each of which is treated as a different class. Therefore, the label value may not be entered.

4.1 example 1

This example does not normalize/decentralize /ZAC whitening the data

If data enhancement is required, follow these steps:

# call ImageDataGenerator
train_datagen = ImageDataGenerator(
        rotation_range=30,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

# Instantiate a generator through the flow_from_directory method, which generates batch data in a continuous loop
train_generator = train_datagen.flow_from_directory(
    './datas/train',
    target_size=(config.image_size, config.image_size),
    class_mode='categorical',
    batch_size=config.batch_size)
    
Take the generator as an argument to the model training function FIT_generator
model.fit_generator(train_generator,
        steps_per_epoch=nb_train_samples//config.batch_size + 1,
        epochs=config.epochs,
        validation_data=val_generator,
        validation_steps=nb_val_samples//config.batch_size + 1,
        callbacks=callbacks)
Copy the code

As you can also infer from the name of the training function FIT_generator, its input is a generator object.

The evaluate_generator or predict_generator methods can also be applied.

4.2 Data should be decentralized/standardized

4.2.1 By InvokingImageDateGenerator.fit()function

From the previous knowledge, the fit() function is used to use the data generator for sample data. When featureWISe_Center or FeatureWISe_STd_normalization or ZCA_Whitening is set to True, the input data X is de-centralized, normalized, and ZCA whitening.

Taking decentralization as an example, the specific treatment is as follows:

  • Calculate input dataxThe average of
  • performx-=self.mean
# ImageDataGenerator.fit()
if self.featurewise_center:
    self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
    broadcast_shape = [1.1.1]
    broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
    self.mean = np.reshape(self.mean, broadcast_shape)
    x -= self.mean
Copy the code

Note: The mean is the average of the input data. Set the corresponding parameter to True

If you do not want to use the input data’s own mean and standard deviation, you cannot call this function and use another method.

4.2.2 throughImageDateGenerator.standardize()function

This function applies the appropriate canonical configuration to a batch of input data

Input: Input batch data to be normalized return: normalized input data

This function changes the input x in place because it is primarily used internally to standardize the image and make it available to the network, incurring significant performance costs if copies are created.

There are two cases in which this function handles data:

  1. The featureWISe_Center or Featurewise_STd_normalization parameters are set to True, but the data is not processed by the. Fit () function. At this point, the mean value and standard deviation of the data are not obtained through the FIT () function, so the mean value and standard deviation of the data are empty and need to be specified. Defaults to None:

    self.mean = None
    self.std = None
    Copy the code

    The processing is as follows:

    if self.featurewise_center:
        if self.mean is not None:
            x -= self.mean
        else:
            warnings.warn('This ImageDataGenerator specifies '
                          '`featurewise_center`, but it hasn\'t '
                          'been fit on any training data. Fit it '
                          'first by calling `.fit(numpy_data)`.')
    if self.featurewise_std_normalization:
        if self.std is not None:
            x /= (self.std + 1e-6)
        else:
            warnings.warn('This ImageDataGenerator specifies '
                          '`featurewise_std_normalization`, '
                          'but it hasn\'t '
                          'been fit on any training data. Fit it '
                          'first by calling `.fit(numpy_data)`.')
    Copy the code

    The mean and standard deviation of data need to be manually set before data processing can be carried out. Otherwise, warning messages will be given. The correct usage is as follows:

    datagen = ImageDataGenerator(
                featurewise_center=True,
                rotation_range=30,
                shear_range=0.2,
                zoom_range=0.2)
    # Manually set the data mean
    datagen.mean = np.array(config.data_mean, dtype=np.float32).reshape((1.1.3))
    train_generator = datagen.flow_from_directory(config.train_data,
                              target_size=img_size,
                              batch_size=batch_size,
                              class_mode=None,
                              shuffle=False)
    Copy the code
  2. If the data has not been processed by the.fit() function and the data mean/standard deviation is not known, it can be processed by the Samplewise_center and samplewise_STd_normalization parameters.

    This parameter automatically calculates the mean and standard deviation of the current batch data

    # ImageDateGenerator()
    def standardize(self, x):
        if self.samplewise_center:
            x -= np.mean(x, keepdims=True)
        if self.samplewise_std_normalization:
            x /= (np.std(x, keepdims=True) + 1e-6)
    Copy the code
  3. throughpreprocessing_functionParameter specifies the handler function

    If you want to encapsulate the processing of data into a single function, There is no normalization described above (by specifying the parameter Samplewise_center or Samplewise_STd_normalization or FeatureWISe_CENTER or FeatureWISe_STd_normalization).

    If self.preprocessing_function is set to a handler, the preprocessor is executed first.

    if self.preprocessing_function:
        x = self.preprocessing_function(x)
    Copy the code

    Call method:

    # specify the handler preprocessing_function
    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input,
        rotation_range=30,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
    train_generator = train_datagen.flow_from_directory(
        './datas/train',
        target_size=(224.224),
        class_mode='categorical',
        batch_size=32)
    model.fit_generator(train_generator,
                steps_per_epoch=nb_train_samples//config.batch_size + 1,
                epochs=config.epochs,
                callbacks=callbacks)
    Copy the code