难以堆叠 MNIST 和 Fashion_MNIST

Difficulty with stacking MNIST and Fashion_MNIST

我知道这对你们来说太简单了,但我是初学者,需要你们的帮助。 我正在努力用 CNN 制作二元分类器。 我的最终目标是检查精度超过 0.99

我导入了 MNIST 和 FASHION_MNIST 来识别它是数字还是衣服。 所以有2类。我想将 0-60000 归类为 0,将 60001-120000 归类为 1。 我将使用 binary_crossentropy.

但我不知道如何从头开始。 首先如何使用vstack hstack来结合MNIST和FASHION_MNIST?

到目前为止我就是这样尝试的

****import numpy as np
from keras.datasets import mnist
from keras.datasets import fashion_mnist
import keras
import tensorflow as tf
from keras.utils.np_utils import to_categorical
num_classes = 2
train_images = train_images.astype("float32") / 255
test_images = test_images.astype("float32") / 255
train_images = train_images.reshape((-1, 784))
test_images = test_images.reshape((-1, 784))
train_labels = to_categorical(train_labels, num_classes)
test_labels = to_categorical(test_labels, num_classes)****

首先

它们是图像,因此最好将它们视为图像,不要将它们重塑为矢量。

现在回答问题。假设您有 mnist_train_imagefashion_train_image,两者都有 (60000, 28, 28) 输入形状。

你要做的是由两部分组成,结合输入和制定目标。

首先是输入

正如您在问题中已经写过的,您可以像这样使用 np.vstack

>>> train_image = np.vstack((fashion_train_image, mnist_train_image))
>>> train_image.shape
(120000, 28, 28)

但是您应该已经注意到,记住您是需要 vstack 还是 dstackhstack 有点痛苦。我的偏好是使用 np.concatenate 而不是

>>> train_image = np.concatenate((fashion_train_image, mnist_train_image), axis=0)
>>> train_image.shape
(120000, 28, 28)

现在不需要记住鸭子是什么 vhd 你只需要记住你想要连接的轴(或维度),在这种情况下它是第一个轴表示 0。特别是在这种情况下,“垂直”是第二个轴,因为它是一堆图像,第一个轴是“批处理”。

接下来,标签

既然您想将 0-60000 分类为 0,将 60001-120000 分类为 1,那么有很多奇特的方法可以做到这一点。

但简而言之,您可以使用 np.zeros 创建一个以 0 填充的数组。而 np.ones 可以创建一个以 1 填充的数组。但是 oneszeros 给你一个浮点数组,我不确定这是否会成为一个问题所以我在后面添加 .astype('uint8') 以防万一。您也可以在函数中添加参数 dtype='uint8'

使用上面的连接

>>> train_labels = np.concatenate((np.zeros(60000), np.ones(60000))).astype('uint8')
>>> train_labels.shape
(120000,)

对整个大小使用 1 或 0,然后减去或添加或重新分配其余部分

>>> train_labels = np.zeros(120000).astype('uint8')
>>> train_labels[60000:] = 1
#####
>>> train_labels = np.ones(120000, dtype='uint8')
>>> train_labels[:60000] -= 1

重要!!!!

你的例子中有一个关于标签的明显错误,索引从 0 开始,所以第 60,000 个索引是 59,999。

所以你真正想要的是将 0-59999 归类为 0,将 60000-119999 归类为 1。