在 ndarray 中搜索单热编码标签
Search for a one-hot encoded label in ndarray
我有一个名为 labels
的 ndarray
,其形状为 (6000, 8)
。这是具有 8 个类别的 6000 个单热编码数组。我想搜索如下所示的标签:
[1,0,0,0,0,0,0,0]
然后尝试这样做
np.where(labels==[1,0,0,0,0,0,0,0,0])
但这并没有产生预期的结果
您需要 all
沿第二个轴:
np.where((labels == [1,0,0,0,0,0,0,0]).all(1))
查看这个较小的示例:
labels = np.array([[1,0,0,1,0,0,0,0],
[0,0,0,0,0,1,1,0],
[1,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,1]])
(labels == [1,0,0,0,0,0,0,0])
array([[ True, True, True, False, True, True, True, True],
[False, True, True, True, True, False, False, True],
[ True, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, False]])
请注意,上面的比较只是 returns 一个与 labels
相同形状的数组,因为比较是沿着 labels
的行进行的。需要用all
进行聚合,检查一行中的所有元素是否都是True
:
(labels == [1,0,0,0,0,0,0,0]).all(1)
#array([False, False, True, False])
我有一个名为 labels
的 ndarray
,其形状为 (6000, 8)
。这是具有 8 个类别的 6000 个单热编码数组。我想搜索如下所示的标签:
[1,0,0,0,0,0,0,0]
然后尝试这样做
np.where(labels==[1,0,0,0,0,0,0,0,0])
但这并没有产生预期的结果
您需要 all
沿第二个轴:
np.where((labels == [1,0,0,0,0,0,0,0]).all(1))
查看这个较小的示例:
labels = np.array([[1,0,0,1,0,0,0,0],
[0,0,0,0,0,1,1,0],
[1,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,1]])
(labels == [1,0,0,0,0,0,0,0])
array([[ True, True, True, False, True, True, True, True],
[False, True, True, True, True, False, False, True],
[ True, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, False]])
请注意,上面的比较只是 returns 一个与 labels
相同形状的数组,因为比较是沿着 labels
的行进行的。需要用all
进行聚合,检查一行中的所有元素是否都是True
:
(labels == [1,0,0,0,0,0,0,0]).all(1)
#array([False, False, True, False])