读取tensorflow时如何过滤数据集?
how to filter dataset when reading in tensorflow?
ds_train = tf.data.experimental.make_csv_dataset(
file_pattern = "./df_profile_seq_fill_csv/*.csv",
batch_size=batch_size, column_names=use_cols, label_name='label',
select_columns= select_cols,
num_parallel_reads=30,
shuffle_buffer_size=10000)
我从csv中读取数据,其中label
列是整数的标签,比如0,1,2 ...
model.fit( ds_train, validation_data=ds_test, steps_per_epoch=10000,
verbose=1,
epochs=1000000
)
我想过滤掉 label == 0
的所有样本,包括 ds_train 和 ds_test。
有什么方法可以实现这一点?谢谢
一种方法是首先使用 batch 1 从 csv 创建数据集(batch 是必需的参数)。然后过滤“批次”,这是示例,然后再次重新批次:
class_number_to_get_rid_of = 0
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=1)
dataset_filtered = dataset.filter(lambda p: tf.reduce_all(tf.not_equal(p['survived'], [class_number_to_get_rid_of])))
dataset_filtered = dataset_filtered.batch(5)
ds_train = tf.data.experimental.make_csv_dataset(
file_pattern = "./df_profile_seq_fill_csv/*.csv",
batch_size=batch_size, column_names=use_cols, label_name='label',
select_columns= select_cols,
num_parallel_reads=30,
shuffle_buffer_size=10000)
我从csv中读取数据,其中label
列是整数的标签,比如0,1,2 ...
model.fit( ds_train, validation_data=ds_test, steps_per_epoch=10000,
verbose=1,
epochs=1000000
)
我想过滤掉 label == 0
的所有样本,包括 ds_train 和 ds_test。
有什么方法可以实现这一点?谢谢
一种方法是首先使用 batch 1 从 csv 创建数据集(batch 是必需的参数)。然后过滤“批次”,这是示例,然后再次重新批次:
class_number_to_get_rid_of = 0
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=1)
dataset_filtered = dataset.filter(lambda p: tf.reduce_all(tf.not_equal(p['survived'], [class_number_to_get_rid_of])))
dataset_filtered = dataset_filtered.batch(5)