在 PyTorch 中以稀疏方式将梯度应用于张量
Apply gradient to a tensor in a sparse way in PyTorch
我有一个非常大的张量 L
(数百万个元素),我从中收集了一个相对较小的子张量 S
(可能有一千个元素)。
然后我将我的模型应用于 S
,计算损失,并反向传播到 S
和 L
,目的是仅更新 L
中的选定元素。问题是 PyTorch 使 L
的梯度成为连续张量,因此它基本上使 L
的内存使用量翻倍。
有没有简单的方法来计算梯度并将其应用于 L
而无需加倍内存使用?
示例代码来说明问题:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
net = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64,64),
nn.ReLU(),
nn.Linear(64, 1))
L = Parameter(torch.zeros([1024*1024*256], dtype=torch.float32))
L.data.uniform_(-1, 1)
indices = torch.randint(high=256*1024*1024, size=[1024])
S = torch.unsqueeze(L[indices], dim=1)
out = net(S)
loss = out.sum()
loss.backward()
print(loss)
g = L.grad
print(g.shape) # this is huge!
您在 L
上实际上不需要 requires_grad
,因为渐变将被手动计算和应用。相反,将其设置为 S
。这将在 S
.
处停止反向传播
然后,您可以使用 S.grad
和您喜欢的优化来更新 L
的值。沿着这些线
L = torch.zeros([1024*1024*256], dtype=torch.float32)
...
S = torch.unsqueeze(L[indices], dim=1)
S.requires_grad_()
out = net(S)
loss = torch.abs(out).sum()
loss.backward()
with torch.no_grad():
L[indices] -= learning_rate * torch.squeeze(S.grad)
S.grad.zero_()
我有一个非常大的张量 L
(数百万个元素),我从中收集了一个相对较小的子张量 S
(可能有一千个元素)。
然后我将我的模型应用于 S
,计算损失,并反向传播到 S
和 L
,目的是仅更新 L
中的选定元素。问题是 PyTorch 使 L
的梯度成为连续张量,因此它基本上使 L
的内存使用量翻倍。
有没有简单的方法来计算梯度并将其应用于 L
而无需加倍内存使用?
示例代码来说明问题:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
net = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64,64),
nn.ReLU(),
nn.Linear(64, 1))
L = Parameter(torch.zeros([1024*1024*256], dtype=torch.float32))
L.data.uniform_(-1, 1)
indices = torch.randint(high=256*1024*1024, size=[1024])
S = torch.unsqueeze(L[indices], dim=1)
out = net(S)
loss = out.sum()
loss.backward()
print(loss)
g = L.grad
print(g.shape) # this is huge!
您在 L
上实际上不需要 requires_grad
,因为渐变将被手动计算和应用。相反,将其设置为 S
。这将在 S
.
然后,您可以使用 S.grad
和您喜欢的优化来更新 L
的值。沿着这些线
L = torch.zeros([1024*1024*256], dtype=torch.float32)
...
S = torch.unsqueeze(L[indices], dim=1)
S.requires_grad_()
out = net(S)
loss = torch.abs(out).sum()
loss.backward()
with torch.no_grad():
L[indices] -= learning_rate * torch.squeeze(S.grad)
S.grad.zero_()