如何从现有的 TensorFlow 数据集中删除某些 类(标签和图像)? (时尚 MNIST)

How to remove certain classes (labels and images) from existing TensorFlow datasets? (Fashion MNIST)

我才刚刚开始学习更多关于 TensorFlow 和 numpy 的知识。我目前正在使用 TensorFlow 的 Fashion MNIST 数据集,其中包括 10 种服装。但是,我希望能够编辑包含这些数据集的 numpy 数组,以删除所有不是 'T-shirts'、'Shirts' 和 'Trousers' 的图像和标签。本质上,我只想从 Fashion MNIST 创建一个只有这 3 种类型的数据集。

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

以上是我目前导入数据集的方式,据我所知,有几种不同的方法可以在预处理数据集之前导入数据集。如何确保正确删除标签及其对应的图片,以便生成的标签和图像仍然相互对应?

from tensorflow.keras.datasets import fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# sorting based on index
idx = np.argsort(train_labels)
train_images = train_images[idx]
train_labels = train_labels[idx]

idx = np.argsort(test_labels)
test_images = test_images[idx]
test_labels = test_labels[idx]

labels = ["T-Shirt", "Trouser", "Pullover", "Dress", "Coat", 
          "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

label_mapping = dict(zip(labels, range(10)))

def get_data(mapping, classes):
    X_train, X_test, y_train, y_test = [], [], [], []
    for cls in classes:
        idx = mapping[cls]
        start = idx*6000
        end = idx*6000+6000
        X_train.append(train_images[start: end])
        y_train.append(train_labels[start: end])
        start = idx*1000
        end = idx*1000+1000
        X_test.append(test_images[start: end])
        y_test.append(test_labels[start: end])
    return X_train, X_test, y_train, y_test


X_train, X_test, y_train, y_test = get_data(label_mapping, 
                                            classes=["T-Shirt", "Shirt", "Trouser"])

您可以找到 类 和它们的标签 here

之间的映射