TensorFlow TypeError: 'BatchDataset' object is not iterable / TypeError: 'CacheDataset' object is not subscriptable

TensorFlow TypeError: 'BatchDataset' object is not iterable / TypeError: 'CacheDataset' object is not subscriptable

我正在关注 TensorFlow starter guide。它特别表示要在鸢尾花(花)分类的示例项目上启用急切执行。

Import the required Python modules, including TensorFlow, and enable eager execution for this program. Eager execution makes TensorFlow evaluate operations immediately, returning concrete values instead of creating a computational graph that is executed later. If you are used to a REPL or the python interactive console, you'll feel at home.

所以我按照说明启用eager execution,继续说明。然而,当我到达讨论如何将数据集准备为张量流数据集的部分时,我遇到了一个错误。

小区代码

train_dataset = tf.data.TextLineDataset(train_dataset_fp)
train_dataset = train_dataset.skip(1)             # skip the first header row
train_dataset = train_dataset.map(parse_csv)      # parse each row
train_dataset = train_dataset.shuffle(buffer_size=1000)  # randomize
train_dataset = train_dataset.batch(32)

# View a single example entry from a batch
features, label = iter(train_dataset).next()
print("example features:", features[0])
print("example label:", label[0])

错误

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-61bfe99af85b> in <module>()

      7 
      8 # View a single example entry from a batch
----> 9 features, label = iter(train_dataset).next()
     10 print("example features:", features[0])
     11 print("example label:", label[0])

TypeError: 'BatchDataset' object is not iterable

我只想继续学习示例。我该怎么做才能将 BatchDataset 对象转换为可迭代的对象?

原来是我在项目中确实没有做某些步骤导致了这个问题。

将 TensorFlow 从 1.7 升级到 1.8:

!pip install --upgrade tensorflow

正在检查您的 TensorFlow 是否已更新

此代码单元格:

from __future__ import absolute_import, division, print_function

import os
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()

print("TensorFlow version: {}".format(tf.VERSION))
print("Eager execution: {}".format(tf.executing_eagerly()))

应该return以下输出:

TensorFlow version: 1.8.0
Eager execution: True

参考 here 替代解决方案。我们还可以使用 as_numpy_iterator() 从 tensorflow 数据集

中获取值