PyTorch 损失函数取决于网络相对于输入的梯度
PyTorch loss function that depends on gradient of network with respect to input
我正在尝试实现一个损失函数,该损失函数取决于网络相对于其输入的梯度。也就是说,损失函数有一项像
sum(u - grad_x(network(x)))
其中 u
是通过网络前向传播 x
计算得出的。
我可以通过调用
来计算梯度
funcApprox = funcNetwork.forward(X)
funcGrad = grad(funcApprox, X, grad_outputs=torch.ones_like(funcApprox))
这里,funcNetwork
是我的 NN,X
是输入。这些计算是在损失函数中完成的。
但是,现在如果我尝试执行以下操作
opt.zero_grad()
loss = self.loss(X) # My custom loss function that calculates funcGrad, etc., from above
opt.zero_grad()
loss.backward()
opt.step()
我看到以下错误:
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
在上面的 loss.backward()
行上。
我试过 create_graph
、retain_graph
等,但无济于事。
感谢任何帮助!
根据@aretor 的评论,在损失函数的 grad
调用中设置 retain_graph=True, create_graph=False
,在 backward
中设置 retain_graph=True
解决了问题。
我正在尝试实现一个损失函数,该损失函数取决于网络相对于其输入的梯度。也就是说,损失函数有一项像
sum(u - grad_x(network(x)))
其中 u
是通过网络前向传播 x
计算得出的。
我可以通过调用
来计算梯度funcApprox = funcNetwork.forward(X)
funcGrad = grad(funcApprox, X, grad_outputs=torch.ones_like(funcApprox))
这里,funcNetwork
是我的 NN,X
是输入。这些计算是在损失函数中完成的。
但是,现在如果我尝试执行以下操作
opt.zero_grad()
loss = self.loss(X) # My custom loss function that calculates funcGrad, etc., from above
opt.zero_grad()
loss.backward()
opt.step()
我看到以下错误:
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
在上面的 loss.backward()
行上。
我试过 create_graph
、retain_graph
等,但无济于事。
感谢任何帮助!
根据@aretor 的评论,在损失函数的 grad
调用中设置 retain_graph=True, create_graph=False
,在 backward
中设置 retain_graph=True
解决了问题。