从数据集 Tensorflow 中删除坏数据
Drop bad data from dataset Tensorflow
我有一个使用 tf.data 的训练管道。在数据集中有一些坏元素,在我的例子中是 0 值。我如何根据它们的值删除这些坏数据元素?由于数据集很大,我希望能够在训练时从管道中删除它们。
根据以下伪代码假设:
def parse_function(element):
height = element['height']
if height <= 0: skip() #How to skip this value
labels = element['label']
features['height'] = height
return features, labels
ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)
建议是根据特征值使用 ds.skip(1),或者提供某种中性 weight/loss?
您可以使用 tf.data.Dataset.filter
:
def filter_func(elem):
""" return True if the element is to be kept """
return tf.math.greater(elem['height'],0)
ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.filter(filter_func)
假设 element
是您代码中的数据框,那么它将是:
def parse_function(element):
element = element.query('height>0')
labels = element['label']
features['height'] = element['height']
return features, labels
ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)
`
我有一个使用 tf.data 的训练管道。在数据集中有一些坏元素,在我的例子中是 0 值。我如何根据它们的值删除这些坏数据元素?由于数据集很大,我希望能够在训练时从管道中删除它们。
根据以下伪代码假设:
def parse_function(element):
height = element['height']
if height <= 0: skip() #How to skip this value
labels = element['label']
features['height'] = height
return features, labels
ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)
建议是根据特征值使用 ds.skip(1),或者提供某种中性 weight/loss?
您可以使用 tf.data.Dataset.filter
:
def filter_func(elem):
""" return True if the element is to be kept """
return tf.math.greater(elem['height'],0)
ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.filter(filter_func)
假设 element
是您代码中的数据框,那么它将是:
def parse_function(element):
element = element.query('height>0')
labels = element['label']
features['height'] = element['height']
return features, labels
ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)
`