在 PyTorch 中计算中间节点的梯度
Computing gradients of intermediate nodes in PyTorch
我正在尝试了解 autograd 在 PyTorch 中的工作原理。在下面这个简单的程序中,我不明白为什么loss
w.r.tW1
和W2
的梯度是None
。 据我从文档中了解到,W1
和 W2
是不稳定的,因此无法计算梯度。 是吗?我的意思是,我怎么不能对损失 w.r.t 中间节点求导?谁能解释一下我在这里缺少什么?
import torch
import torch.autograd as tau
W = tau.Variable(torch.FloatTensor([[0, 1]]), requires_grad=True)
a = tau.Variable(torch.FloatTensor([[2, 2]]), requires_grad=False)
b = tau.Variable(torch.FloatTensor([[3, 3]]), requires_grad=False)
W1 = W + a * a
W2 = W1 - b * b * b
Z = W2 * W2
print 'W:', W
print 'W1:', W1
print 'W2:', W2
print 'Z:', Z
loss = torch.sum((Z - 3) * (Z - 3))
print 'loss:', loss
# free W gradient buffer in case you are running this cell more than 2 times
if W.grad is not None: W.grad.data.zero_()
loss.backward()
print 'W.grad:', W.grad
# all of them are None
print 'W1.grad:', W1.grad
print 'W2.grad:', W2.grad
print 'a.grad:', a.grad
print 'b.grad:', b.grad
print 'Z.grad:', Z.grad
需要时,累积中间梯度in a C++ buffer,但为了节省内存,它们默认不保留(暴露在python 对象中)。
仅保留使用 requires_grad=True
设置的叶变量的梯度(因此在您的示例中为 W
)
保留中间渐变的一种方法是注册一个钩子。这项工作的一个挂钩是 retain_grad()
(see PR)
在您的示例中,如果您编写 W2.retain_grad()
,W2
的中间梯度将暴露在 W2.grad
中
W1
和 W2
不是易失性的(您可以通过访问它们的 volatile
属性(即:W1.volatile
)来检查)并且不可能是因为它们不是叶子变量(例如 W
、a
和 b
)。相反,需要计算它们的梯度,请参阅它们的 requires_grad
属性。
如果只有一个叶子变量是volatile
,整个backward graph是没有构建出来的(可以做一个volatile看看损失梯度函数)
a = tau.Variable(torch.FloatTensor([[2, 2]]), volatile=True)
# ...
assert loss.grad_fn is None
总结
- Volatility 意味着没有梯度计算:在推理模式下很有用
- 只有一个叶子变量设置volatile禁用梯度计算
- 需要梯度意味着梯度计算。中间的暴露与否
- 只有一个叶子变量需要梯度计算启用梯度计算
我正在尝试了解 autograd 在 PyTorch 中的工作原理。在下面这个简单的程序中,我不明白为什么loss
w.r.tW1
和W2
的梯度是None
。 据我从文档中了解到, 是吗?我的意思是,我怎么不能对损失 w.r.t 中间节点求导?谁能解释一下我在这里缺少什么?W1
和 W2
是不稳定的,因此无法计算梯度。
import torch
import torch.autograd as tau
W = tau.Variable(torch.FloatTensor([[0, 1]]), requires_grad=True)
a = tau.Variable(torch.FloatTensor([[2, 2]]), requires_grad=False)
b = tau.Variable(torch.FloatTensor([[3, 3]]), requires_grad=False)
W1 = W + a * a
W2 = W1 - b * b * b
Z = W2 * W2
print 'W:', W
print 'W1:', W1
print 'W2:', W2
print 'Z:', Z
loss = torch.sum((Z - 3) * (Z - 3))
print 'loss:', loss
# free W gradient buffer in case you are running this cell more than 2 times
if W.grad is not None: W.grad.data.zero_()
loss.backward()
print 'W.grad:', W.grad
# all of them are None
print 'W1.grad:', W1.grad
print 'W2.grad:', W2.grad
print 'a.grad:', a.grad
print 'b.grad:', b.grad
print 'Z.grad:', Z.grad
需要时,累积中间梯度in a C++ buffer,但为了节省内存,它们默认不保留(暴露在python 对象中)。
仅保留使用 requires_grad=True
设置的叶变量的梯度(因此在您的示例中为 W
)
保留中间渐变的一种方法是注册一个钩子。这项工作的一个挂钩是 retain_grad()
(see PR)
在您的示例中,如果您编写 W2.retain_grad()
,W2
的中间梯度将暴露在 W2.grad
W1
和 W2
不是易失性的(您可以通过访问它们的 volatile
属性(即:W1.volatile
)来检查)并且不可能是因为它们不是叶子变量(例如 W
、a
和 b
)。相反,需要计算它们的梯度,请参阅它们的 requires_grad
属性。
如果只有一个叶子变量是volatile
,整个backward graph是没有构建出来的(可以做一个volatile看看损失梯度函数)
a = tau.Variable(torch.FloatTensor([[2, 2]]), volatile=True)
# ...
assert loss.grad_fn is None
总结
- Volatility 意味着没有梯度计算:在推理模式下很有用
- 只有一个叶子变量设置volatile禁用梯度计算
- 需要梯度意味着梯度计算。中间的暴露与否
- 只有一个叶子变量需要梯度计算启用梯度计算