ValueError: unpack: when trying to split fashion_mnist into 3 splits

ValueError: unpack: when trying to split fashion_mnist into 3 splits

(train_dataset,validation_dataset,test_dataset) = tfds.load('fashion_mnist',
                            with_info=True, as_supervised=True,
                            split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'])

我正在尝试将 fashion_mnist 分成 3 组训练测试和验证。我不确定这里的错误是什么,因为我根本无法解决它。

"fashion_mnist" 数据集在 Tensorflow 数据集中只有一个训练和一个测试拆分(参见 documentation,拆分部分),因此在 split 参数中它期望一个列表具有长度最多为 2,但是您使用的是长度为 3 的列表。为了获得训练、验证和测试拆分,您可以执行以下操作:

whole_ds,info_ds = tfds.load("fashion_mnist", with_info = True, split='train+test', as_supervised=True)

n = tf.data.experimental.cardinality(whole_ds).numpy() # 70 000
train_num = int(n*0.8)
val_num = int(n*0.1)

train_ds = whole_ds.take(train_num)
val_ds = whole_ds.skip(train_num).take(val_num)
test_ds = whole_ds.skip(train_num+val_num)

如果您想保留提供的测试数据作为您的测试数据:

(train_data, validation_data, test_data),info = tfds.load(
        name="fashion_mnist", 
        split=['train[:80%]', 'train[80%:]', 'test'],
        as_supervised=True,
        with_info=True)