如何按特定值筛选 tf.data.Dataset?
How can I filter tf.data.Dataset by specific values?
我通过读取 TFRecords 创建了一个数据集,我映射了值,我想过滤数据集的特定值,但由于结果是一个带有张量的字典,我无法获得一个的实际值张量或用 tf.cond()
/ tf.equal
检查它。我该怎么做?
def mapping_func(serialized_example):
feature = { 'label': tf.FixedLenFeature([1], tf.string) }
features = tf.parse_single_example(serialized_example, features=feature)
return features
def filter_func(features):
# this doesn't work
#result = features['label'] == 'some_label_value'
# neither this
result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
return result
def main():
file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(file_names)
dataset = dataset.map(mapping_func)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.filter(filter_func)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
sample = iterator.get_next()
您应该尝试使用来自
tf.data.TFRecordDatasettensorflow documentation
否则...阅读这篇关于 TFRecords 的文章以更好地了解 TFRecords TFRecords for humans
但最有可能的情况是您既无法访问也无法修改 TFRecord...github 上有关于此主题的请求 TFRecords request
我的建议是让事情尽可能简单......你必须知道你正在使用图形和会话......
在任何情况下...如果一切都失败了,请尽可能简单地尝试在 tensorflow 会话中不起作用的代码部分...可能所有这些操作都应该在 tf.session 是 运行...
我正在回答我自己的问题。我找到问题了!
我需要做的是tf.unstack()
这样的标签:
label = tf.unstack(features['label'])
label = label[0]
在我把它交给tf.equal()
之前:
result = tf.reshape(tf.equal(label, 'some_label_value'), [])
我想问题是标签被定义为一个数组,其中包含一个字符串类型的元素 tf.FixedLenFeature([1], tf.string)
,因此为了获得第一个元素和单个元素,我必须将其解包(这会创建一个列表) 然后获取索引为0的元素,如有错误请指正
我认为您首先不需要将标签设为一维数组。
与:
feature = {'label': tf.FixedLenFeature((), tf.string)}
您不需要拆开 filter_func
中的标签
阅读、过滤数据集非常容易,无需拆散任何东西。
读取数据集:
print(my_dataset, '\n\n')
##let us print the first 3 records
for record in my_dataset.take(3):
##below could be large in case of image
print(record)
##let us print a specific key
print(record['key2'])
过滤同样简单:
my_filtereddataset = my_dataset.filter(_filtcond1)
您可以根据需要在其中定义 _filtcond1。假设您的数据集中有一个 'true' 'false' 布尔标志,那么:
@tf.function
def _filtcond1(x):
return x['key_bool'] == 1
甚至是 lambda 函数:
my_filtereddataset = my_dataset.filter(lambda x: x['key_int']>13)
如果你正在阅读一个你还没有创建的数据集或者你不知道键(似乎是 OPs 的情况),你可以首先使用它来了解键和结构:
import json
from google.protobuf.json_format import MessageToJson
for raw_record in noidea_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
##print(example) ##if image it will be toooolong
m = json.loads(MessageToJson(example))
print(m['features']['feature'].keys())
现在您可以继续过滤了
我通过读取 TFRecords 创建了一个数据集,我映射了值,我想过滤数据集的特定值,但由于结果是一个带有张量的字典,我无法获得一个的实际值张量或用 tf.cond()
/ tf.equal
检查它。我该怎么做?
def mapping_func(serialized_example):
feature = { 'label': tf.FixedLenFeature([1], tf.string) }
features = tf.parse_single_example(serialized_example, features=feature)
return features
def filter_func(features):
# this doesn't work
#result = features['label'] == 'some_label_value'
# neither this
result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
return result
def main():
file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(file_names)
dataset = dataset.map(mapping_func)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.filter(filter_func)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
sample = iterator.get_next()
您应该尝试使用来自 tf.data.TFRecordDatasettensorflow documentation
否则...阅读这篇关于 TFRecords 的文章以更好地了解 TFRecords TFRecords for humans
但最有可能的情况是您既无法访问也无法修改 TFRecord...github 上有关于此主题的请求 TFRecords request
我的建议是让事情尽可能简单......你必须知道你正在使用图形和会话......
在任何情况下...如果一切都失败了,请尽可能简单地尝试在 tensorflow 会话中不起作用的代码部分...可能所有这些操作都应该在 tf.session 是 运行...
我正在回答我自己的问题。我找到问题了!
我需要做的是tf.unstack()
这样的标签:
label = tf.unstack(features['label'])
label = label[0]
在我把它交给tf.equal()
之前:
result = tf.reshape(tf.equal(label, 'some_label_value'), [])
我想问题是标签被定义为一个数组,其中包含一个字符串类型的元素 tf.FixedLenFeature([1], tf.string)
,因此为了获得第一个元素和单个元素,我必须将其解包(这会创建一个列表) 然后获取索引为0的元素,如有错误请指正
我认为您首先不需要将标签设为一维数组。
与:
feature = {'label': tf.FixedLenFeature((), tf.string)}
您不需要拆开 filter_func
中的标签阅读、过滤数据集非常容易,无需拆散任何东西。
读取数据集:
print(my_dataset, '\n\n')
##let us print the first 3 records
for record in my_dataset.take(3):
##below could be large in case of image
print(record)
##let us print a specific key
print(record['key2'])
过滤同样简单:
my_filtereddataset = my_dataset.filter(_filtcond1)
您可以根据需要在其中定义 _filtcond1。假设您的数据集中有一个 'true' 'false' 布尔标志,那么:
@tf.function
def _filtcond1(x):
return x['key_bool'] == 1
甚至是 lambda 函数:
my_filtereddataset = my_dataset.filter(lambda x: x['key_int']>13)
如果你正在阅读一个你还没有创建的数据集或者你不知道键(似乎是 OPs 的情况),你可以首先使用它来了解键和结构:
import json
from google.protobuf.json_format import MessageToJson
for raw_record in noidea_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
##print(example) ##if image it will be toooolong
m = json.loads(MessageToJson(example))
print(m['features']['feature'].keys())
现在您可以继续过滤了