“This is the 12th day of my participation in the First Challenge 2022. For details: First Challenge 2022.”

preface

This is a transcript of my previous blog post on CSDN. The original link was: Dim problem in Pytorch

background

Many Pytorch functions have the dim control parameter, but its specific meaning is sometimes confusing. Here are two specific examples to record my understanding.

Take sum in Pytorch for example

Using the tensor’s sum function for the tensor, we build a 3 dimensional tensor of 2*3*4 and sum over dim=0,1,2. The code and display result are as follows:

>>x = [[[1.2.3.4], [5.6.7.8], [4.3.2.1]],
    [[2.2.2.2], [1.1.1.1], [3.3.3.3]]]
>>x = torch.tensor(x)
>>x.shape
torch.Size([2.3.4])
Copy the code

The code above defines a 2*3*4 three-dimensional tensor. Next, sum over dim=0

>>x.sum(dim=(0), keepdim=True)
tensor([[[3, 4, 5, 6],
         [6, 7, 8, 9],
         [7, 6, 5, 4]]])
>>x.sum(dim=(0), keepdim=True).shape
torch.Size([1, 3, 4])
Copy the code

As you can see, after dim=0, the size of the 0th dimension changes to 1. That is, the 0th dimension is summed, while the other dimensions are not.

For dim=1:

>>x.sum(dim=(1), keepdim=True)
tensor([[[10, 11, 12, 13]],

        [[ 6,  6,  6,  6]]])
>>x.sum(dim=(1), keepdim=True).shape
torch.Size([2, 1, 4])
Copy the code

The above code sums dim=1, and the size of the first dimension changes to 1, leaving the other dimensions unchanged.

For dim=2:

>>x.sum(dim=(2), keepdim=True)
tensor([[[10],
         [26],
         [10]],

        [[ 8],
         [ 4],
         [12]]])
>>x.sum(dim=(2), keepdim=True).shape
torch.Size([2, 3, 1])
Copy the code

The above code sums dim=2, and the size of the second dimension changes to 1, leaving the other dimensions unchanged.

As can be seen from the result dimensions of the above experiment, the number of dim equals indicates which dimension is operated on. In this case, the sum function ends up summing up the values in that dimension, leaving the other dimensions undisturbed.

Another example

Let’s look at another example, where we use the argmax function to experiment:

>>x.argmax(dim=(1), keepdim=True)
tensor([[[1.1.1.1]],

        [[2.2.2.2]]])
>>x.argmax(dim=(1), keepdim=True).shape
torch.Size([2.1.4])
>>x.argmax(dim=(0), keepdim=True)
tensor([[[1.0.0.0],
         [0.0.0.0],
         [0.0.1.1]]])
>>x.argmax(dim=(0), keepdim=True).shape
torch.Size([1.3.4])
>>x.argmax(dim=(2), keepdim=True)
tensor([[[3],
         [3],
         [0]],

        [[0],
         [0],
         [0]]])
>>x.argmax(dim=(2), keepdim=True).shape
torch.Size([2.3.1])
Copy the code

This example can also confirm the idea above, that is, all operations are only performed on the set DIM dimension, and no operations such as comparison and accumulation are carried out on other dimensions, and the size information in these dimensions is still retained.

>>x.sum(dim=-1, keepdim=True)
tensor([[[10],
         [26],
         [10]],

        [[ 8],
         [ 4],
         [12]]])
>>x.sum(dim=-1, keepdim=True).shape
torch.Size([2.3.1])
Copy the code

Here dim=-1 is the last dimension

conclusion

From the above two experiments, it can be seen that most of the basic operations in Pytorch only operate on the dim dimension set. No operations are performed on other dimensions and the size information in these dimensions is kept.