如何从 tensorflow 数据集中解压数据?

How to unpack the data from tensorflow dataset?

这是我从 tfrecord 加载数据的代码:

def read_tfrecord(tfrecord, epochs, batch_size):

    dataset = tf.data.TFRecordDataset(tfrecord)

    def parse(record):
        features = {
            "image": tf.io.FixedLenFeature([], tf.string),
            "target": tf.io.FixedLenFeature([], tf.int64)
        }
        example = tf.io.parse_single_example(record, features)
        image = decode_image(example["image"])
        label = tf.cast(example["target"], tf.int32)
        return image, label

    dataset = dataset.map(parse)
    dataset = dataset.shuffle(buffer_size=10000)        
    dataset = dataset.prefetch(buffer_size=batch_size)  #
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.repeat(epochs)

    return dataset


x_train, y_train = read_tfrecord(tfrecord=train_files, epochs=EPOCHS, batch_size=BATCH_SIZE)

我收到以下错误:

ValueError: too many values to unpack (expected 2)

我的问题是:

如何从数据集中解压数据?

dataset = read_tfrecord(tfrecord=train_files, epochs=EPOCHS, batch_size=BATCH_SIZE)
# the above should return an iterator
for x, y in dataset:
    print(x)
    print(y)
    # now the unpacking parsing happens

您可以试试这个解决方案:

dataset = read_tfrecord(tfrecord=train_files, epochs=EPOCHS, batch_size=BATCH_SIZE)

iterator = iter(dataset)

x, y = next(iterator)

TensorFlow 的 get_single_element() 终于 around 可以用来解压数据集。

这避免了使用 .map()iter() 生成和使用迭代器的需要(这对于大数据集来说可能代价高昂)。

get_single_element() returns 封装数据集所有成员的张量(或张量的元组或字典)。我们需要将批处理的数据集的所有成员传递到一个元素中。

这可用于获取特征作为张量数组,或特征和标签作为元组或字典(张量数组),具体取决于原始数据集的方式已创建。

在 SO 上查看此 以获取将特征和标签解包到张量数组元组中的示例。