@once_differentiable有什么用?

2022-09-02 13:31:25 浏览数 (1)

代码语言:javascript复制
from torch.autograd.function import once_differentiable
class GOF_Function(Function):
    @staticmethod #一般来说,要使用某个类的方法,需要先实例化一个对象再调用方法。 而使用@staticmethod或@classmethod,就可以不需要实例化,直接类名.方法名()来调用。
    def forward(ctx, weight, gaborFilterBank):# 在forward中,需要定义GOF_Function这个运算的forward计算过程
        ctx.save_for_backward(weight, gaborFilterBank)  # 将输入保存起来,在backward时使用
        output = _C.gof_forward(weight, gaborFilterBank)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        weight, gaborFilterBank = ctx.saved_tensors
        grad_weight = _C.gof_backward(grad_output, gaborFilterBank)
        return grad_weight, None

Pytorch提供了包torch.autograd用于自动求导。在前向过程中PyTorch会构建计算图,每个节点用Variable表示,边表示由输入节点到输出节点的函数(torch.autograd.Function对象)。Function对象不仅负责执行前向计算,在反向过程中,每个Function对象会调用.backward()函数计算输出对输入的梯度,然后将梯度传递给下一个Function对象。但是一些操作是不可导的,当你自定义的函数不可导时,在写backward函数时,就需要使用@once_differentiable。

0 人点赞