使用 [image,label] = dataset.take(2) returns 两个元组而不是一个元组

Using [image,label] = dataset.take(2) returns two tuples instead of a single one

我有一个 TFRecord 文件,我在其中存储 图像包装字节作为字符串 标签作为整数 64。我正在使用下面的代码来处理图像和标签:

# Create dataset from TFRecord file 
records_path = DATA_DIR + 'TFRecords/train_0.tfrecords'
dataset = tf.data.TFRecordDataset(filenames=records_path)

# Map dataset from parsing function
parsed_dataset = dataset.map(parsing_fn)
print(parsed_dataset)

# Take a testing sample
image,label = parsed_dataset.take(2)
print(image,label)

输出:

Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=int64)
<MapDataset shapes: ((None,), ()), types: (tf.float32, tf.int64)>

((<tf.Tensor: id=635, shape=(185256,), dtype=float32, numpy=array([162., 162., 170., ...,  17.,  17., 255.], dtype=float32)>,
  <tf.Tensor: id=636, shape=(), dtype=int64, numpy=183350>),
 (<tf.Tensor: id=637, shape=(153120,), dtype=float32, numpy=array([208., 207., 202., ..., 240., 240., 242.], dtype=float32)>,
  <tf.Tensor: id=638, shape=(), dtype=int64, numpy=183350>))

这意味着 imagelabel 是包含张量 的 元组,每个张量对应于图像和标签 两张不同的图像,而不是每个图像和标签数据分别来自同一图像.

image[0] = image bytes from image 1

image[1] = label 来自 image 1

的信息

标签[0] = 图像字节 来自图像2

label[1] = label 来自 image 2

的信息

有谁知道为什么要使用“image = take(1)”returns a TakeDataset,而不是只包含一个数据样本的元组和对应的张量图像字节和标签数据?

帮手 fn 的

# Data stored format
data = {
    'image': wrap_bytes(img_bytes),
    'label': wrap_int64(label)
}

# Parsing function
def parsing_fn(serialized):

    # Define a dict with the data-names and types we expect to
    # find in the TFRecords file.

    features = \
        {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)
        }

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.io.parse_single_example(serialized=serialized,
                                             features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.io.decode_raw(image_raw, tf.uint8)
    
    # The type is now uint8 but we need it to be float.
    image = tf.cast(image, tf.float32)

    # Get the label associated with the image.
    label = parsed_example['label']
    # The image and label are now correct TensorFlow types.
    return image, label

方法以创建数据集为例。它不会从数据集中提取元素。 See documentation.

如果你想从数据中提取一个元素,你可以使用: tf.compat.v1.data.make_one_shot_iterator() 我没有找到更简洁的方法来从数据集中提取元素。

示例:

iterator = tf.compat.v1.data.make_one_shot_iterator(parsed_dataset)
image, label = iterator.get_next()