CNN 在 python。图像数组

CNN in python. Array of images

在发帖之前,我试图在网上深入搜索解决方案,但找不到。我的问题出现在卷积神经网络训练中读取图像的过程中。基本上,我决定创建一个函数,从一系列图像中创建输入值和输出值。我想读取集合中的所有图像,但不是同时读取所有图像,以避免 运行 内存不足,因此我创建了下一个函数:

def readImages(strSet = 'Train', nIni = 1, nFin = 20):

if strSet not in ('Train','Test'):
    return None
# 
# Inicializamos los arrays de salida: las imágenes y las etiquetas.
arrImages = []
arrLabels = []
#
# Recorremos todos y cada uno de los directorios dentro del set elegido
for strDir in os.listdir(data_dir+'/' + strSet + '/'):
    # Nombre de la clase que estamos tratando.
    strClass = strDir[strDir.find('-')+1:]
    # Número y nombre de los ficheros, por si es menor que el número n indicado.
    arrNameFiles = os.listdir(data_dir+'/' + strSet + '/'+strDir)
    nFiles = len(os.listdir(data_dir+'/' + strSet + '/'+strDir))
    #
    # Cogemos los ficheros desde el nIni al nFin. De esta forma nos aseguramos los cogemos todos en cada directorio.
    #print('nImagesClase(',strSet,',',strClass,'):',nImagesClase(strSet, strClass))
    if (nIni == -1):
        # Si el valor es -1, cogemos todas las imágenes del directorio.            
        listChosenFiles = arrNameFiles
        #print('Todos: ', len(listChosenFiles))
    else:
        if (nImagesClase(strSet, strClass)<nFin):
            # Si ya hemos dado la vuelta a todos los ficheros del grupo, los cogemos al azar.
            listChosenFiles = random.sample(arrNameFiles, min(nFiles, nFin-nIni))
            #print('Fin del directorio ',nFin,'>',nImagesClase(strSet,strClass),': ', len(listChosenFiles))
        else:
            # Si no, seguimos.
            listChosenFiles = arrNameFiles[nIni-1:min(nFin,nImagesClase(strSet, strClass))-1]
            #print('Seguimos ',nIni,'-',nFin,': ', len(listChosenFiles))
    #
    for file in listChosenFiles:
        # Lectura del fichero.
        image = plt.imread(data_dir+'/'+strSet+'/'+strDir+'/'+file)
        #print('Original Shape: ',image.shape)
        #plt.imshow(image)
        image = cv2.resize(image, (crop_width, crop_height), interpolation=cv2.INTER_NEAREST)
        #image = image.reshape((image_height,image_width,num_channels))
        #print('Al array de imágenes: ',image.shape)
        arrImages.append(image)
        # Añadimos etiquetas.
        arrLabel = np.zeros(n_classes)
        arrLabel[array_classes.index(strClass)] = 1
        arrLabels.append(arrLabel)
#
# Recogemos los valores de entrada y salida en arrays.
y = np.array(arrLabels)
X = np.array(arrImages, dtype=np.uint8)
# Una vez terminado el recorrido por todas las imágenes, reordenamos los índices para que no vayan las imágenes en secuendias de la misma clase.
arrIndexes = np.arange(X.shape[0])
np.random.shuffle(arrIndexes)
X = X[arrIndexes]
y = y[arrIndexes]
#
return X, y

为了测试此函数的行为,我只执行了以下行。

X, y = readImages(strSet = 'Train', nIni = 1, nFin = 5)

没关系,直到 nIninFin 达到某些值(例如 101-105)。在那一刻,我收到以下错误:

ValueError                                Traceback (most recent call last)
<ipython-input-125-8a690256a1fc> in <module>
----> 1 X, y = readImages(strSet = 'Train', nIni = 101, nFin = 105)

<ipython-input-123-9e9ebc660c33> in readImages(strSet, nIni, nFin)
     50     # Recogemos los valores de entrada y salida en arrays.
     51     y = np.array(arrLabels)
---> 52     X = np.array(arrImages, dtype=np.uint8)
     53     # Una vez terminado el recorrido por todas las imágenes, reordenamos los índices para que no vayan las imágenes en secuendias de la misma clase.
     54     arrIndexes = np.arange(X.shape[0])

ValueError: could not broadcast input array from shape (28,28,3) into shape (28,28)

我在读图的时候放了一些打印痕迹,读出来的每张图都是(28,28,3)的形状,所以我不太明白这个(28)是从哪里来的,28) 错误轨迹中指出的形状。

您知道可能是什么问题吗?您之前遇到过这个问题吗?

提前致谢。

您的某些图像只有单通道。使用 cv2.imread 而不是 plt.imread

image = cv2.imread(data_dir+'/'+strSet+'/'+strDir+'/'+file)