使用混淆矩阵理解多标签分类器
Understanding multi-label classifier using confusion matrix
我有 12 个 classes 的多标签 class化问题。我正在使用 Tensorflow
的 slim
来使用在 ImageNet
上预训练的模型来训练模型。以下是每个 class 在训练和验证
中的存在百分比
Training Validation
class0 44.4 25
class1 55.6 50
class2 50 25
class3 55.6 50
class4 44.4 50
class5 50 75
class6 50 75
class7 55.6 50
class8 88.9 50
class9 88.9 50
class10 50 25
class11 72.2 25
问题是模型没有收敛,验证集上的 ROC
曲线下 (Az
) 很差,类似于:
Az
class0 0.99
class1 0.44
class2 0.96
class3 0.9
class4 0.99
class5 0.01
class6 0.52
class7 0.65
class8 0.97
class9 0.82
class10 0.09
class11 0.5
Average 0.65
我不知道为什么它对某些 class 有效而对其他人却无效。我决定深入细节,看看神经网络在学习什么。我知道混淆矩阵仅适用于二进制或多class class化。因此,为了能够绘制它,我不得不将问题转换为对多 class class 化。尽管模型是使用 sigmoid
训练的,为每个 class 提供预测,但对于下面混淆矩阵中的每个单元格,我显示的是概率的平均值(通过应用 sigmoid
对图像的 tensorflow 预测的函数,其中矩阵行中的 class 存在且列中的 class 不存在。这应用于验证集图像。通过这种方式,我认为我可以获得有关模型正在学习的内容的更多详细信息。为了显示目的,我只是圈出了对角线元素。
我的解读是:
- Classes 0 和 4 在存在时检测到存在,在不存在时检测到不存在。这意味着这些 class 已被很好地检测到。
- Classes 2、6 和 7 始终被检测为不存在。这不是我要找的。
- Classes 3、8 和 9 始终被检测为存在。这不是我要找的。这可以应用于class 11.
- Class 5 不存在时检测存在,存在时检测为不存在。是反检测的。
- Classes 3 & 10:我认为我们不能为这 2 classes 提取太多信息。
我的问题是解释。我不确定问题出在哪里,也不确定数据集中是否存在产生此类结果的偏差。我还想知道是否有一些指标可以帮助解决多标签 class 化问题?你能和我分享你对这种混淆矩阵的解释吗? what/where 看下一个?对其他指标的一些建议会很棒。
谢谢。
编辑:
我将问题转换为多class class化,因此对于每对classes(例如0,1)计算概率(class 0, class 1), 表示为 p(0,1)
:
我对存在工具 0 而工具 1 不存在的图像进行工具 1 的预测,并通过应用 sigmoid 函数将它们转换为概率,然后显示这些概率的平均值。对于 p(1, 0)
,我对工具 0 执行相同操作,但现在使用存在工具 1 而工具 0 不存在的图像。对于 p(0, 0)
,我使用存在工具 0 的所有图像。考虑到上图中的 p(0,4)
,N/A 意味着不存在工具 0 存在且工具 4 不存在的图像。
以下是 2 个子集的图像数量:
- 169320 张训练图像
- 37440 张图片用于验证
这是在训练集上计算的混淆矩阵(计算方式与前面描述的验证集相同),但这次颜色代码是用于计算每个概率的图像数量:
已编辑:
对于数据增强,我对网络中的每个输入图像进行随机平移、旋转和缩放。此外,这里有一些关于工具的信息:
class 0 shape is completely different than the other objects.
class 1 resembles strongly to class 4.
class 2 shape resembles to class 1 & 4 but it's always accompanied by an object different than the others objects in the scene. As a whole, it is different than the other objects.
class 3 shape is completely different than the other objects.
class 4 resembles strongly to class 1
class 5 have common shape with classes 6 & 7 (we can say that they are all from the same category of objects)
class 6 resembles strongly to class 7
class 7 resembles strongly to class 6
class 8 shape is completely different than the other objects.
class 9 resembles strongly to class 10
class 10 resembles strongly to class 9
class 11 shape is completely different than the other objects.
已编辑:
这是下面为训练集建议的代码的输出:
Avg. num labels per image = 6.892700212615167
On average, images with label 0 also have 6.365296803652968 other labels.
On average, images with label 1 also have 6.601033718926901 other labels.
On average, images with label 2 also have 6.758548914659531 other labels.
On average, images with label 3 also have 6.131520940484937 other labels.
On average, images with label 4 also have 6.219187208527648 other labels.
On average, images with label 5 also have 6.536933407946279 other labels.
On average, images with label 6 also have 6.533908387864367 other labels.
On average, images with label 7 also have 6.485973817793214 other labels.
On average, images with label 8 also have 6.1241642788920725 other labels.
On average, images with label 9 also have 5.94092288040875 other labels.
On average, images with label 10 also have 6.983303518187239 other labels.
On average, images with label 11 also have 6.1974066621953945 other labels.
对于验证集:
Avg. num labels per image = 6.001282051282051
On average, images with label 0 also have 6.0 other labels.
On average, images with label 1 also have 3.987080103359173 other labels.
On average, images with label 2 also have 6.0 other labels.
On average, images with label 3 also have 5.507731958762887 other labels.
On average, images with label 4 also have 5.506459948320414 other labels.
On average, images with label 5 also have 5.00169779286927 other labels.
On average, images with label 6 also have 5.6729452054794525 other labels.
On average, images with label 7 also have 6.0 other labels.
On average, images with label 8 also have 6.0 other labels.
On average, images with label 9 also have 5.506459948320414 other labels.
On average, images with label 10 also have 3.0 other labels.
On average, images with label 11 also have 4.666095890410959 other labels.
评论:
我认为这不仅与分布之间的差异有关,因为如果模型能够很好地泛化 class 10(意味着在训练过程中可以正确识别对象,如 class 0),则验证集的准确性就足够了。我的意思是,问题在于训练集本身及其构建方式,而不是两种分布之间的差异。它可以是:class 或对象的出现频率非常相似(如 class 10 与 class 9 非常相似的情况)或数据集或薄对象内部的偏差(代表输入图像中可能有 1% 或 2% 的像素,如 class 2)。我并不是说问题是其中之一,但我只是想指出我认为这不仅仅是两个分布之间的差异。
输出校准
我认为首先要意识到的一件事是神经网络的输出可能 校准 很差。我的意思是,它给不同实例的输出可能会产生良好的排名(带有标签 L 的图像往往比没有标签 L 的图像具有更高的标签分数),但这些分数不能总是可靠地解释为概率(对于没有标签的实例,它可能会给出非常高的分数,如 0.9
,而对于带有标签的实例,它可能会给出更高的分数,如 0.99
)。我想这是否会发生取决于您选择的损失函数等因素。
有关这方面的更多信息,请参阅示例:https://arxiv.org/abs/1706.04599
通过所有 classes 1 乘 1
Class 0: AUC(曲线下面积)= 0.99。这是一个很好的成绩。混淆矩阵中的第 0 列看起来也不错,所以这里没有问题。
Class 1: AUC = 0.44。这太糟糕了,低于 0.5,如果我没记错的话,这几乎意味着你最好故意做与你的网络对该标签的预测相反的。
查看混淆矩阵中的第 1 列,它在所有地方的分数几乎相同。对我来说,这表明网络没有设法了解很多关于这个 class 的信息,根据训练集中包含该标签的图像的百分比 (55.6%),几乎只是 "guesses"。由于这个百分比在验证集中下降到 50%,这个策略确实意味着它会比随机的稍微差一点。尽管如此,第 1 行仍然是该列所有行中数量最多的,因此它似乎至少学到了一点点,但并不多。
Class 2: AUC = 0.96。这是非常好的。
您对此 class 的解释是,根据整个列的浅色阴影,它始终被预测为不存在。我不认为这种解释是正确的。查看它在对角线上的得分如何 > 0,而列中其他地方的得分仅为 0。它在该行中的得分可能相对较低,但很容易与同一列中的其他行分开。您可能只需要设置选择该标签是否相对较低的阈值。我怀疑这是由于上面提到的校准问题。
这也是AUC其实很好的原因;有可能 select 一个阈值,使得大多数分数高于阈值的实例正确地具有标签,而大多数低于它的实例正确地没有。不过,该阈值可能不是 0.5,如果您假设校准良好,这是您可能期望的阈值。绘制此特定标签的 ROC 曲线可能会帮助您准确确定阈值的位置。
Class 3: AUC = 0.9,相当不错。
您将其解释为始终被检测为存在,并且混淆矩阵的列中确实有很多高数字,但 AUC 很好并且对角线上的单元格确实具有足够高的值,它可能很容易与其他人分开。我怀疑这是与 Class 2 类似的情况(只是翻转过来,到处都是高预测,因此正确决策需要高阈值)。
如果您想确定 well-selected 阈值是否确实可以正确地将大多数 "positives"(具有 class 3 的实例)从大多数 "negatives" 中分离出来(没有 class 3 的实例),您需要根据标签 3 的预测分数对所有实例进行排序,然后遍历整个列表并在每对连续条目之间计算您将获得的验证集的准确性如果您决定将阈值放在那里,select 最佳阈值。
Class 4: 等同于 class 0.
Class 5: AUC = 0.01,显然很糟糕。也同意你对混淆矩阵的解释。很难确定为什么它在这里表现如此糟糕。也许这是一种难以识别的物体?可能还有一些过度拟合正在进行(从第二个矩阵的列判断训练数据中有 0 个误报,尽管也有其他 classes 发生这种情况)。
从训练数据到验证数据,标签 5 图像的比例增加可能也无济于事。这意味着网络在训练期间在此标签上的表现不如在验证期间重要。
Class 6: AUC = 0.52,只比随机好一点。
从第一个矩阵的第 6 列判断,这实际上可能与 class 2 的情况类似。不过,如果我们也考虑 AUC,看起来它并没有很好地学习对实例进行排名好吧。类似于 class 5,只是没有那么糟糕。同样,训练和 vlidation分布完全不同。
Class 7: AUC = 0.65,相当平均。例如,显然不如 class 2 好,但也没有您从矩阵中解释的那么差。
Class 8: AUC = 0.97,很好,类似于class 3.
Class 9: AUC = 0.82,不太好,但仍然不错。矩阵中的列有很多黑色单元格,而且数字非常接近,在我看来 AUC 出奇地好。它几乎出现在训练数据的每张图像中,因此预测它经常出现也就不足为奇了。也许其中一些非常暗的细胞仅基于绝对数量较少的图像?这将很有趣。
Class 10: AUC = 0.09,太可怕了。对角线上的 0 非常令人担忧(您的数据标记正确吗?)。根据第一个矩阵的第 10 行,它似乎经常混淆 classes 3 和 9(棉花和 primary_incision_knives 看起来很像 secondary_incision_knives 吗?)。也可能对训练数据有些过拟合。
Class 11: AUC = 0.5,不比随机好。性能不佳(以及矩阵中明显过高的分数)可能是因为该标签存在于大多数训练图像中,但仅存在于少数验证图像中。
还有什么要绘制/测量的?
为了更深入地了解您的数据,我首先绘制热图,显示每个 class co-occurs 的频率(一张用于训练,一张用于验证数据)。单元格 (i, j) 将根据同时包含标签 i 和 j 的图像的比例进行着色。这将是一个对称图,对角线上的单元格根据您问题中的第一个数字列表着色。比较这两个热图,看看它们有什么不同,看看这是否有助于解释您的模型的性能。
此外,了解(对于两个数据集)每个图像平均有多少个不同的标签,以及对于每个单独的标签,它平均与多少个其他标签共享一个图像可能很有用。例如,我怀疑标签为 10 的图像在训练数据中的其他标签相对较少。如果网络识别出其他事物,这可能会阻止网络预测标签 10,并且如果标签 10 确实突然与验证数据中的其他对象更频繁地共享图像,则会导致性能不佳。由于伪代码比文字更容易表达意思,打印如下内容可能会很有趣:
# Do all of the following once for training data, AND once for validation data
tot_num_labels = 0
for image in images:
tot_num_labels += len(image.get_all_labels())
avg_labels_per_image = tot_num_labels / float(num_images)
print("Avg. num labels per image = ", avg_labels_per_image)
for label in range(num_labels):
tot_shared_labels = 0
for image in images_with_label(label):
tot_shared_labels += (len(image.get_all_labels()) - 1)
avg_shared_labels = tot_shared_labels / float(len(images_with_label(label)))
print("On average, images with label ", label, " also have ", avg_shared_labels, " other labels.")
对于单个数据集,这并不能提供太多有用的信息,但是如果您对训练集和验证集执行此操作,如果数字非常不同,您可以看出它们的分布非常不同
最后,我有点担心第一个矩阵中的某些列如何完全 相同的平均预测出现在许多不同的行上。我不太确定是什么导致了这种情况,但这可能有助于调查。
如何改进?
如果您还没有,我建议您查看 数据扩充 以获取您的训练数据。由于您正在处理图像,因此您可以尝试将现有图像的旋转版本添加到您的数据中。
对于您的 multi-label 具体情况,目标是检测不同类型的对象,尝试将一堆不同的图像(例如两张或四张图像)简单地连接在一起可能也很有趣。然后,您可以将它们缩小到原始图像大小,并作为标签分配原始标签集的并集。合并图像的边缘会出现有趣的不连续点,我不知道这是否有害。也许它不适合你的 multi-object 检测,我认为值得一试。
我有 12 个 classes 的多标签 class化问题。我正在使用 Tensorflow
的 slim
来使用在 ImageNet
上预训练的模型来训练模型。以下是每个 class 在训练和验证
Training Validation
class0 44.4 25
class1 55.6 50
class2 50 25
class3 55.6 50
class4 44.4 50
class5 50 75
class6 50 75
class7 55.6 50
class8 88.9 50
class9 88.9 50
class10 50 25
class11 72.2 25
问题是模型没有收敛,验证集上的 ROC
曲线下 (Az
) 很差,类似于:
Az
class0 0.99
class1 0.44
class2 0.96
class3 0.9
class4 0.99
class5 0.01
class6 0.52
class7 0.65
class8 0.97
class9 0.82
class10 0.09
class11 0.5
Average 0.65
我不知道为什么它对某些 class 有效而对其他人却无效。我决定深入细节,看看神经网络在学习什么。我知道混淆矩阵仅适用于二进制或多class class化。因此,为了能够绘制它,我不得不将问题转换为对多 class class 化。尽管模型是使用 sigmoid
训练的,为每个 class 提供预测,但对于下面混淆矩阵中的每个单元格,我显示的是概率的平均值(通过应用 sigmoid
对图像的 tensorflow 预测的函数,其中矩阵行中的 class 存在且列中的 class 不存在。这应用于验证集图像。通过这种方式,我认为我可以获得有关模型正在学习的内容的更多详细信息。为了显示目的,我只是圈出了对角线元素。
我的解读是:
- Classes 0 和 4 在存在时检测到存在,在不存在时检测到不存在。这意味着这些 class 已被很好地检测到。
- Classes 2、6 和 7 始终被检测为不存在。这不是我要找的。
- Classes 3、8 和 9 始终被检测为存在。这不是我要找的。这可以应用于class 11.
- Class 5 不存在时检测存在,存在时检测为不存在。是反检测的。
- Classes 3 & 10:我认为我们不能为这 2 classes 提取太多信息。
我的问题是解释。我不确定问题出在哪里,也不确定数据集中是否存在产生此类结果的偏差。我还想知道是否有一些指标可以帮助解决多标签 class 化问题?你能和我分享你对这种混淆矩阵的解释吗? what/where 看下一个?对其他指标的一些建议会很棒。
谢谢。
编辑:
我将问题转换为多class class化,因此对于每对classes(例如0,1)计算概率(class 0, class 1), 表示为 p(0,1)
:
我对存在工具 0 而工具 1 不存在的图像进行工具 1 的预测,并通过应用 sigmoid 函数将它们转换为概率,然后显示这些概率的平均值。对于 p(1, 0)
,我对工具 0 执行相同操作,但现在使用存在工具 1 而工具 0 不存在的图像。对于 p(0, 0)
,我使用存在工具 0 的所有图像。考虑到上图中的 p(0,4)
,N/A 意味着不存在工具 0 存在且工具 4 不存在的图像。
以下是 2 个子集的图像数量:
- 169320 张训练图像
- 37440 张图片用于验证
这是在训练集上计算的混淆矩阵(计算方式与前面描述的验证集相同),但这次颜色代码是用于计算每个概率的图像数量:
已编辑: 对于数据增强,我对网络中的每个输入图像进行随机平移、旋转和缩放。此外,这里有一些关于工具的信息:
class 0 shape is completely different than the other objects.
class 1 resembles strongly to class 4.
class 2 shape resembles to class 1 & 4 but it's always accompanied by an object different than the others objects in the scene. As a whole, it is different than the other objects.
class 3 shape is completely different than the other objects.
class 4 resembles strongly to class 1
class 5 have common shape with classes 6 & 7 (we can say that they are all from the same category of objects)
class 6 resembles strongly to class 7
class 7 resembles strongly to class 6
class 8 shape is completely different than the other objects.
class 9 resembles strongly to class 10
class 10 resembles strongly to class 9
class 11 shape is completely different than the other objects.
已编辑: 这是下面为训练集建议的代码的输出:
Avg. num labels per image = 6.892700212615167
On average, images with label 0 also have 6.365296803652968 other labels.
On average, images with label 1 also have 6.601033718926901 other labels.
On average, images with label 2 also have 6.758548914659531 other labels.
On average, images with label 3 also have 6.131520940484937 other labels.
On average, images with label 4 also have 6.219187208527648 other labels.
On average, images with label 5 also have 6.536933407946279 other labels.
On average, images with label 6 also have 6.533908387864367 other labels.
On average, images with label 7 also have 6.485973817793214 other labels.
On average, images with label 8 also have 6.1241642788920725 other labels.
On average, images with label 9 also have 5.94092288040875 other labels.
On average, images with label 10 also have 6.983303518187239 other labels.
On average, images with label 11 also have 6.1974066621953945 other labels.
对于验证集:
Avg. num labels per image = 6.001282051282051
On average, images with label 0 also have 6.0 other labels.
On average, images with label 1 also have 3.987080103359173 other labels.
On average, images with label 2 also have 6.0 other labels.
On average, images with label 3 also have 5.507731958762887 other labels.
On average, images with label 4 also have 5.506459948320414 other labels.
On average, images with label 5 also have 5.00169779286927 other labels.
On average, images with label 6 also have 5.6729452054794525 other labels.
On average, images with label 7 also have 6.0 other labels.
On average, images with label 8 also have 6.0 other labels.
On average, images with label 9 also have 5.506459948320414 other labels.
On average, images with label 10 also have 3.0 other labels.
On average, images with label 11 also have 4.666095890410959 other labels.
评论: 我认为这不仅与分布之间的差异有关,因为如果模型能够很好地泛化 class 10(意味着在训练过程中可以正确识别对象,如 class 0),则验证集的准确性就足够了。我的意思是,问题在于训练集本身及其构建方式,而不是两种分布之间的差异。它可以是:class 或对象的出现频率非常相似(如 class 10 与 class 9 非常相似的情况)或数据集或薄对象内部的偏差(代表输入图像中可能有 1% 或 2% 的像素,如 class 2)。我并不是说问题是其中之一,但我只是想指出我认为这不仅仅是两个分布之间的差异。
输出校准
我认为首先要意识到的一件事是神经网络的输出可能 校准 很差。我的意思是,它给不同实例的输出可能会产生良好的排名(带有标签 L 的图像往往比没有标签 L 的图像具有更高的标签分数),但这些分数不能总是可靠地解释为概率(对于没有标签的实例,它可能会给出非常高的分数,如 0.9
,而对于带有标签的实例,它可能会给出更高的分数,如 0.99
)。我想这是否会发生取决于您选择的损失函数等因素。
有关这方面的更多信息,请参阅示例:https://arxiv.org/abs/1706.04599
通过所有 classes 1 乘 1
Class 0: AUC(曲线下面积)= 0.99。这是一个很好的成绩。混淆矩阵中的第 0 列看起来也不错,所以这里没有问题。
Class 1: AUC = 0.44。这太糟糕了,低于 0.5,如果我没记错的话,这几乎意味着你最好故意做与你的网络对该标签的预测相反的。
查看混淆矩阵中的第 1 列,它在所有地方的分数几乎相同。对我来说,这表明网络没有设法了解很多关于这个 class 的信息,根据训练集中包含该标签的图像的百分比 (55.6%),几乎只是 "guesses"。由于这个百分比在验证集中下降到 50%,这个策略确实意味着它会比随机的稍微差一点。尽管如此,第 1 行仍然是该列所有行中数量最多的,因此它似乎至少学到了一点点,但并不多。
Class 2: AUC = 0.96。这是非常好的。
您对此 class 的解释是,根据整个列的浅色阴影,它始终被预测为不存在。我不认为这种解释是正确的。查看它在对角线上的得分如何 > 0,而列中其他地方的得分仅为 0。它在该行中的得分可能相对较低,但很容易与同一列中的其他行分开。您可能只需要设置选择该标签是否相对较低的阈值。我怀疑这是由于上面提到的校准问题。
这也是AUC其实很好的原因;有可能 select 一个阈值,使得大多数分数高于阈值的实例正确地具有标签,而大多数低于它的实例正确地没有。不过,该阈值可能不是 0.5,如果您假设校准良好,这是您可能期望的阈值。绘制此特定标签的 ROC 曲线可能会帮助您准确确定阈值的位置。
Class 3: AUC = 0.9,相当不错。
您将其解释为始终被检测为存在,并且混淆矩阵的列中确实有很多高数字,但 AUC 很好并且对角线上的单元格确实具有足够高的值,它可能很容易与其他人分开。我怀疑这是与 Class 2 类似的情况(只是翻转过来,到处都是高预测,因此正确决策需要高阈值)。
如果您想确定 well-selected 阈值是否确实可以正确地将大多数 "positives"(具有 class 3 的实例)从大多数 "negatives" 中分离出来(没有 class 3 的实例),您需要根据标签 3 的预测分数对所有实例进行排序,然后遍历整个列表并在每对连续条目之间计算您将获得的验证集的准确性如果您决定将阈值放在那里,select 最佳阈值。
Class 4: 等同于 class 0.
Class 5: AUC = 0.01,显然很糟糕。也同意你对混淆矩阵的解释。很难确定为什么它在这里表现如此糟糕。也许这是一种难以识别的物体?可能还有一些过度拟合正在进行(从第二个矩阵的列判断训练数据中有 0 个误报,尽管也有其他 classes 发生这种情况)。
从训练数据到验证数据,标签 5 图像的比例增加可能也无济于事。这意味着网络在训练期间在此标签上的表现不如在验证期间重要。
Class 6: AUC = 0.52,只比随机好一点。
从第一个矩阵的第 6 列判断,这实际上可能与 class 2 的情况类似。不过,如果我们也考虑 AUC,看起来它并没有很好地学习对实例进行排名好吧。类似于 class 5,只是没有那么糟糕。同样,训练和 vlidation分布完全不同。
Class 7: AUC = 0.65,相当平均。例如,显然不如 class 2 好,但也没有您从矩阵中解释的那么差。
Class 8: AUC = 0.97,很好,类似于class 3.
Class 9: AUC = 0.82,不太好,但仍然不错。矩阵中的列有很多黑色单元格,而且数字非常接近,在我看来 AUC 出奇地好。它几乎出现在训练数据的每张图像中,因此预测它经常出现也就不足为奇了。也许其中一些非常暗的细胞仅基于绝对数量较少的图像?这将很有趣。
Class 10: AUC = 0.09,太可怕了。对角线上的 0 非常令人担忧(您的数据标记正确吗?)。根据第一个矩阵的第 10 行,它似乎经常混淆 classes 3 和 9(棉花和 primary_incision_knives 看起来很像 secondary_incision_knives 吗?)。也可能对训练数据有些过拟合。
Class 11: AUC = 0.5,不比随机好。性能不佳(以及矩阵中明显过高的分数)可能是因为该标签存在于大多数训练图像中,但仅存在于少数验证图像中。
还有什么要绘制/测量的?
为了更深入地了解您的数据,我首先绘制热图,显示每个 class co-occurs 的频率(一张用于训练,一张用于验证数据)。单元格 (i, j) 将根据同时包含标签 i 和 j 的图像的比例进行着色。这将是一个对称图,对角线上的单元格根据您问题中的第一个数字列表着色。比较这两个热图,看看它们有什么不同,看看这是否有助于解释您的模型的性能。
此外,了解(对于两个数据集)每个图像平均有多少个不同的标签,以及对于每个单独的标签,它平均与多少个其他标签共享一个图像可能很有用。例如,我怀疑标签为 10 的图像在训练数据中的其他标签相对较少。如果网络识别出其他事物,这可能会阻止网络预测标签 10,并且如果标签 10 确实突然与验证数据中的其他对象更频繁地共享图像,则会导致性能不佳。由于伪代码比文字更容易表达意思,打印如下内容可能会很有趣:
# Do all of the following once for training data, AND once for validation data
tot_num_labels = 0
for image in images:
tot_num_labels += len(image.get_all_labels())
avg_labels_per_image = tot_num_labels / float(num_images)
print("Avg. num labels per image = ", avg_labels_per_image)
for label in range(num_labels):
tot_shared_labels = 0
for image in images_with_label(label):
tot_shared_labels += (len(image.get_all_labels()) - 1)
avg_shared_labels = tot_shared_labels / float(len(images_with_label(label)))
print("On average, images with label ", label, " also have ", avg_shared_labels, " other labels.")
对于单个数据集,这并不能提供太多有用的信息,但是如果您对训练集和验证集执行此操作,如果数字非常不同,您可以看出它们的分布非常不同
最后,我有点担心第一个矩阵中的某些列如何完全 相同的平均预测出现在许多不同的行上。我不太确定是什么导致了这种情况,但这可能有助于调查。
如何改进?
如果您还没有,我建议您查看 数据扩充 以获取您的训练数据。由于您正在处理图像,因此您可以尝试将现有图像的旋转版本添加到您的数据中。
对于您的 multi-label 具体情况,目标是检测不同类型的对象,尝试将一堆不同的图像(例如两张或四张图像)简单地连接在一起可能也很有趣。然后,您可以将它们缩小到原始图像大小,并作为标签分配原始标签集的并集。合并图像的边缘会出现有趣的不连续点,我不知道这是否有害。也许它不适合你的 multi-object 检测,我认为值得一试。