• GAN with Keras: Application to Image Deblurring
  • By Raphael Meudec
  • Translation from: The Gold Project
  • This article is permalink: github.com/xitu/gold-m…
  • Translator: luochen
  • Proofread by: SergeyChang mingxing47

In 2014, Ian Goodfellow proposed Generative Adversarial Networks (GAN) This article will focus on implementing an adversity-generated network based image de-blurring model using Keras. All the Keras code is here.

Scientific publication and Pytorch version implementation.


Quick review to generate adversarial networks

In a generative adversarial network, two networks train each other. The generative model misleads the discriminant model by creating false inputs that are not true. The discriminant model distinguishes between real and fake inputs.

GAN training process – Source

There are three main steps to training:

  • Create noise-based fake inputs using generative models.
  • Both real and false inputs are used to train the discriminant model.
  • Training the whole model: the model is composed of generating model followed by discriminant model.

Note that in the third step, the weights of the discriminant model are no longer updated.

The reason for concatenating the two model networks is that it is not possible to feed back directly to the generated model output. Our only measure is whether the model accepts the generated samples.

The structure of GAN is briefly reviewed here. If you find it difficult to understand, you can refer to the Excellent introduction.


The data set

Ian Goodfellow first applied the GAN model to generate MNIST data. In this tutorial, we use generative adversarial networks for image de-blurring. Therefore, the input to generate the model is not noise but fuzzy images.

The data set was GOPRO. You can download the compact version (9GB) or the full version (35GB). It contains human-blurred images from multiple street views. Data sets are in subfolders by scene.

Let’s first place the images in folders A (blur) and B (clear). The structure of A and B is consistent with the original pix2PIx article. I wrote a custom script to perform this task, using it according to the README.


model

The training process remains the same. First, let’s look at the neural network structure!

Generate models

Generation models are designed to reproduce clear images. The network model is based on residual network (ResNet) blocks. It continuously tracks the evolution of the original blurry image. This article is based on the UNet version, which I haven’t implemented yet. Both of these structures are suitable for image de-blurring.

DeblurGAN generates the network structure of the model – Source

At the core are nine ResNet blocks applied to the upsampling of the original image. Let’s look at the implementation of Keras!

from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout

def res_block(input, filters, kernel_size=(3.3), strides=(1.1), use_dropout=False):
    Instantiate a Keras Resnet block using a sequential API. :param input: input tensor :param filters: Number of convolution cores :param kernel_SIZE: Convolution kernel size: param strides: Param USe_dropout: Boolean to determine whether to use the Dropout: Return: Keras model ""
    x = ReflectionPadding2D((1.1))(input)
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=strides,)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    if use_dropout:
        x = Dropout(0.5)(x)

    x = ReflectionPadding2D((1.1))(x)
    x = Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,)(x)
    x = BatchNormalization()(x)

    Connect two layers of convolution between input and output
    merged = Add()([input, x])
    return merged
Copy the code

The ResNet layer is basically a convolution layer, adding inputs and outputs to form the final output.

from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model

from layer_utils import ReflectionPadding2D, res_block

ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256.256, input_nc)
n_blocks_gen = 9


def generator_model(a):
    """ build a generative model """
    # Current version : ResNet block
    inputs = Input(shape=image_shape)

    x = ReflectionPadding2D((3.3))(inputs)
    x = Conv2D(filters=ngf, kernel_size=(7.7), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Increase filter number
    n_downsampling = 2
    for i in range(n_downsampling):
        mult = 2**i
        x = Conv2D(filters=ngf*mult*2, kernel_size=(3.3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    # Apply 9 ResNet Blocks
    mult = 2**n_downsampling
    for i in range(n_blocks_gen):
        x = res_block(x, ngf*mult, use_dropout=True)

    # Reduce convolution kernel to 3 (RGB)
    for i in range(n_downsampling):
        mult = 2**(n_downsampling - i)
        x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3.3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    x = ReflectionPadding2D((3.3))(x)
    x = Conv2D(filters=output_nc, kernel_size=(7.7), padding='valid')(x)
    x = Activation('tanh')(x)

    # Add direct connection from input to output and recenter to [-1, 1]
    outputs = Add()([x, inputs])
    outputs = Lambda(lambda z: z/2)(outputs)

    model = Model(inputs=inputs, outputs=outputs, name='Generator')
    return model
Copy the code

Keras implements the generation model

As planned, nine ResNet blocks are applied to the upsampled version of the input. We add a connection from the input to the output and divide by 2 to maintain a normalized output.

So that’s the generation model, let’s look at the discriminant model.

Discriminant model

The goal of the discriminant model is to determine whether the input image is artificial. Therefore, the structure of the discriminant model is convolution and the output is a single value.

from keras.layers import Input
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model

ndf = 64
output_nc = 3
input_shape_discriminator = (256.256, output_nc)


def discriminator_model(a):
    """ Construct a discriminant model."""
    n_layers, use_sigmoid = 3.False
    inputs = Input(shape=input_shape_discriminator)

    x = Conv2D(filters=ndf, kernel_size=(4.4), strides=2, padding='same')(inputs)
    x = LeakyReLU(0.2)(x)

    nf_mult, nf_mult_prev = 1.1
    for n in range(n_layers):
        nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
        x = Conv2D(filters=ndf*nf_mult, kernel_size=(4.4), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)

    nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
    x = Conv2D(filters=ndf*nf_mult, kernel_size=(4.4), strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(filters=1, kernel_size=(4.4), strides=1, padding='same')(x)
    if use_sigmoid:
        x = Activation('sigmoid')(x)

    x = Flatten()(x)
    x = Dense(1024, activation='tanh')(x)
    x = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=inputs, outputs=x, name='Discriminator')
    return model
Copy the code

Keras implements the discriminant model

The final step is to build the full model. The special thing about this GAN is that the input is real image and not noise. As a result, we can get direct feedback on the output of the generated model.

from keras.layers import Input
from keras.models import Model

def generator_containing_discriminator_multiple_outputs(generator, discriminator):inputs = Input(shape=image_shape) generated_images = generator(inputs) outputs = discriminator(generated_images) model =  Model(inputs=inputs, outputs=[generated_images, outputs])return model
Copy the code

Let’s see how we can take full advantage of this particularity by using two loss functions.


training

Loss function

We extract loss values at two levels, one at the end of the generation model and one at the end of the entire model.

The first is to calculate perceptual loss directly from the output of the generated model. This loss value ensures that the GAN model is defuzzy-oriented. It compares the first convolution output of VGG.

import keras.backend as K
from keras.applications.vgg16 import VGG16
from keras.models import Model

image_shape = (256.256.3)

def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
Copy the code

The second loss value is to calculate the output Wasserstein loss of the whole model. It’s the average difference between the two images. It is known for improving convergence against generated networks.

import keras.backend as K

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)
Copy the code

The training process

The first step is to load the data and initialize the model. We used custom functions to load the data set and add the Adam optimizer to the model. We prevent the discriminant model from being trained by setting the Keras trainable option.

# Load dataset
data = load_images('./images/train', n_images)
y_train, x_train = data['B'], data['A']

Initialize the model
g = generator_model()
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

# Initialize optimizer
g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

# Compile model
d.trainable = True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100.1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable = True
Copy the code

Then, we start the iteration while dividing the data set into batches.

for epoch in range(epoch_num):
  print('epoch: {}/{}'.format(epoch, epoch_num))
  print('batches: {}'.format(x_train.shape[0] / batch_size))

  # Randomly divide images into different batches
  permutated_indexes = np.random.permutation(x_train.shape[0])

  for index in range(int(x_train.shape[0] / batch_size)):
      batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
      image_blur_batch = x_train[batch_indexes]
      image_full_batch = y_train[batch_indexes]
Copy the code

Finally, we train the model and the discriminant model according to the two kinds of losses. We use generative models to generate false inputs. We train the discriminant model to distinguish between false and true inputs, and then we train the whole model.

for epoch in range(epoch_num):
  for index in range(batches):
    # [Batch Preparation]

    # Generate false input
    generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)
    
    # Train multiple discriminant models on true and false inputs
    for _ in range(critic_updates):
        d_loss_real = d.train_on_batch(image_full_batch, output_true_batch)
        d_loss_fake = d.train_on_batch(generated_images, output_false_batch)
        d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

    d.trainable = False
    # Train generator only on discriminator's decision and generated images
    d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])

    d.trainable = True
Copy the code

You can refer to Github to see the entire loop!

Some material

I used AWS Instance (P2.xlarge) in Deep Learning AMI (Version 3.0). Under GOPRO data set compact version, the training time was about 5 hours (50 iterations).

Image deblurring results

Left to right: raw image, blurred image, GAN output

The output above is the result of our Keras Deblur GAN. Even in the case of severe blurring, the network was able to reduce and form a more convincing image. The headlights are clearer, the branches are clearer.

Left: GOPRO test image, right: GAN output.

One limitation is the induction mode on the image, which may be caused by the use of VGG as a loss.

Left: GOPRO test image, right: GAN output.

I hope you enjoyed this article on image de-blurring using generative adversarial models. Feel free to comment, follow us or contact me.

If you’re interested in computer vision, check out our previous article on Keras implementing content-based image retrieval. The following is a list of resources that generate adversarial networks.

Left: GOPRO test image, right: GAN output.

Generates a list of resources against the network.

  • NIPS 2016: Generative Adversarial Networks by Ian Goodfellow

  • ICCV 2017: A tutorial against the Generative Web

  • Keras implementation against generative networks by Eric Linder-Noren

  • Counter generated list of network resources by Deeplearning4J

  • Awesome against the generative network by Holger Caesar


Diggings translation project is a community for translating quality Internet technical articles from diggings English sharing articles. The content covers the fields of Android, iOS, front end, back end, blockchain, products, design, artificial intelligence and so on. For more high-quality translations, please keep paying attention to The Translation Project, official weibo and zhihu column.