代码语言:javascript复制
from torch.autograd.function import once_differentiable
class GOF_Function(Function):
@staticmethod #一般来说,要使用某个类的方法,需要先实例化一个对象再调用方法。 而使用@staticmethod或@classmethod,就可以不需要实例化,直接类名.方法名()来调用。
def forward(ctx, weight, gaborFilterBank):# 在forward中,需要定义GOF_Function这个运算的forward计算过程
ctx.save_for_backward(weight, gaborFilterBank) # 将输入保存起来,在backward时使用
output = _C.gof_forward(weight, gaborFilterBank)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
weight, gaborFilterBank = ctx.saved_tensors
grad_weight = _C.gof_backward(grad_output, gaborFilterBank)
return grad_weight, None
Pytorch提供了包torch.autograd用于自动求导。在前向过程中PyTorch会构建计算图,每个节点用Variable表示,边表示由输入节点到输出节点的函数(torch.autograd.Function对象)。Function对象不仅负责执行前向计算,在反向过程中,每个Function对象会调用.backward()函数计算输出对输入的梯度,然后将梯度传递给下一个Function对象。但是一些操作是不可导的,当你自定义的函数不可导时,在写backward函数时,就需要使用@once_differentiable。