This article is shared by AI Hao from Huawei cloud community “Full Convolutional Network (FCN) : Semantic Segmentation using FCN”.

FCN classifies images at the pixel level, thus solving the semantic segmentation problem. Different from the classic CNN, which uses the full connection layer to obtain fixed-length feature vectors after the convolutional layer for classification (the full connection layer + Softmax output), FCN can accept input images of any size, and the deconvolution layer is used to upsample the feature map of the last convolutional layer to restore it to the same size of the input image. Thus, a prediction can be generated for each pixel, while preserving the spatial information in the original input image, and finally, pixel-by-pixel classification can be performed on the up-sampled feature map.

The following figure is the schematic diagram of the full convolutional network (FCN) used for semantic segmentation:

Disadvantages of the traditional SEGMENTATION method based on CNN?

Traditional SEGMENTATION methods based on CNN: In order to classify a pixel, an image block around the pixel is used as the INPUT of CNN for training and prediction. This method mainly has several shortcomings:

1) High storage overhead. For example, 15 * 15 image blocks are used for each pixel, and then image blocks are input into CNN for category judgment by constantly sliding the window. Therefore, the required storage space increases sharply with the number and size of sliding Windows;

2) Low efficiency, the adjacent pixel blocks are basically repeated, and the convolution is calculated for each pixel block one by one, which is repeated to a large extent;

3) The size of the pixel block limits the size of the felt area. Usually, the size of the pixel block is much smaller than the size of the whole image, and only some local features can be extracted, thus limiting the classification performance. The full convolutional network (FCN) recovers the category of each pixel from the abstract features. That is, from image level classification to pixel level classification.

What has changed about FCN?

For general classified CNN networks, such as VGG and Resnet, some full connection layers will be added at the end of the network, and category probability information can be obtained after Softmax. However, this probability information is 1-dimensional, that is, it can only identify the category of the whole image, not the category of each pixel point, so this full-connection method is not suitable for image segmentation. FCN proposed that the following full connections could be replaced by convolution, so that a 2-dimensional feature map could be obtained, followed by softmax layer to obtain the classification information of each pixel point, thus solving the segmentation problem, as shown in Figure 4.

FCN shortcomings

(1) The results obtained are still not precise enough. Although 8 times up-sampling is much better than 32 times, the up-sampling result is still fuzzy and smooth, not sensitive to the details in the image.

(2) The classification of each pixel fails to fully consider the relationship between pixels. The spatial regularization step used in the segmentation method based on pixel classification is ignored, which lacks spatial consistency.

The data set

The data set for this example uses PASCAL VOC 2012 data set, which has twenty categories:

Person: person

Animal: bird, cat, cow, dog, horse, sheep

Vehicle: aeroplane, bicycle, boat, bus, car, motorbike, train

Indoor: bottle, chair, dining table, potted plant, sofa, tv/monitor

The PASCAL Visual Object Classes Challenge 2012 (VOC2012) (ox.ac.uk)

Data set structure:

VOCdevkit └ ─ ─ VOC2012 ├ ─ ─ all Annotations image annotation information (XML file) ├ ─ ─ ImageSets │ ├ ─ ─ the Action behavior motion image information │ ├ ─ ─ Layout person any part of the image information │ │ │ ├ ─ ─ the Main target detection classification image information │ │ ├ ─ ─ "train". TXT training set (5717) │ │ ├ ─ ─ val. TXT validation set (5823) │ │ └ ─ ─ trainval. TXT training set + validation set (11540) │ │ │ └ ─ ─ Segmentation target Segmentation image information │ ├ ─ ─ "train". TXT training set (1464) │ ├ ─ ─ val. TXT validation set (1449) │ └ ─ ─ trainval. TXT training set + validation set (2913) │ ├ ─ ─ ├─ ├─ Class Exercises ├─ Class Exercises Class ExercisesCopy the code

The data set contains object detection and semantic segmentation. We only need semantic segmentation data set, so we can consider deleting redundant images.

1. Get the names of all images.

2. Obtain the names of all semantic segmentation masks.

3. Find the difference set of the two and delete the name of the difference set.

The code is as follows:

import glob
import os
image_all = glob.glob('data/VOCdevkit/VOC2012/JPEGImages/*.jpg')
image_all_name = [image_file.replace('\\', '/').split('/')[-1].split('.')[0] for image_file in image_all]

image_SegmentationClass = glob.glob('data/VOCdevkit/VOC2012/SegmentationClass/*.png')
image_se_name= [image_file.replace('\\', '/').split('/')[-1].split('.')[0] for image_file in image_SegmentationClass]
image_other=list(set(image_all_name) - set(image_se_name))
print(image_other)
for image_name in image_other:
    os.remove('data/VOCdevkit/VOC2012/JPEGImages/{}.jpg'.format(image_name))
Copy the code

Code link

The code selected in this example comes from deep-learning-for-image-processing/ Pytorch_segmentation/FCN at master · WZMIAOMIAO/deep-learning-for-image-processing (github.com)

Other code also has a lot of, this article is easier to understand!

In fact, there is a better image segmentation library: github.com/qubvel/segm…

The image segmentation collection was created by Russian programmer Pavel Yakubovskiy. I’ll also use this library for demos in a later article.

The project structure

├─ SRC: backbone of model and construction of FCN ├─ train_utils: training, validation and multi-gpu training related modules ├─ my_dataset. Py: custom dataset for reading vocs dataset ├─ train.py: ├─ predict.py: FCN_resnet50 (Dilated/Atrous Convolution) └ ─ predict.py: └─ pascal_voc_classes.json: pascal_VOC tag file, using mIoU and other metrics of trained weight validation/test dataCopy the code

Because there is too much code to explain one by one, I will analyze the important code next.

Custom data set reads

My_dataset. Py customizes the data reading method as follows:

import os
import torch.utils.data as data
from PIL import Image

class VOCSegmentation(data.Dataset):
    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        super(VOCSegmentation, self).__init__()
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        root=root.replace('\\','/')
        assert os.path.exists(root), "path '{}' does not exist.".format(root)
        image_dir = os.path.join(root, 'JPEGImages')
        mask_dir = os.path.join(root, 'SegmentationClass')

        txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)
        txt_path=txt_path.replace('\\','/')
        assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)
        with open(os.path.join(txt_path), "r") as f:
            file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
        assert (len(self.images) == len(self.masks))
        self.transforms = transforms
Copy the code

Import the required packages.

Define the VOC data set reading class VOCSegmentation. In the init method, the core is to read the image list and mask list.

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target
Copy the code

The __getitem__ method takes a single image and its corresponding mask, and then enhances it.

 def collate_fn(batch):
        images, targets = list(zip(*batch))
        batched_imgs = cat_list(images, fill_value=0)
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets
Copy the code

The collate_fn method calls cat_list to align data in a batch.

In the train. The torch in py. Utils. Data. The DataLoader calls

 train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)
  val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)
Copy the code

training

The important parameters

To open train.py, let’s take a look at the important parameters:

def parse_args(): Import argparse = argparse.argumentParser (description=" PyTorch FCN training") # The folder where the root of the dataset (VOCdevkit) is located parser.add_argument("--data-path", default="data/", help="VOCdevkit root") parser.add_argument("--num-classes", default=20, type=int) parser.add_argument("--aux", default=True, type=bool, help="auxilier loss") parser.add_argument("--device", default="cuda", help="training device") parser.add_argument("-b", "--batch-size", default=32, type=int) parser.add_argument("--epochs", default=30, type=int, metavar="N", Help ="number of total epochs to train") parser. Add_argument ('--lr', default=0.0001, type=float, Help ='initial learning rate') parser. Add_argument ('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', Parser. Add_argument ("--amp", default=False, type=bool, help="Use torch.cuda.amp for mixed precision training") args = parser.parse_args() return argsCopy the code

Data-path: The folder where the root directory (VOCdevkit) of the data set is defined

Num-classes: Detects the number of target classes (excluding backgrounds).

Aux: Indicates whether to use aux_classifier.

Device: Use CPU or GPU for training. The default is CUDA.

Batch-size: indicates BatchSize.

Epochs: indicates the number of epochs.

Lr: learning rate.

Resume: Choose the model to use when you continue training.

Start-epoch: indicates the start epoch. For retraining, you do not need to start from 0.

Amp: Whether to use torch’s automatic mixed precision training.

Data to enhance

Enhance the methods in the call transforms. Py.

Enhancements to the training set are as follows:

class SegmentationPresetTrain: Def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), STD =(0.229, 0.224, 0.225)): Max_size = int(2.0 * base_size) # random Resize enhancements. trans = [T.RandomResize(min_size, max_size)] if hflip_prob > 0: Trans. Append (t.ruomhorizontalflip (hfLIP_prob)) trans. Extend ([# ruomcrop (crop_size), t.totensor (), T.Normalize(mean=mean, std=std), ]) self.transforms = T.Compose(trans) def __call__(self, img, target): return self.transforms(img, target)Copy the code

Training set enhancement, including random Resize, random horizontal flip, random clipping.

Validation set enhancement:

Class SegmentationPresetEval: def __init__(self, base_size, mean=(0.485, 0.456, 0.406), STD =(0.229, 0.224, 0.225)): self.transforms = T.Compose([ T.RandomResize(base_size, base_size), T.ToTensor(), T.Normalize(mean=mean, std=std), ]) def __call__(self, img, target): return self.transforms(img, target)Copy the code

The validation set enhancement is relatively simple, with only random Resize.

The Main method

I have made some changes to the Main method, with the following code:

# fcn_resnet50(pretrained=True) # fcn_resnet50(pretrained=True) if num_classes ! = 21: model.classifier[4] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) model.aux_classifier[4] = torch.nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1) print(model) model.to(device) # if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!" ) model = torch.nn.DataParallel(model)Copy the code

The model, I changed to pytorch official model, if you can use the official model try to use the official model.

The default category is 21. If it is not 21, change the category.

Detect whether there are multiple cards in the system. If there are multiple cards, using multiple cards can not waste resources.

If you don’t want to use all the cards and instead specify a few of them, you can use:

OS. Environ [' CUDA_VISIBLE_DEVICES] = '0, 1'Copy the code

You can also set gradient methods in dataparty:

The model = torch. Nn. DataParallel (model, device_ids = [0, 1])Copy the code

If multiple graphics cards are used, the parameters of the model need to be changed to model.module. XXX, for example:

  params = [p for p in model.module.aux_classifier.parameters() if p.requires_grad]
            params_to_optimize.append({"params": params, "lr": args.lr * 10})
Copy the code

Once you’ve done all the above, you can start your training, as shown below:

test

Before starting the test, we also need to grab the palette and create a new script, get_palette. Py, as follows:

Import json import numpy as NP from PIL import Image # Read mask tag target = image.open ("./2007_001288.png") # Retrieve the palette 0 = Target.getPalette () Palette = Np.0 (Palette, (-1, 3)). Tolist () print(Palette) # Change to dictionary subform pd = dict((I, 0) color) for i, color in enumerate(palette)) json_str = json.dumps(pd) with open("palette.json", "w") as f: f.write(json_str)Copy the code

Pick a Mask, grab it using the getPalette method, and save it in dictionary format.

Next, to start the prediction section, create a new predict.py and insert the following code:

import os
import time
import json
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from torchvision.models.segmentation import fcn_resnet50
Copy the code

Import the package file required by the program in the mian method:

def main(): aux = False # inference time not need aux_classifier classes = 20 weights_path = "./save_weights/model_5.pth" img_path =  "./2007_000123.jpg" palette_path = "./palette.json" assert os.path.exists(weights_path), f"weights {weights_path} not found." assert os.path.exists(img_path), f"image {img_path} not found." assert os.path.exists(palette_path), f"palette {palette_path} not found." with open(palette_path, "rb") as f: pallette_dict = json.load(f) pallette = [] for v in pallette_dict.values(): pallette += vCopy the code
  • Define whether aux_classifier is required. Forecast no aux_classifier is required, so set to False.
  • Set the category to 20, excluding backgrounds.
  • Define the path of the weight.
  • Define the path to the palette.
  • Read away the palette.

Next, is the loading model, single graphics card training model and multi-graphics card training model loading is different, we first look at the single graphics card training model how to load.

Model = fcn_classes (num_classes=classes+1) print(model) # Weights_dict = weights_path (weights_path, weights_path) map_location='cpu')['model'] for k in list(weights_dict.keys()): if "aux_classifier" in k: del weights_dict[k] # load weights model.load_state_dict(weights_dict) model.to(device)Copy the code

Define model FCN_RESNET50 with num_classes set to category +1 (background).

Load the trained model and delete the aux_classifier.

Then load the weights.

Here’s how to load a multi-card model:

    # create model
    model = fcn_resnet50(num_classes=classes+1)
    model = torch.nn.DataParallel(model)
    # delete weights about aux_classifier
    weights_dict = torch.load(weights_path, map_location='cpu')['model']
    print(weights_dict)
    for k in list(weights_dict.keys()):
        if "aux_classifier" in k:
            del weights_dict[k]
    # load weights
    model.load_state_dict(weights_dict)
    model=model.module
    model.to(device)
Copy the code

Define the model fCN_RESnet50, set num_CLASSES to category +1 (background), and put the model into DataParallel classes.

Load the trained model and delete the aux_classifier.

Load weights.

12) Torch. Nn.DataParallel(model) when executing torch. Nn.DataParallel(model), the model is placed in the model.module, so the model. So here we assign model.module to model.

Next is the processing of image data:

# load image original_img = Image.open(img_path) # from pil image to tensor and normalize data_transform = Transforms.Com pose([transforms.Resize(520), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), STD =(0.229, 0.224, 0.225)) img = data_transform(original_img) # expand batch dimension img = torch. Unsqueeze (img, dim=0)Copy the code

Load the image.

Resize, standardize and normalize the image.

Use torch. Unsqueeze to add a dimension.

Once the image is processed, the prediction can begin.

Model.eval () # enter validation mode with torch.no_grad(): # init model img_height, img_width = img.shape[-2:] init_img = torch.zeros((1, 3, img_height, img_width), device=device) model(init_img) t_start = time_synchronized() output = model(img.to(device)) t_end = time_synchronized() print("inference+NMS time: {}".format(t_end - t_start)) prediction = output['out'].argmax(1).squeeze(0) prediction = prediction.to("cpu").numpy().astype(np.uint8) np.set_printoptions(threshold=sys.maxsize) print(prediction.shape) mask = Image.fromarray(prediction) mask.putpalette(pallette) mask.save("test_result.png")Copy the code

Save the predicted results to test_result.png. View the run result:

The original:

Results:

Printed data:

Category List:

{
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20
}
Copy the code

From the results, it has been predicted that the category on the image is “train”.

conclusion

The core content of this article is to explain how to use FCN to achieve semantic segmentation of images.

At the beginning of the article, we discussed the structure and advantages and disadvantages of FCN. Then, we explained how to read the data set. Next, I will tell you how to implement training. Finally, there is testing and results presentation.

I hope this article can help you. The complete code: download.csdn.net/download/hh…

Click to follow, the first time to learn about Huawei cloud fresh technology ~