使用 Tensorflow 的数据集时必须调用 iter.get_next 吗?
Must I call iter.get_next when using Tensorflow's Dataset?
我已经将代码从基于队列的系统转换为 tensorflow 的数据集。转换后,我发现准确性有所下降,时间也有所增加。我将此归因于我的实施不正确,目前我正在尝试解决可能的问题。现在,通过此转换中的反复试验,我根据我遇到的许多文章和示例做出了一些假设,我只是想确保我当前的实施是正确的,并且我的假设也是正确的。
以前我有大量图像,我会将它们分批放入一个队列中,然后将我的 100 张图像从队列中取出,进行处理和汇总,然后继续。我认为通过队列加载到内存中可能会造成瓶颈,所以当我听说数据集 API 时,我认为它值得一看。所以我现在检索所有图像信息并将其传递到我的方法,然后我通过数据集批处理方法执行批处理。之前和之后如下所示。我读过,没有必要在数据集上调用 iter.get_next,因为操作会自动调用它,但是根据我最后看到的准确性,我对这是否属实犹豫不决.目前如您所见,我只是将 iter.initializer 作为一个操作传递给 sess.run 和我的其他操作,然后传递 feed_dict。任何见解都会有所帮助,因为我对此有些陌生。谢谢!
上一个使用队列时的示例函数:
(请注意,我会将图像排队到一个 blob 对象中并将该子集传递给此方法)
def get_summary(self, sess, images, labels, weights, keep_prob = 1.0):
feed_dict = {self._input_images: images, self._input_labels: labels,
self._input_weights: weights, self._is_training: False}
summary, acc = sess.run([self._summary_op, self._accuracy], feed_dict=feed_dict)
return summary, acc
当前示例函数使用数据集 API:
(现在在调用它之前,我用所有数据填充我的 blob 对象并使用下面的批处理功能——注意我从不调用 iter.get_next())
def get_summary(self, sess, images, labels, weights, keep_prob = 1.0, batch_size=32):
dataset = tf.data.Dataset.from_tensor_slices((self._input_images, self._input_labels,
self._input_weights)).repeat().batch(batch_size)
iter = dataset.make_initializable_iterator()
feed_dict = {self._input_images: images, self._input_labels: labels,
self._input_weights: weights, self._is_training: False}
_, summary, acc = sess.run([iter.initializer, self._summary_op, self._accuracy], feed_dict=feed_dict)
return summary, acc
从该代码片段来看,您似乎从未使用过 iter
中的值,因此它应该不会对您的摘要产生影响。例如,您应该能够删除创建迭代器的行,并从传递给 sess.run()
的列表中删除 iter.initializer
并获得相同的结果。
要回答更广泛的问题 "Must I call iter.get_next()
?":在基于图形的 TensorFlow 中,tf.data.Iterator
和您传递给 [=12= 的 tensor/operation 之间必须存在数据流连接] 以便使用该迭代器中的值。如果您使用的是 low-level TensorFlow API,最简单的方法是调用 iter.get_next()
来获取一个或多个 tf.Tensor
对象,然后使用这些张量作为模型的输入。
但是,如果您使用的是 高级 tf.estimator
API,您的 input_fn
可以 return tf.data.Dataset
而无需创建 tf.data.Iterator
(或调用 Iterator.get_next()
,Estimator API 将负责创建迭代器并为您调用 get_next()
。
我已经将代码从基于队列的系统转换为 tensorflow 的数据集。转换后,我发现准确性有所下降,时间也有所增加。我将此归因于我的实施不正确,目前我正在尝试解决可能的问题。现在,通过此转换中的反复试验,我根据我遇到的许多文章和示例做出了一些假设,我只是想确保我当前的实施是正确的,并且我的假设也是正确的。
以前我有大量图像,我会将它们分批放入一个队列中,然后将我的 100 张图像从队列中取出,进行处理和汇总,然后继续。我认为通过队列加载到内存中可能会造成瓶颈,所以当我听说数据集 API 时,我认为它值得一看。所以我现在检索所有图像信息并将其传递到我的方法,然后我通过数据集批处理方法执行批处理。之前和之后如下所示。我读过,没有必要在数据集上调用 iter.get_next,因为操作会自动调用它,但是根据我最后看到的准确性,我对这是否属实犹豫不决.目前如您所见,我只是将 iter.initializer 作为一个操作传递给 sess.run 和我的其他操作,然后传递 feed_dict。任何见解都会有所帮助,因为我对此有些陌生。谢谢!
上一个使用队列时的示例函数: (请注意,我会将图像排队到一个 blob 对象中并将该子集传递给此方法)
def get_summary(self, sess, images, labels, weights, keep_prob = 1.0):
feed_dict = {self._input_images: images, self._input_labels: labels,
self._input_weights: weights, self._is_training: False}
summary, acc = sess.run([self._summary_op, self._accuracy], feed_dict=feed_dict)
return summary, acc
当前示例函数使用数据集 API: (现在在调用它之前,我用所有数据填充我的 blob 对象并使用下面的批处理功能——注意我从不调用 iter.get_next())
def get_summary(self, sess, images, labels, weights, keep_prob = 1.0, batch_size=32):
dataset = tf.data.Dataset.from_tensor_slices((self._input_images, self._input_labels,
self._input_weights)).repeat().batch(batch_size)
iter = dataset.make_initializable_iterator()
feed_dict = {self._input_images: images, self._input_labels: labels,
self._input_weights: weights, self._is_training: False}
_, summary, acc = sess.run([iter.initializer, self._summary_op, self._accuracy], feed_dict=feed_dict)
return summary, acc
从该代码片段来看,您似乎从未使用过 iter
中的值,因此它应该不会对您的摘要产生影响。例如,您应该能够删除创建迭代器的行,并从传递给 sess.run()
的列表中删除 iter.initializer
并获得相同的结果。
要回答更广泛的问题 "Must I call iter.get_next()
?":在基于图形的 TensorFlow 中,tf.data.Iterator
和您传递给 [=12= 的 tensor/operation 之间必须存在数据流连接] 以便使用该迭代器中的值。如果您使用的是 low-level TensorFlow API,最简单的方法是调用 iter.get_next()
来获取一个或多个 tf.Tensor
对象,然后使用这些张量作为模型的输入。
但是,如果您使用的是 高级 tf.estimator
API,您的 input_fn
可以 return tf.data.Dataset
而无需创建 tf.data.Iterator
(或调用 Iterator.get_next()
,Estimator API 将负责创建迭代器并为您调用 get_next()
。