从 TF 数据集中获取样本

Get samples rom TF dataset

我有一个 TF 数据集

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

而且由于我的数据量很大 (100,000),我只想 select 训练数据的一个子集 所以我需要从旧数据集创建一个新的 TF 数据集

你可以使用 tf.data.Dataset.shard.

Creates a Dataset that includes only 1/num_shards of this dataset.

shard is deterministic. The Dataset produced by A.shard(n, i) will contain all elements of A whose index mod n = i.

A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)

如果您想要原始 train_ds 的 1/10: new_ds = train_ds.shard(num_shards=10, index=0)