使用 tf.data.Dataset 产生多输入数据

Using tf.data.Dataset to produce multi-input data

我有一个 dataset(tf.data.Dataset 的实例),它产生 image 作为输入和 label 作为输出。我的模型需要将 [image, label] 作为输入,将 label 作为输出。那么我该如何实现呢?

我试过这个:

dataset = dataset.map(suit_IO)

def suit_IO(img, label):
  return [img, label], label

但出现此错误:

TypeError: Unsupported return value from function passed to Dataset.map(): ([<tf.Tensor 'args_0:0' shape=(320, 320, 3) dtype=float32>, <tf.Tensor 'args_1:0' shape=() dtype=int32>], <tf.Tensor 'args_1:0' shape=() dtype=int32>).

您需要使用嵌套元组,而不是列表:

def suit_IO(img, label):
  return (img, label), label