PyTorch中的model.zero_grad()和optimizer.zero_grad()

2022-09-02 21:22:33 浏览数 (1)

代码语言:javascript复制
model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

代码语言:javascript复制
def zero_grad(self):
        """Sets gradients of all model parameters to zero."""
        for p in self.parameters():
            if p.grad is not None:
                p.grad.data.zero_()

0 人点赞