查找具有最多 1 的 PyTorch 张量的列索引

Finding column index of a PyTorch tensor that has most 1's

我有一个 PyTorch 张量 a 形状如下:

import torch
a = torch.tensor([[[1., 0., 0., 0.]],
        [[0., 1., 0., 0.]],
        [[1., 0., 0., 0.]],
        [[0., 0., 0., 1.]],
        [[1., 0., 0., 0.]],
        [[0., 0., 0., 1.]],
        [[1., 0., 0., 0.]]])

张量的每一行 a 有 4 个元素,1 和 0。假设我相应地索引了这个张量的行和列。因此,例如,第 0 行(最上面的行)中的条目是 [[1., 0., 0., 0.]],而第 3 列(最右边的列)中的条目是 [[0., 0., 0., 1., 0., 1., 0.]].

从给定的张量中,我想确定 1. 出现最频繁的列的索引。例如,对于张量 a,这样一列的索引将为 0。如果 1 的数量相同,我仍然想获得所有这些相关的列索引。

如何在 Python 上轻松完成此任务?

谢谢,

如果你的矩阵只包含0和1,你可以对每一列的元素求和,然后搜索最大的和:

import numpy as np

% sum over columns
sumsi = torch.sum(a, dim=1)

% find where maximum
col_idx = np.where(sumsi==np.max(sumsi))