tf.data 使用标签谓词过滤数据集
tf.data filter dataset using label predicate
我正在尝试使用下面给出的特定标签过滤 CIFAR10 训练和测试数据,
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
数据集
dataset = datasets.cifar10.load_data()
拆分数据集
train_data = tf.data.Dataset.from_tensor_slices((dataset[0][0],dataset[0][1]))
test_data = tf.data.Dataset.from_tensor_slices((dataset[1][0],dataset[1][1]))
过滤功能
def filter_f(datas,filter_labels = tf.constant([0,1,2])):
x = tf.not_equal(datas[1],filter_labels)
x = tf.reduce_sum(tf.cast(x, tf.uint8))
return tf.greater(x, tf.constant(0,tf.uint8))
dataset = train_data.filter(filter_f).batch(200)
根据 。但是,上面代码中的过滤函数returns是未过滤的。
labels = []
for i, x in enumerate(tfds.as_numpy(dataset)):
labels.append(x[1][0][0])
print(labels)
Returns
[4, 7, 5, 6, 0, 5, 5, 6, 5, 3, 6, 7, 0, 0, 6, 3]
要重现结果,请使用此 colab link
我不确定下面的确切问题。不过,如果您只需要删除属于特定 class 的数据,您可以使用以下命令。
dataset = train_data.filter(lambda x,y: tf.reduce_all(tf.not_equal(y, [0,1,2]))).batch(200)
我正在尝试使用下面给出的特定标签过滤 CIFAR10 训练和测试数据,
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
数据集
dataset = datasets.cifar10.load_data()
拆分数据集
train_data = tf.data.Dataset.from_tensor_slices((dataset[0][0],dataset[0][1]))
test_data = tf.data.Dataset.from_tensor_slices((dataset[1][0],dataset[1][1]))
过滤功能
def filter_f(datas,filter_labels = tf.constant([0,1,2])):
x = tf.not_equal(datas[1],filter_labels)
x = tf.reduce_sum(tf.cast(x, tf.uint8))
return tf.greater(x, tf.constant(0,tf.uint8))
dataset = train_data.filter(filter_f).batch(200)
根据
labels = []
for i, x in enumerate(tfds.as_numpy(dataset)):
labels.append(x[1][0][0])
print(labels)
Returns
[4, 7, 5, 6, 0, 5, 5, 6, 5, 3, 6, 7, 0, 0, 6, 3]
要重现结果,请使用此 colab link
我不确定下面的确切问题。不过,如果您只需要删除属于特定 class 的数据,您可以使用以下命令。
dataset = train_data.filter(lambda x,y: tf.reduce_all(tf.not_equal(y, [0,1,2]))).batch(200)