Class CNN 中的权重
Class weights in CNN
我有一个非常不平衡的数据集。首先,我将这个数据集分为训练数据集(80%)和验证数据集(20%)。我使用了 StratifiedShuffleSplit
,所以两个数据集都保留了每个 class 百分比。
为了解决两个数据集不平衡的问题,我使用了 class_weight
。这是我为此使用的代码:
class_weight = {0: 70.,
1: 110.,
2: 82.,
3: 17.,
4: 9.}
model.fit(train_generator, epochs = 5, class_weight=(class_weight), validation_data=(x_val, y_val))
变量class_weight
目前有整个数据集每个class的图像数量,即训练和验证数据集的组合。应该那样做吗?或者它应该有训练数据集的图像?
我还有一个问题。假设我进行数据扩充,我如何确定每个 class 的图像数量?有自动计算器之类的吗?
您似乎为 classes 硬编码了一些权重值。但是,您可以使用 sklearn.utils.class_weight.compute_class_weight
进行 class 加权来解决不平衡的数据集。它将根据 classes.
的出现计算适当的权重值
# imports
from sklearn.utils import class_weight
# compute class weight
# based on appearance of each class in y_trian
cls_wgts = class_weight.compute_class_weight('balanced',
sorted(np.unique(y_train)),
y_train)
# dict mapping
cls_wgts = {i : cls_wgts[i] for i, label in enumerate(sorted(np.unique(y_train)))}
# pass it to fit
model.fit(..., class_weight=cls_wgts)
关于你的第二个问题,如果我理解正确的话,我们通常不知道在训练时间内每 class 会发生多少增强。但是我们可以控制数据生成器中的设置,其中 minor
classes 与 major
classes 相比会得到更多增强。此外,您还可以在这里使用加权交叉熵损失函数来处理class不平衡。
我有一个非常不平衡的数据集。首先,我将这个数据集分为训练数据集(80%)和验证数据集(20%)。我使用了 StratifiedShuffleSplit
,所以两个数据集都保留了每个 class 百分比。
为了解决两个数据集不平衡的问题,我使用了 class_weight
。这是我为此使用的代码:
class_weight = {0: 70.,
1: 110.,
2: 82.,
3: 17.,
4: 9.}
model.fit(train_generator, epochs = 5, class_weight=(class_weight), validation_data=(x_val, y_val))
变量class_weight
目前有整个数据集每个class的图像数量,即训练和验证数据集的组合。应该那样做吗?或者它应该有训练数据集的图像?
我还有一个问题。假设我进行数据扩充,我如何确定每个 class 的图像数量?有自动计算器之类的吗?
您似乎为 classes 硬编码了一些权重值。但是,您可以使用 sklearn.utils.class_weight.compute_class_weight
进行 class 加权来解决不平衡的数据集。它将根据 classes.
# imports
from sklearn.utils import class_weight
# compute class weight
# based on appearance of each class in y_trian
cls_wgts = class_weight.compute_class_weight('balanced',
sorted(np.unique(y_train)),
y_train)
# dict mapping
cls_wgts = {i : cls_wgts[i] for i, label in enumerate(sorted(np.unique(y_train)))}
# pass it to fit
model.fit(..., class_weight=cls_wgts)
关于你的第二个问题,如果我理解正确的话,我们通常不知道在训练时间内每 class 会发生多少增强。但是我们可以控制数据生成器中的设置,其中 minor
classes 与 major
classes 相比会得到更多增强。此外,您还可以在这里使用加权交叉熵损失函数来处理class不平衡。