如何解码 Tensorflow 2 中的示例(从 1.12 移植)
How to decode examples in Tensorflow 2 (porting from 1.12)
我有以下方法可以解码序列化 TFRecordDataset
中的样本:
def decode_example(self, serialized_example):
"""Return a dict of Tensors from a serialized tensorflow.Example."""
data_fields, data_items_to_decoders = self.example_reading_spec()
# Necessary to rejoin examples in the correct order with the Cloud ML Engine
# batch prediction API.
data_fields['batch_prediction_key'] = tf.io.FixedLenFeature([1], tf.int64, 0)
if data_items_to_decoders is None:
data_items_to_decoders = {
field: tf.contrib.slim.tfexample_decoder.Tensor(field)
for field in data_fields
}
decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder(data_fields, data_items_to_decoders)
decode_items = list(sorted(data_items_to_decoders))
decoded = decoder.decode(serialized_example, items=decode_items)
return dict(zip(decode_items, decoded))
但是,这在 Tensorflow 2 下不起作用。
tf.contrib
不存在了,我找不到任何可以用来解码这些例子的东西。
安装 tensorflow-data-validation
后我什至找不到 TFExampleDecoder
。
知道那里出了什么问题吗and/or我如何解码我的示例?
我能够使用 tf.io.parse_single_example
使其工作。
我们必须像往常一样声明我们的数据字段 (example_reading_spec
),然后我们可以使用它来解码示例:
def example_reading_spec():
data_fields = {
'inputs': tf.io.VarLenFeature(tf.float32),
'targets': tf.io.VarLenFeature(tf.int64),
}
return data_fields
def decode_example(serialized_example):
"""Return a dict of Tensors from a serialized tensorflow.Example."""
return tf.io.parse_single_example(
serialized_example,
features=example_reading_spec()
)
现在我们可以使用 Dataset.map
像这样加载我们的数据集分片:
record_dataset = tf.data.TFRecordDataset(filenames, buffer_size=1024)
record_dataset = record_dataset.map(decode_example)
我有以下方法可以解码序列化 TFRecordDataset
中的样本:
def decode_example(self, serialized_example):
"""Return a dict of Tensors from a serialized tensorflow.Example."""
data_fields, data_items_to_decoders = self.example_reading_spec()
# Necessary to rejoin examples in the correct order with the Cloud ML Engine
# batch prediction API.
data_fields['batch_prediction_key'] = tf.io.FixedLenFeature([1], tf.int64, 0)
if data_items_to_decoders is None:
data_items_to_decoders = {
field: tf.contrib.slim.tfexample_decoder.Tensor(field)
for field in data_fields
}
decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder(data_fields, data_items_to_decoders)
decode_items = list(sorted(data_items_to_decoders))
decoded = decoder.decode(serialized_example, items=decode_items)
return dict(zip(decode_items, decoded))
但是,这在 Tensorflow 2 下不起作用。
tf.contrib
不存在了,我找不到任何可以用来解码这些例子的东西。
安装 tensorflow-data-validation
后我什至找不到 TFExampleDecoder
。
知道那里出了什么问题吗and/or我如何解码我的示例?
我能够使用 tf.io.parse_single_example
使其工作。
我们必须像往常一样声明我们的数据字段 (example_reading_spec
),然后我们可以使用它来解码示例:
def example_reading_spec():
data_fields = {
'inputs': tf.io.VarLenFeature(tf.float32),
'targets': tf.io.VarLenFeature(tf.int64),
}
return data_fields
def decode_example(serialized_example):
"""Return a dict of Tensors from a serialized tensorflow.Example."""
return tf.io.parse_single_example(
serialized_example,
features=example_reading_spec()
)
现在我们可以使用 Dataset.map
像这样加载我们的数据集分片:
record_dataset = tf.data.TFRecordDataset(filenames, buffer_size=1024)
record_dataset = record_dataset.map(decode_example)