如何从现有的 TensorFlow 数据集中删除某些 类(标签和图像)? (时尚 MNIST)
How to remove certain classes (labels and images) from existing TensorFlow datasets? (Fashion MNIST)
我才刚刚开始学习更多关于 TensorFlow 和 numpy 的知识。我目前正在使用 TensorFlow 的 Fashion MNIST 数据集,其中包括 10 种服装。但是,我希望能够编辑包含这些数据集的 numpy 数组,以删除所有不是 'T-shirts'、'Shirts' 和 'Trousers' 的图像和标签。本质上,我只想从 Fashion MNIST 创建一个只有这 3 种类型的数据集。
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
以上是我目前导入数据集的方式,据我所知,有几种不同的方法可以在预处理数据集之前导入数据集。如何确保正确删除标签及其对应的图片,以便生成的标签和图像仍然相互对应?
from tensorflow.keras.datasets import fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# sorting based on index
idx = np.argsort(train_labels)
train_images = train_images[idx]
train_labels = train_labels[idx]
idx = np.argsort(test_labels)
test_images = test_images[idx]
test_labels = test_labels[idx]
labels = ["T-Shirt", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
label_mapping = dict(zip(labels, range(10)))
def get_data(mapping, classes):
X_train, X_test, y_train, y_test = [], [], [], []
for cls in classes:
idx = mapping[cls]
start = idx*6000
end = idx*6000+6000
X_train.append(train_images[start: end])
y_train.append(train_labels[start: end])
start = idx*1000
end = idx*1000+1000
X_test.append(test_images[start: end])
y_test.append(test_labels[start: end])
return X_train, X_test, y_train, y_test
X_train, X_test, y_train, y_test = get_data(label_mapping,
classes=["T-Shirt", "Shirt", "Trouser"])
您可以找到 类 和它们的标签 here
之间的映射
我才刚刚开始学习更多关于 TensorFlow 和 numpy 的知识。我目前正在使用 TensorFlow 的 Fashion MNIST 数据集,其中包括 10 种服装。但是,我希望能够编辑包含这些数据集的 numpy 数组,以删除所有不是 'T-shirts'、'Shirts' 和 'Trousers' 的图像和标签。本质上,我只想从 Fashion MNIST 创建一个只有这 3 种类型的数据集。
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
以上是我目前导入数据集的方式,据我所知,有几种不同的方法可以在预处理数据集之前导入数据集。如何确保正确删除标签及其对应的图片,以便生成的标签和图像仍然相互对应?
from tensorflow.keras.datasets import fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# sorting based on index
idx = np.argsort(train_labels)
train_images = train_images[idx]
train_labels = train_labels[idx]
idx = np.argsort(test_labels)
test_images = test_images[idx]
test_labels = test_labels[idx]
labels = ["T-Shirt", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
label_mapping = dict(zip(labels, range(10)))
def get_data(mapping, classes):
X_train, X_test, y_train, y_test = [], [], [], []
for cls in classes:
idx = mapping[cls]
start = idx*6000
end = idx*6000+6000
X_train.append(train_images[start: end])
y_train.append(train_labels[start: end])
start = idx*1000
end = idx*1000+1000
X_test.append(test_images[start: end])
y_test.append(test_labels[start: end])
return X_train, X_test, y_train, y_test
X_train, X_test, y_train, y_test = get_data(label_mapping,
classes=["T-Shirt", "Shirt", "Trouser"])
您可以找到 类 和它们的标签 here
之间的映射