了解 index_put 在 PyTorch 中的行为

Understanding behavior of index_put in PyTorch

我想了解 index_put 在 PyTorch 中的行为,但文档对我来说不是很清楚。

给出

a = torch.zeros(2, 3)
a.index_put([torch.tensor(1, 0), torch.tensor([1, 1])], torch.tensor(1.))

它returns

tensor([[1., 1., 0.], 
       [0., 0., 0.])

同时给予

a = torch.zeros(2, 3)
a.index_put([torch.tensor(0, 0), torch.tensor([1, 1])], torch.tensor(1.))

它returns

tensor([[0., 1., 0.], 
       [0., 0., 0.])

我想知道index_put到底有什么规律?如果我想将三个值赋给 a 怎么办,这样它 returns

tensor([0., 1., 1.,],
       [0., 1., 0.])

感谢任何帮助!

我在此处复制了您的示例,其中插入了参数名称、固定括号和正确的输出(您的已被交换):

a.index_put(indices=[torch.tensor([1, 0]), torch.tensor([1, 1])], values=torch.tensor(1.))

tensor([[0., 1., 0.],
        [0., 1., 0.]])

a.index_put(indices=[torch.tensor([0, 0]), torch.tensor([0, 1])], values = torch.tensor(1.))

tensor([[1., 1., 0.],
        [0., 0., 0.]]

此方法的作用是将值插入 indices 指示的原始 a 张量中的位置。 indices 是插入的 x 坐标和插入的 y 坐标的列表。值可以是单个值或一维张量。

要获得所需的输出,请使用:

a.index_put(indices=[torch.tensor([0,0,1]), torch.tensor([1, 2, 1])], values=torch.tensor(1.))

tensor([[0., 1., 1.],
        [0., 1., 0.]])

此外,您可以在 values 参数中传递多个值以将它们插入指定位置:

a.index_put(indices=[torch.tensor([0,0,1]), torch.tensor([1, 2, 1])], values=torch.tensor([1., 2., 3.]))

tensor([[0., 1., 2.],
        [0., 3., 0.]])