PyTorch 损失函数取决于网络相对于输入的梯度

PyTorch loss function that depends on gradient of network with respect to input

我正在尝试实现一个损失函数,该损失函数取决于网络相对于其输入的梯度。也就是说,损失函数有一项像

sum(u - grad_x(network(x)))

其中 u 是通过网络前向传播 x 计算得出的。

我可以通过调用

来计算梯度
funcApprox = funcNetwork.forward(X)
funcGrad = grad(funcApprox, X, grad_outputs=torch.ones_like(funcApprox))

这里,funcNetwork 是我的 NN,X 是输入。这些计算是在损失函数中完成的。

但是,现在如果我尝试执行以下操作

opt.zero_grad()
loss = self.loss(X) # My custom loss function that calculates funcGrad, etc., from above
            
opt.zero_grad()
loss.backward()
opt.step()

我看到以下错误:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

在上面的 loss.backward() 行上。

我试过 create_graphretain_graph 等,但无济于事。

感谢任何帮助!

根据@aretor 的评论,在损失函数的 grad 调用中设置 retain_graph=True, create_graph=False,在 backward 中设置 retain_graph=True 解决了问题。