读取tensorflow时如何过滤数据集?

how to filter dataset when reading in tensorflow?

ds_train = tf.data.experimental.make_csv_dataset(
    file_pattern = "./df_profile_seq_fill_csv/*.csv",
    batch_size=batch_size, column_names=use_cols, label_name='label',
    select_columns= select_cols,
    num_parallel_reads=30, 
    shuffle_buffer_size=10000)

我从csv中读取数据,其中label列是整数的标签,比如0,1,2 ...

model.fit( ds_train, validation_data=ds_test, steps_per_epoch=10000,
     verbose=1,
    epochs=1000000
)

我想过滤掉 label == 0 的所有样本,包括 ds_train 和 ds_test。 有什么方法可以实现这一点?谢谢

一种方法是首先使用 batch 1 从 csv 创建数据集(batch 是必需的参数)。然后过滤“批次”,这是示例,然后再次重新批次:

class_number_to_get_rid_of = 0
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"

train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=1)
dataset_filtered = dataset.filter(lambda p: tf.reduce_all(tf.not_equal(p['survived'], [class_number_to_get_rid_of])))
dataset_filtered = dataset_filtered.batch(5)