`tf.data.Dataset`如何动态传递`tf.io.FixedLenFeature`的大小

`tf.data.Dataset` how to dynamically pass the size of `tf.io.FixedLenFeature`

我们有 tfrecord 文件,其中每个 tfrecord 文件包含一个示例,但其中的特征包含一个值列表。我们按以下方式使用 tf.data.Dataset

n_rows_per_record_file = 100

def parse_tfrecord_to_example(record_bytes):
    col_map = {
    "my_col": tf.io.FixedLenFeature(
        shape=n_rows_per_record_file, dtype=tf.int64
    )}

ds = (
    tf.data.TFRecordDataset(file_paths)
    .map(parse_tfrecord_to_example)
    )

我们不想为 n_rows_per_record_file 使用固定常量,而是希望在给定文件路径的情况下查找行数。

关于如何实现这一点有什么想法吗?

我们试过使用这样的东西

def get_shape(filepath):
    return filepath, shapes[filepath]
ds = (
    tf.data.list_files(file_paths)
    .map(get_shape)
    .map(
        lambda f, shape: tf.data.TFRecordDataset(f).map(
           lambda shape: parse_tfrecord_to_example(shape)
       )
    )

但这失败了,因为 tf.data 直到需要时才急切地评估文件路径(即它仍然是 tf.Tensor)

您提出的解决方案看起来不错,是的,您的文件路径将是一个张量,但您不能在您的情况下使用某些外部 python 对象,例如 shapes。不幸的是,如果你使用 tf.Data,你需要学习很多 Tensorflow 特定的函数来做 'basic' python 事情。例如,在您的情况下,您可能想要 split the file name and then cast 一个字符串到 int。所以是的,一切都是张量。

在您的评论中,您还提到了广播。 tf.Data 不适用于 broadcasting. tf.Data is for fast loading of data in memory record by record. So, whenever you think to apply vectorisation or broadcasting you should use something else. First option, prepare your data before you save it in TFRecords using whatever tool you want: pandas, dask, spark, etc. Second option, enrich your data on the fly with one of tf lookup 实施。例如,如果您有一个包含形状的字典,并且您希望将此功能添加到基于某个类别或 ID 的每条记录,请将该数据加载到 StaticHashTable 并添加一个预处理查找步骤。注意:这个用于丰富的数据必须非常小,因为你必须在内存中,如果你使用 GPU,甚至可能是 GPU 内存。

所以这是一个带有查找的示例 table:

dataset = tf.data.Dataset.from_tensor_slices([range(10)])
keys_tensor = tf.constant(range(10))
vals_tensor = tf.constant(range(100, 110))
lookup = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), default_value=-1)

def map_numbers(v):
    return lookup[v]

for element in dataset.map(map_numbers):
    print(element)

tf.Tensor([100 101 102 103 104 105 106 107 108 109], shape=(10,), dtype=int32)