如何在 Tensorflow 中设置 ParallelMapDataset 数据类型中的图像数量?

How to set the number of images in ParallelMapDataset datatype in Tensorflow?

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

train_images = dataset['train']

test_images = dataset['test']

train_batches = ( 
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

现在我想将 test_images 大小减少到 100 张图片。 我期待一些像这样的代码:

test_images = test_images[100]

但这会产生错误:

'ParallelMapDataset' object is not subscriptable

使用 take() 方法,您可以从目标数据集中获取批次或项目。

如果数据集是批处理的:

test_images.take((100 // BATCH_SIZE) + 1)

当您对数据集进行批处理时,它将包含批次或组。

比方说,您使用大小 32 对数据进行批处理,test_images.take(1) 将 return 32 个元素,换句话说,一个批处理。 test_images.take(2) 将 return 64 个元素等


如果不批量:

test_images.take(100)

与批处理数据集不同,数据集将 return 已传递到 take() 方法的元素数量。