For more perfectly formatted illustrated articles, go to studyai.com/pytorch-1.4…

In this tutorial, you will learn how to enhance your network using a visual attention mechanism called spatial converter network. You can read more about spatial converter networks in DeepMind Paper.

Spatial Transformer Networks (STN) is a generalization of differentiable attention for any Spatial transformation. STN allows the neural network to learn how to perform spatial transformations on the input image to improve the geometric invariance of the model. For example, it can crop areas of interest, scale, and correct the orientation of an image. This is a useful mechanism because CNN is not invariant to image rotation, scale, and more generally affine transformations.

One of the best things about STN is the ability to simply plug it into any existing CNN with very little modification.

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion()   # Interactive mode
Copy the code

Load the data

In this paper, we experiment with classical MNIST data sets. The standard convolutional network and STN network are used for augmentation.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training data set
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='. ', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,)), (0.3081,))), shuffle=True, num_workers=4)# Test the data set
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='. ', train=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)))), batch_size=64, shuffle=True, num_workers=4)Copy the code

Describes the STN network

The spatial converter network can be divided into three main components:

Localization network is a regular CNN network that performs regression on transformation parameters. The transformation is never explicitly learned from the dataset, but is a spatial transformation which is automatically learned by the network to improve the global accuracy. The grid generator generates from the output image a grid of coordinates in the input image corresponding to each pixel. The sampler takes the converter's parameters and applies them to the input image.Copy the code

../_images/stn-arch.png

Note

We need PyTorch. (PyToch 1.0 and above) with affine_grid and grid_sample modules.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        Localization network (LOCALIZATION network)
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        Regression networks for estimating 3 * 2 affine matrices
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        Weights /bias is initialized using identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # forward function of STN network
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):
        Transform the input
        x = self.stn(x)

        Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net().to(device)
Copy the code

Training model

Now, let’s use the SGD algorithm to train the model. The network learns to classify tasks in a supervised manner. At the same time, the model automatically learns STN in an end-to-end manner.

Optimizer = optim.sgd (model.parameters(), lr=0.01) def train(epoch): model.train()for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
#
# A simple test method to test STN performance on MNIST.
#


def test():
    with torch.no_grad():
        model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # sum up batch loss
            test_loss += F.nll_loss(output, target, size_average=False).item()
            # get the index of the max log-probability
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(test_loader.dataset),
                      100. * correct / len(test_loader.dataset)))
Copy the code

Visualize the results of STN

We will now examine the results of what we have learned about the mechanism of visual attention.

We define a small helper function to visualize the transformation during training.

def convert_image_np(inp):
    """Convert a Tensor to numpy image."""Inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) STD = np.array([0.229, 0.224, 0.225]) INP = STD * INP + mean INP = NP. Clip (inP, 0, 1)return inp

# At the end of the training, we will visualize the output of the STN layer.
Visualize batch input images and the corresponding batch generated using STN transform.


def visualize_stn():
    with torch.no_grad():
        # Get a batch of training data
        data = next(iter(test_loader))[0].to(device)

        input_tensor = data.cpu()
        transformed_input_tensor = model.stn(data).cpu()

        in_grid = convert_image_np(
            torchvision.utils.make_grid(input_tensor))

        out_grid = convert_image_np(
            torchvision.utils.make_grid(transformed_input_tensor))

        # Plot the results side-by-side
        f, axarr = plt.subplots(1, 2)
        axarr[0].imshow(in_grid)
        axarr[0].set_title('Dataset Images')

        axarr[1].imshow(out_grid)
        axarr[1].set_title('Transformed Images')

for epoch in range(1, 20 + 1):
    train(epoch)
    test(a)# Visualize the STN transformation on some input batch
visualize_stn()

plt.ioff()
plt.show()

Copy the code