我们可以在 GPU 上使用 pytorch scatter_
Can we use pytorch scatter_ on GPU
我正在尝试在 GPU 模式下使用 pyTorch 对某些数据进行一次热编码,但是,它一直给我一个例外。有人可以帮助我吗?
这是一个例子:
def char_OneHotEncoding(x):
coded = torch.zeros(x.shape[0], x.shape[1], 101)
for i in range(x.shape[1]):
coded[:,i] = scatter(x[:,i])
return coded
def scatter(x):
return torch.zeros(x.shape[0], 101).scatter_(1, x.view(-1,1), 1)
所以如果我在 GPU 上给它一个张量,它会显示如下:
x_train = [[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[14, 13, 83, 18, 14],
[ 0, 0, 0, 0, 0]]
print(char_OneHotEncoding(torch.tensor(x_train, dtype=torch.long).cuda()).shape)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-62-95c0c4ade406> in <module>()
4 [14, 13, 83, 18, 14],
5 [ 0, 0, 0, 0, 0]]
----> 6 print(char_OneHotEncoding(torch.tensor(x_train, dtype=torch.long).cuda()).shape)
7 x_train[:5, maxlen:maxlen+5]
<ipython-input-53-055f1bf71306> in char_OneHotEncoding(x)
2 coded = torch.zeros(x.shape[0], x.shape[1], 101)
3 for i in range(x.shape[1]):
----> 4 coded[:,i] = scatter(x[:,i])
5 return coded
6
<ipython-input-53-055f1bf71306> in scatter(x)
7
8 def scatter(x):
----> 9 return torch.zeros(x.shape[0], 101).scatter_(1, x.view(-1,1), 1)
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #3 'index'
顺便说一句,如果我们简单地删除这里的.cuda()
,一切都会很好
print(char_OneHotEncoding(torch.tensor(x_train, dtype=torch.long)).shape)
torch.Size([5, 5, 101])
是的,这是可能的。你必须注意 all 张量都在 GPU 上。特别是,默认情况下,像 torch.zeros
这样的构造函数在 CPU 上分配,这将导致这种不匹配。您的代码可以通过 device=x.device
构造来修复,如下所示
import torch
def char_OneHotEncoding(x):
coded = torch.zeros(x.shape[0], x.shape[1], 101, device=x.device)
for i in range(x.shape[1]):
coded[:,i] = scatter(x[:,i])
return coded
def scatter(x):
return torch.zeros(x.shape[0], 101, device=x.device).scatter_(1, x.view(-1,1), 1)
x_train = torch.tensor([
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[14, 13, 83, 18, 14],
[ 0, 0, 0, 0, 0]
], dtype=torch.long, device='cuda')
print(char_OneHotEncoding(x_train).shape)
另一个替代方案是称为 xxx_like
的构造函数,例如 zeros_like
,但在这种情况下,由于您需要与 x
不同的形状,我发现 device=x.device
更具可读性.
我正在尝试在 GPU 模式下使用 pyTorch 对某些数据进行一次热编码,但是,它一直给我一个例外。有人可以帮助我吗?
这是一个例子:
def char_OneHotEncoding(x):
coded = torch.zeros(x.shape[0], x.shape[1], 101)
for i in range(x.shape[1]):
coded[:,i] = scatter(x[:,i])
return coded
def scatter(x):
return torch.zeros(x.shape[0], 101).scatter_(1, x.view(-1,1), 1)
所以如果我在 GPU 上给它一个张量,它会显示如下:
x_train = [[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[14, 13, 83, 18, 14],
[ 0, 0, 0, 0, 0]]
print(char_OneHotEncoding(torch.tensor(x_train, dtype=torch.long).cuda()).shape)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-62-95c0c4ade406> in <module>()
4 [14, 13, 83, 18, 14],
5 [ 0, 0, 0, 0, 0]]
----> 6 print(char_OneHotEncoding(torch.tensor(x_train, dtype=torch.long).cuda()).shape)
7 x_train[:5, maxlen:maxlen+5]
<ipython-input-53-055f1bf71306> in char_OneHotEncoding(x)
2 coded = torch.zeros(x.shape[0], x.shape[1], 101)
3 for i in range(x.shape[1]):
----> 4 coded[:,i] = scatter(x[:,i])
5 return coded
6
<ipython-input-53-055f1bf71306> in scatter(x)
7
8 def scatter(x):
----> 9 return torch.zeros(x.shape[0], 101).scatter_(1, x.view(-1,1), 1)
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #3 'index'
顺便说一句,如果我们简单地删除这里的.cuda()
,一切都会很好
print(char_OneHotEncoding(torch.tensor(x_train, dtype=torch.long)).shape)
torch.Size([5, 5, 101])
是的,这是可能的。你必须注意 all 张量都在 GPU 上。特别是,默认情况下,像 torch.zeros
这样的构造函数在 CPU 上分配,这将导致这种不匹配。您的代码可以通过 device=x.device
构造来修复,如下所示
import torch
def char_OneHotEncoding(x):
coded = torch.zeros(x.shape[0], x.shape[1], 101, device=x.device)
for i in range(x.shape[1]):
coded[:,i] = scatter(x[:,i])
return coded
def scatter(x):
return torch.zeros(x.shape[0], 101, device=x.device).scatter_(1, x.view(-1,1), 1)
x_train = torch.tensor([
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[14, 13, 83, 18, 14],
[ 0, 0, 0, 0, 0]
], dtype=torch.long, device='cuda')
print(char_OneHotEncoding(x_train).shape)
另一个替代方案是称为 xxx_like
的构造函数,例如 zeros_like
,但在这种情况下,由于您需要与 x
不同的形状,我发现 device=x.device
更具可读性.