Introduction to the

Variational Autoencoder (VAE) is one of the Generative models. Another common Generative Model is Generative Adversarial Network (GAN).

Here we introduce the principles of VAE and implement them with Keras

The principle of

We often have the need to learn how to generate new samples from many samples

In the case of MNIST, after looking at thousands of handwritten digital images, we were able to mimic and generate similar images that were not present in the original data, with some changes that looked similar

In other words, you need to learn the distribution of data X so that you can easily generate new samples based on the distribution of data


But it is not easy to estimate the distribution of data, especially when there is insufficient data

You can use an implicit variable, z, from which x is derived by a complex mapping, and assume that Z obeys a Gaussian distribution


Therefore, the distribution of original data can be obtained only by learning the parameters of gaussian distribution that implicit variables obey and the mapping function

In order to learn the parameters of the gaussian distribution that implicit variables obey, enough samples of Z are needed

However, z samples cannot be obtained directly, so a mapping function (conditional probability distribution) is needed to obtain the corresponding Z samples from the existing X samples


This looks very similar to an autoencoder, where the data itself is encoded into a hidden layer representation, which is decoded back

But VAE and AE are distinguished as follows:

  • In AE, the distribution of hidden variables is unknown, while in VAE, the hidden variables follow the Gaussian distribution
  • Encoder and decoder are studied in AE, and VAE also studies the distribution of hidden variables, including the mean and variance of gaussian distribution
  • AE can only get the corresponding reconstruction x from an x
  • VAE can generate new Z, so as to obtain new X, that is, generate new samples

Loss function

In addition to the reconstruction error, since the VAE assumes that the implicit variable Z follows the Gaussian distribution, the conditional probability distribution corresponding to the Encoder should be as similar as possible to the Gaussian distribution

You can use relative entropy, also known as Kullback — Leibler Divergence, to measure the difference, or distance, between two distributions, but relative entropy is asymmetric


implementation

Here, MNIST is taken as an example to learn the mean and variance of the Gaussian distribution of the hidden variable Z, so that x that is not present in the original data can be generated from the new Z

Encoder and decoder each use two fully connected layers, which are simpler, mainly to illustrate the VAE implementation

Load the library

# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt

from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
Copy the code

Define some constants

batch_size = 100
original_dim = 784
intermediate_dim = 256
latent_dim = 2
epochs = 50
Copy the code

Encoder part, two fully connected layers, hidden layer representation including mean and variance

x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
Copy the code

The Lambda layer does not participate in training, but only in calculations for later generation of new Z

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.)
    return z_mean + K.exp(z_log_var / 2) * epsilon

z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
Copy the code

Decoder part, two full connection layer, X_decoded_mean is the output of reconstruction

decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
Copy the code

Customize the total loss function and compile the model

def vae_loss(x, x_decoded_mean): xent_loss = original_dim * objectives.binary_crossentropy(x, X_decoded_mean kl_loss = -0.5 * K.sum(1 + Z_log_var - K.sum(z_mean) -k.xp (z_log_var), axis=-1)return xent_loss + kl_loss

vae = Model(x, x_decoded_mean)
vae.compile(optimizer='rmsprop', loss=vae_loss)
Copy the code

Load data and train, CPU training speed is tolerable

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

vae.fit(x_train, x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, x_test))
Copy the code

Define an encoder and see what the data in MNIST looks like in the hidden layer

encoder = Model(x, z_mean)

x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()
Copy the code

The result is shown below, showing that the different numbers are well separated in the two-dimensional hidden layer

Define a generator, from the hidden layer to the output, to generate a new sample

decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)
Copy the code

Generate some 2d data in a gridded way, input it as a new Z to the generator, and display the resulting X

n = 20
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-4, 4, n)
grid_y = np.linspace(-4, 4, n)

for i, xi in enumerate(grid_x):
    for j, yi in enumerate(grid_y):
        z_sample = np.array([[yi, xi]])
        x_decoded = generator.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[(n - i - 1) * digit_size: (n - i) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()
Copy the code

The result is as follows, consistent with the hidden layer diagram we saw before, even showing some transition states between numbers

Because some random factors are involved, there will be some differences in the results generated each time

If you replace the full connection layer with CNN, you should get better representation results

expand

Once you have mastered the above, you can run again on the FashionMNIST dataset using the same method, and the dataset size is exactly the same as MNIST

Just change four lines

from keras.datasets import fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

grid_x = np.linspace(-3, 3, n)
grid_y = np.linspace(-3, 3, n)
Copy the code

The complete code is as follows

# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt

from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import fashion_mnist

batch_size = 100
original_dim = 784
intermediate_dim = 256
latent_dim = 2
epochs = 50

x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.)
    return z_mean + K.exp(z_log_var / 2) * epsilon

z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid') h_decoded = decoder_h(z) x_decoded_mean = decoder_mean(h_decoded) def vae_loss(x, x_decoded_mean): xent_loss = original_dim * objectives.binary_crossentropy(x, X_decoded_mean kl_loss = -0.5 * K.sum(1 + Z_log_var - K.sum(z_mean) -k.xp (z_log_var), axis=-1)return xent_loss + kl_loss

vae = Model(x, x_decoded_mean)
vae.compile(optimizer='rmsprop', loss=vae_loss)

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

vae.fit(x_train, x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, x_test))

encoder = Model(x, z_mean)

x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()

decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)

n = 20
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-3, 3, n)
grid_y = np.linspace(-3, 3, n)

for i, xi in enumerate(grid_x):
    for j, yi in enumerate(grid_y):
        z_sample = np.array([[yi, xi]])
        x_decoded = generator.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[(n - i - 1) * digit_size: (n - i) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()
Copy the code

Let’s look at the representation of the hidden layer, which also plays a very good classification effect

Then generate some graphics to see the transitions between different kinds of clothes

reference

  • Auto-encoding Variational Bayes: arxiv.org/pdf/1312.61…
  • Tutorial on Variational Autoencoders: arxiv.org/pdf/1606.05…
  • Building Autoencoders in Keras: blog.keras. IO/build-au…
  • Fashion-MNIST Database of Fashion Articles: Keras.io /datasets/# F…
  • Peck rice daily 】 【 VAE:zhuanlan.zhihu.com/p/25269592 variational encoder

Video lecture course

Deep and interesting (1)