如何拆分 tf.data.Dataset 的输出?

How can I split output from the tf.data.Dataset?

我有一些形状为 [150~180, 480~512, 480~512] 的 .npy 数据,是从 MRI 图像中提取的。

所以我使用了一些函数来细化数据集并将它们转换为tf.data.Dataset类型

train_dataset = tf.data.Dataset.from_tensor_slices((list_image,list_label))
train_dataset = train_dataset.shuffle(NUM_TRN)
train_dataset = train_dataset.batch(NUM_BATCH_NPY)
train_dataset = train_dataset.map(
    lambda x,y: tf.py_function(load_dataset, inp=[x,y], Tout=[tf.float16, tf.float16]))


description:
1) list_image & list_label is the lists of .npys
---- [000_image.npy,..., 099_image.npy], [000_label.npy,.,.., 099_label.npy]

2) NUM_TRN is the number of total dataset, and it is used to shuffle all dataset
---- 100 (The number of *_image.npy)

3) NUM_BATCH_NPY is the number of .npys that will be extracted simultaneously
---- If NUM_BATCH_NPY is 3, three sets of .npys will be extracted
---- [000_image/label.npy], [001_image/label.npy], [002_image/label.npy]

4) The function 'load_dataset' will extract arrays from the .npys,
   refine the extracted .npys and stack them along 0-axis.
---- 000_image.npy->(170,360,360,1), 000_label.npy->(170,)
---- 001_image.npy->(150,360,360,1), 001_label.npy->(150,)
---- 002_image.npy->(163,360,360,1), 002_label.npy->(163,)
---- output shape of the dataset will be ((483,360,360,1),(483,))

如上所述,数组将按图像提取。问题是,如何将这个提取的数据集类型对象拆分为 NUM_TRAIN_BATCH=128 个切片?

----提取数据集(483,360,360,1)->(128,:),(128,:), ...

我将提取批次并创建新数据集:

datasets = []
train_dataset = train_dataset.batch(NUM_TRAIN_BATCH)
for set in train_dataset:
  set = tf.data.Dataset.from_tensor_slices(set)
  datasets.append(set)