使用 Tensorflow 的数据集时必须调用 iter.get_next 吗?

Must I call iter.get_next when using Tensorflow's Dataset?

我已经将代码从基于队列的系统转换为 tensorflow 的数据集。转换后,我发现准确性有所下降,时间也有所增加。我将此归因于我的实施不正确,目前我正在尝试解决可能的问题。现在,通过此转换中的反复试验,我根据我遇到的许多文章和示例做出了一些假设,我只是想确保我当前的实施是正确的,并且我的假设也是正确的。

以前我有大量图像,我会将它们分批放入一个队列中,然后将我的 100 张图像从队列中取出,进行处理和汇总,然后继续。我认为通过队列加载到内存中可能会造成瓶颈,所以当我听说数据集 API 时,我认为它值得一看。所以我现在检索所有图像信息并将其传递到我的方法,然后我通过数据集批处理方法执行批处理。之前和之后如下所示。我读过,没有必要在数据集上调用 iter.get_next,因为操作会自动调用它,但是根据我最后看到的准确性,我对这是否属实犹豫不决.目前如您所见,我只是将 iter.initializer 作为一个操作传递给 sess.run 和我的其他操作,然后传递 feed_dict。任何见解都会有所帮助,因为我对此有些陌生。谢谢!

上一个使用队列时的示例函数: (请注意,我会将图像排队到一个 blob 对象中并将该子集传递给此方法)

def get_summary(self, sess, images, labels, weights, keep_prob = 1.0):
        feed_dict = {self._input_images: images, self._input_labels: labels,
                     self._input_weights: weights, self._is_training: False}
        summary, acc = sess.run([self._summary_op, self._accuracy], feed_dict=feed_dict)

        return summary, acc

当前示例函数使用数据集 API: (现在在调用它之前,我用所有数据填充我的 blob 对象并使用下面的批处理功能——注意我从不调用 iter.get_next())

def get_summary(self, sess, images, labels, weights, keep_prob = 1.0, batch_size=32):
        dataset = tf.data.Dataset.from_tensor_slices((self._input_images, self._input_labels,
                                                      self._input_weights)).repeat().batch(batch_size)

        iter = dataset.make_initializable_iterator()
        feed_dict = {self._input_images: images, self._input_labels: labels,
                     self._input_weights: weights, self._is_training: False}
        _, summary, acc = sess.run([iter.initializer, self._summary_op, self._accuracy], feed_dict=feed_dict)

        return summary, acc

从该代码片段来看,您似乎从未使用过 iter 中的值,因此它应该不会对您的摘要产生影响。例如,您应该能够删除创建迭代器的行,并从传递给 sess.run() 的列表中删除 iter.initializer 并获得相同的结果。

要回答更广泛的问题 "Must I call iter.get_next()?":在基于图形的 TensorFlow 中,tf.data.Iterator 和您传递给 [=12= 的 tensor/operation 之间必须存在数据流连接] 以便使用该迭代器中的值。如果您使用的是 low-level TensorFlow API,最简单的方法是调用 iter.get_next() 来获取一个或多个 tf.Tensor 对象,然后使用这些张量作为模型的输入。

但是,如果您使用的是 高级 tf.estimator API,您的 input_fn 可以 return tf.data.Dataset 而无需创建 tf.data.Iterator(或调用 Iterator.get_next(),Estimator API 将负责创建迭代器并为您调用 get_next()