Med3D: Transfer Learning for 3D Medical Image Analysis

Making the address

In this paper, a general encoder that has been trained is the core and can be used for other image tasks.

1. Article highlights

  • Our goal is to build a universal THREE-DIMENSIONAL backbone that can be transferred to other medical tasks to achieve better performance than training from scratch.
  • In order to solve the problem of the small number of medical 3D images and the difficulty of labeling, Tencent Youtu proposed to train a pre-training model with a large number of medical 3D images, which is beneficial to accelerate the convergence of other specific image tasks and improve the accuracy with transfer learning.
  • Some current methods can not make good use of the spatial three-dimensional information of 3D medical images. Due to the lack of large-scale 3D medical images, collaborative training model based on multi-domain 3D medical image data sets may be a solution.
  • Main Content: Build a large THREE-DIMENSIONAL medical data set (containing a variety of segmentation problems) and use this data to train a baseline network that can be used to solve other medical problems.
  • A multi-Branch decoder is proposed to solve the problem of incomplete annotation.

2. Contribution of the paper

  • This paper proposes a heterogeneous Med3D network for 3D multi-domain medical data, which can extract general 3D features under the condition of large differences in data domain distribution.
  • Transfer the Med3D model backbone to three new 3D medical imaging tasks. The effectiveness of Med3D was verified by a large number of experiments.
  • The Med3D pre-training model and related source code will be open source to make it easier for the community to replicate the experimental results and apply Med3D to other applications.

Three methods

3.1 Data Collection

The data sets come from various 3D medical image segmentation data sets to get better, more representative information, rather than classification data sets. The reason:

  • (1) Compared with images in natural scenes, medical images contain fewer object categories, which is more likely to lead to poor generalization ability of the network.
  • ② Compared with images in natural scene, classification label is a very weak supervisory information in 3D medical image. Because the classification label may only be associated with a small part of the image, this is not conducive to network convergence.

3.2 normalization

The morphology (MRI, CT), three-dimensional spatial resolution and pixel intensity range of the data set used in the experiment are different, so it is necessary to carry out the distribution normalization of spatial resolution and gray value.

  • Spatial Normalization: Due to different devices, images from different hospitals or centers differ in Spatial distance, which means the physical distance between two pixels in different areas. And this kind of physical information is impossible for CNN to learn. Spatial Normalization means re-sampling to the same resolution to reduce the impact of changes in pixels in stereo space. (The processing process is not clear, to be added)
  • Pixel intensity (gray value) normalization :(for each image) SVD standard deviation, vm mean, vi raw value

3.2 Med3D Network

A universal encoding, decoding and segmentation structure is adopted, in which the encoder can be any basic structure. In this work, the ResNet model family was adopted as the basic structure of the encoder and some minor modifications were made to enable the network to train 3d medical data. Disadvantages of the dataset: lack of multi-organ segmentation annotation, i.e., incomplete annotation. For example, in the liver segmentation data set, only liver is taken as the foreground, while other organs are taken as the background, and other data sets are similar. Such incomplete annotation information will cause confusion in the network and the training process will not converge.

Because it is technically impossible to annotate the complete organ atlas in detail in large-scale 3D medical data, a multi-branch decoder is proposed to solve the incomplete annotation problem.

  • The encoder is branched with eight specific decoders, each of which corresponds to a specific dataset of 3DSEG-8 (eight different segmented datasets).
  • In the training stage, each decoder branch only processes the feature graph extracted from the corresponding data set, and the other branches do not participate in its optimization process. In addition, each decoder branch shares the same architecture, that is, a convolutional layer is used to combine features from the encoder. (that is, use the same encoder but different decoders)
  • Calculate the difference between the network output and ground truth annotation: map the sampled features directly from the decoder to the original image size. This simple decoder design enables the network to focus on training a universal encoder. During the testing phase, the decoder portion is removed and the remaining encoders can be moved to other tasks.

4. For example

4.1 lung segmentation

– First the coding part is extracted from Med3D as the feature extraction part, then the lungs of the whole body are segmented, followed by three sets of 3D decoding layers. Three sets of decoding layers:

  • The first decoder group is composed of a transpose convolution layer with a core size of (3,3,3) and channels number of 256(used to enlarge the feature map twice) and a convolution layer with a kernel size of (3,3,3) and 128 channels.
  • Second set of decoders: The remaining two sets of decoder layers are similar to the first set of decoder layers, with the number of channels in each layer gradually doubling.
  • Finally, the convolution layer in the kernel (1,1,1) is used to generate the final output, and the number of channels corresponds to the number of categories

4.2 liver segmentation

Inspired by the first prize in 2018 3D Atrial Segmentation Competition [34], we used two-stage segmentation network to segment liver.

Segmentation frame diagram:

  • First, the liver is roughly segmented in the whole image to obtain the region of interest (ROI) of the target.

    At this stage: ① We transmitted backbone pre-trained from Med3D as the encoder. Then the convolution layer with nucleus (1,1,1) and channel number (2) (liver and background) is taken. ② After resizing the image, the image is input into the rough segmentation network for 32 times of down-sampling to extract features. ③ Then, bilinear interpolation method is used to sample the feature image to the size of the original image.

  • Then, according to the results of the first stage, the liver target area was cut, and then the target area was subdivided to get the final segmentation result of the liver. This stage is mainly the fine segmentation of liver contour.

    In order to obtain more dense scale information and a larger acceptance field in the feature graph, we embedded backbone Pre-trained from Med3D to the most advanced DenseASPP segmentation network, which connects a group of empty convolutional layers in a dense way. The generated multi-scale features not only cover a larger scale range, but also densely cover the scale range without significantly increasing the model size. We then replaced all 2D cores with the corresponding 3D version. Due to the inevitable deviation between the ground truth and liver target prediction in the first step, in order to improve the robustness of the fine segmentation model, we randomly expanded the liver target region, and then processed it with rotation and translational methods.

Five experiments

5.1 Transfer learning experiment

The pre-trained Med3D encoder was connected to the DenseASPP[35] network and a model was used to demonstrate the latest performance on the liver segmentation task. The same segmentation network as resNET-152 backbone was established and initialized with pre-trained Med3D. In the course of training, because some volumes did not provide spacing information, we used the average spacing value of training data to standardize all data. We also standardized the strength of Hounsfield cells with window widths from -200 to 250. All the training hyperparameters are the same as in the previous segmentation experiment.