Pytorch use

Pytorch Chinese document links: Pytorch – cn readthedocs. IO/useful/latest /

Simple Network construction

Simple network learning video reference: www.bilibili.com/video/BV1e4…

1. Data loading and normalization

Common data set

Trainset = # training dataset torchvision. Datasets. CIFAR10 (root = '. / data, "train" = True, download = True, transform=transform ) trainloader = torch.utils.data.DataLoader( trainset, batch_size=4, shuffle=True, Num_workers testset = = 2) # test dataset torchvision. Datasets. CIFAR10 (root = '. / data, "train" = False, download = True, transform=transform ) testloader = torch.utils.data.DataLoader( testset, batch_size=4, shuffle=False, num_workers=2 )Copy the code

Set the parameters

  1. root
  2. True if train is a training set, False if train is a test set
  3. If download is True, download is required
  4. Data normalization Transform is a custom data normalization method
  5. Num_workers is used with several cores

Private data set

1. Load the ImageFolder

privatrainset = torchvision.datasets.ImageFolder(
root='./data', transform=transform
)
privadataloader = torch.utils.data.DataLoader(
    privatrainset, batch_size=4, shuffle=False, num_workers=2
)
Copy the code
  1. None Train, download parameters
  2. Data sets are named by format

Pytorch automatically loads data under iage_Pat Meumduring training or testing

When training the model, Pytorch automatically labels the loaded data based on the name of the image_Path subfolder

When loading a private dataset, the folder in which the dataset resides should maintain the following structure:

/toy_dataset
    /class 1
    /class 2
    ...
Copy the code

2. Customize data sets and loading methods

www.jianshu.com/p/2d9927a70…

I haven’t used it yet, but I’ll save it for the next time

2. Define the neural network

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=(3, 3))
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(3, 3))
        self.fc1 = nn.Linear(16*28*28, 512)
        self.fc2 = nn.Linear(512, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = F.relu(x)

        x = x.view(-1, 16*28*28)
        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.relu(x)

        x = self.fc3(x)

        return x
Copy the code
  • Inherit nn.Module class and set up the network. Here, set up two convolution layers and three full connection layers
  • Define the forward function to determine the order in which the network is executed

3. Define weight updating rules and loss functions

Criterion = nn.CrossentRopyLoss () optimizer = optim.sgd (net.parameters(), lr=0.001, momentum=0.9)Copy the code
  • The loss function is defined as cross entropy loss function
  • SGD optimization was adopted, learning rate=0.001,momentum=0.9

4. Train neural networks

for epoch in range(1):
    for i, data in enumerate(trainloader):
        images, labels = data
        outputs = net(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Copy the code
  • Define the training times epoch, and each epoch represents the entire data training round
  • Read the training set in sequence for I, data in Enumerate (trainLoader)
  • Gradient qing 0
  • Back propagation
  • Reverse optimization step

5. Model test

Correct = 0 total = 0 with torch. No_grad (): images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, Sum () total += alllabels. Size (0) Print (' Correct :', float(correct)/total)Copy the code

6. Save and read the model

# Save the model torch. Save (net.state_dict(), Net_2 = Net() net_2.load_state_dict('./model.pt')) correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum() total += labels.size(0)Copy the code