无论如何在一次操作中多次添加相同的索引?
Is there anyway to add same index multiple times in one operation?
假设我有两个张量 value 和 index,它们包含我们需要的数据和所有索引。我想用相应的 index 向 value 中的数据添加一个。如果一个索引在张量中显示 k 次 index,那么这个数据应该被添加 k,而不是一个。
这是一个例子:
value = torch.zeros(3) # [0, 0, 0]
index = torch.zeros(10).long() #[0,0,0,0,0,0,0,0,0,0]
ret = some_func(value, index) # [10, 0, 0]
我知道用for循环遍历index中的所有索引可以解决问题,但我想问一下有没有更优雅的方法?
一种方法是使用 scatter_add
:
In [54]: value = torch.zeros(3)
In [55]: index = torch.tensor([0, 0, 1, 0, 2, 2, 1, 1, 1])
In [56]: value.scatter_add(0, index, torch.ones_like(index, dtype=value.dtype))
Out[56]: tensor([3., 4., 2.])
您可以使用scatter_add_
对value
进行原地操作。
您可能会发现使用 bincount()
:
效率更高
In [63]: index = torch.tensor([0, 0, 1, 0, 2, 2, 1, 1, 1])
In [64]: counts = index.bincount(minlength=value.shape[0])
In [65]: counts
Out[65]: tensor([3, 4, 2])
如果在您的实际问题中,value
初始化为 0,那么您就完成了——只需使用 counts
作为结果。否则,将 counts
添加到 value
。
假设我有两个张量 value 和 index,它们包含我们需要的数据和所有索引。我想用相应的 index 向 value 中的数据添加一个。如果一个索引在张量中显示 k 次 index,那么这个数据应该被添加 k,而不是一个。
这是一个例子:
value = torch.zeros(3) # [0, 0, 0]
index = torch.zeros(10).long() #[0,0,0,0,0,0,0,0,0,0]
ret = some_func(value, index) # [10, 0, 0]
我知道用for循环遍历index中的所有索引可以解决问题,但我想问一下有没有更优雅的方法?
一种方法是使用 scatter_add
:
In [54]: value = torch.zeros(3)
In [55]: index = torch.tensor([0, 0, 1, 0, 2, 2, 1, 1, 1])
In [56]: value.scatter_add(0, index, torch.ones_like(index, dtype=value.dtype))
Out[56]: tensor([3., 4., 2.])
您可以使用scatter_add_
对value
进行原地操作。
您可能会发现使用 bincount()
:
In [63]: index = torch.tensor([0, 0, 1, 0, 2, 2, 1, 1, 1])
In [64]: counts = index.bincount(minlength=value.shape[0])
In [65]: counts
Out[65]: tensor([3, 4, 2])
如果在您的实际问题中,value
初始化为 0,那么您就完成了——只需使用 counts
作为结果。否则,将 counts
添加到 value
。