pytorch narrow 不适用

pytorch narrow is not applicable

如何在pytorch中进行切片?我试过狭窄和麻木的切片?两者都不适用于输出训练数据和测试数据。有什么解决办法吗?

x1 = (max-min)*torch.rand(1, 21, dtype=torch.float) + min
x2 = (max-min)*torch.rand(1, 21, dtype=torch.float) + min
zipped_list = zip(x1, x2)
y = torch.empty(1, 21)
y = [torch.sin(2*x1+2) * torch.cos(0.5*x2)+0.5 for (x1, x2) in zipped_list]


print(y)

train_data = y.narrow(0,1)
test_data = y[11:21]
print(train_data)

输出是

AttributeError: 'list' object has no attribute 'narrow'

但是当我进行正常切片时,测试数据不会被正确切片

train_data = y[0:11]
test_data = y[11:21]
Train Data is: [tensor([-0.0515,  0.4574,  0.5141,  0.4865,  0.9266,  1.0984,  0.5364,  0.7042,
         0.1741, -0.4839,  0.4332,  0.2962,  0.2311,  0.6169,  0.4321,  0.4088,
         0.2443,  0.1982,  0.7978,  0.6651, -0.4453])]
Test Data is: []

在您的代码中,y 是一个 PyTorch 张量:

y = torch.empty(1, 21)

然后替换为 list 个 PyTorch 张量(实际上,只有一个):

y = [torch.sin(2*x1+2) * torch.cos(0.5*x2)+0.5 for (x1, x2) in zipped_list]

所以你需要获取 y 的第一个元素,它是一个张量,然后将其切片:

print(y[0][:5])