火炬就地操作以节省内存(softmax)

torch in-place operations to save memory (softmax)

torch 中的一些操作是就地执行的。 Shorthand 运算符,例如 +=。

是否可以就地执行其他操作,例如softmax

我目前从事语言处理工作。该模型在大量词汇表上产生一长串概率分布。这个最终输出张量负责大约 60% 的已分配内存。这是一个大问题,因为我需要计算它的 softmax 并且需要双倍的内存。

这是问题的一个例子。我对张量 t 不感兴趣,只对它的 softmax 感兴趣:

import numpy as np
import torch
import torch.nn.functional as F

t = torch.tensor(np.zeros((30000,30000))).cuda()  #allocates 6.71 GB of GPU
softmax = F.softmax(t, 1)  #out of memory error
del t  #too late, program crashed

即使以下方法也不起作用:

F.softmax(torch.tensor(np.zeros((30000,30000))).cuda(), 1)

目前无法使用 PyTorch。您可以尝试自己滚动 GPU kernel,但我看到前面有麻烦(如果不是一堵墙的话),这很可能是此操作最初不可用的原因。

Softmax 可以很容易地并行应用,除了归一化,这需要减少。减少是重要的,它们可以就地 并行(请记住,原子的使用相当于并发但非并行操作)。这意味着您的就地操作无论如何都必须在幕后分配,这违背了目的,或者 运行 非常慢。请认为我的回答有点推测性,我不是 GPGPU 专家,但我的观点是,这至少是一个廉价、快速和正确解决的难题。

话虽如此,如果您只关心 softmax,在我看来您就像在进行推理。也许您的应用程序将 logits 移动到 CPU 和 运行 softmax 是一个可行的选择,而您的 GPU 已经在处理下一批?

我创建了一个就地版本的 softmax:

import numpy as np
import torch
import torch.nn.functional as F

# in-place version
t = torch.tensor(np.ones((100,200)))
torch.exp(t, out=t)
summed = torch.sum(t, dim=1, keepdim=True)
t /= summed

# original version
t2 = torch.tensor(np.ones((100,200)))
softmax = F.softmax(t2, 1)

assert torch.allclose(t, softmax)

回答我的问题:如果你想要就地函数,你必须通过插入低级操作来自己创建它们:

  • 许多函数,例如 torch.exp 可以指定一个可选的 out 参数。
  • 作业t[idx] = something就位
  • shorthand 运算符 /=*=+=-= 就位

这需要仔细调试并且可能不直观:

t = t / summed  #not in-place
t /= summed  #in-place

我读到就地操作会产生梯度问题。我将使用此代码进行更多测试。