如何根据特定 class 名称加载 CIFAR-10 数据集?

How to load CIFAR-10 datasets based on specific class name?

我正在使用 CIFAR-10 数据集进行深度学习,但我只想指定我的数据集用于水果 class。我们知道我们使用了:

(X_train, y_train), (X_test, y_test) = cifar10.load_data()

加载所有 CIFAR-10 数据集。如何只加载水果 class 的数据而不是所有数据?

如果您不介意加载额外的数据,最简单的方法是找出 是水果标签并执行如下操作:X_train, y_train = X_train[y_train == fruit_label], y_train[y_train == fruit_label],前提是您的数据已存储在 np.arrays。等效于您的测试集。

否则,您将不得不修改您的 hdf5 文件或您存储数据的位置。