在 ndarray 中搜索单热编码标签

Search for a one-hot encoded label in ndarray

我有一个名为 labelsndarray,其形状为 (6000, 8)。这是具有 8 个类别的 6000 个单热编码数组。我想搜索如下所示的标签:

[1,0,0,0,0,0,0,0]

然后尝试这样做

np.where(labels==[1,0,0,0,0,0,0,0,0])

但这并没有产生预期的结果

您需要 all 沿第二个轴:

np.where((labels == [1,0,0,0,0,0,0,0]).all(1))

查看这个较小的示例:

labels = np.array([[1,0,0,1,0,0,0,0], 
                   [0,0,0,0,0,1,1,0], 
                   [1,0,0,0,0,0,0,0], 
                   [0,0,0,0,0,0,0,1]])

(labels == [1,0,0,0,0,0,0,0])

array([[ True,  True,  True, False,  True,  True,  True,  True],
       [False,  True,  True,  True,  True, False, False,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [False,  True,  True,  True,  True,  True,  True, False]])

请注意,上面的比较只是 returns 一个与 labels 相同形状的数组,因为比较是沿着 labels 的行进行的。需要用all进行聚合,检查一行中的所有元素是否都是True:

(labels == [1,0,0,0,0,0,0,0]).all(1)
 #array([False, False,  True, False])