In GRL, the goal to be achieved is: in the case of forward conduction, the calculation result does not change, and in the case of gradient conduction, the gradient transmitted to the front leaf node changes to the opposite direction of the original. An example is best illustrated:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()

print(s)
s.backward()
print(x)
print(x.grad)
Copy the code

The result of this program is:

tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])
Copy the code

This is for every dimension in tensor:

So the derivative with respect to x is:

So when input x=[1,2,3], the corresponding gradient is: [18,30,42]

So this is normal gradient derivation, but how do you flip the gradient? Very simple, look at the code below:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input


Grl = GRL()

s =6* f.sum()
s = Grl(s)

print(s)
s.backward()
print(x)
print(x.grad)
Copy the code

The running results are as follows:

tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])
Copy the code

This program relative to the previous program, only difference in the addition of a gradient flip layer:

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input
Copy the code

There is no operation on the forward of this part, and. Neg () operation is done inside backward, which is equivalent to the reversal of the gradient. In the backward part of FUnction in Torch. Autograd, where grad_output defaults to 1 without doing anything.