"Tuple indices must be integers or slices, not str" 加载 Tensorflow 数据集时

"Tuple indices must be integers or slices, not str" when loading Tensorflow Dataset

我正在尝试加载 MNIST 数据集,但我得到了

TypeError: tuple indices must be integers or slices, not str

这是我的代码:

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

mnist_dataset = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

这一行给我错误:

mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

如果包含with_info=True,则需要进行相应的解包:

mnist_dataset, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

按照您的方式,mnist_dataset 是一个包含 2 项字典和 tfds.core.DatasetInfo 对象的元组:

(
    {
'test': <PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>,
'train': <PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>
    },
    
tfds.core.DatasetInfo(name='mnist', etc)
)