是否有一个 pytorch 函数可以在给定的张量(大小为 N*h*w*2)中找到唯一的元组?

Is there a pytorch function to find unique tuples in a given tensor (of size N*h*w*2)?

我正在尝试提取 (N * h * w * 2) 张量中的唯一元组。

例如,有 6 个元组的 1 * 2 * 3 * 2 张量:a = torch.tensor([[[[1,2], [2,3], [3,4]], [[4,5], [1,2], [3,4]]]])

我正在尝试查找唯一元组的索引(即 [1,2], [2,3], [3,4], [4,5] 的索引,其中删除了重复项)。

我已经检查过 torch.unique(),但它似乎不起作用。

您计算所有对之间的差异:

d = torch.abs(a.view(-1, 1, 2) - a.view(1, -1, 2)).sum(dim=-1)

然后你可以找到零差对(使用 triu 屏蔽非唯一对):

i, j = torch.where((d + torch.triu(torch.ones_like(d))) == 0)

结果为:

i,j
(tensor([4, 5]), tensor([0, 2]))

a中的第4对与第0对相同,第5对与第二对相同。