层 lstm_9 的输入 0 与层不兼容:预期 ndim=3,发现 ndim=4。已收到完整形状:[None, 2, 4000, 256]

Input 0 of layer lstm_9 is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: [None, 2, 4000, 256]

我尝试使用 RNN 网络创建模型,但我收到:层 lstm_9 的输入 0 与层不兼容:预期 ndim=3,发现 ndim=4。收到完整形状:[None, 2, 4000, 256] 错误。

输入

train_data.shape() = (100,2,4000)

train_labels.shape() =(100,)

labels_values = 0 or 1 (two classes)

型号

input = Input(shape=(2,4000)) # shape from train_data
embedded = Embedding(2, 256)(input) 
lstm = LSTM(1024, return_sequences=True)(embedded) # ERROR
dense = Dense(2, activation='softmax')(lstm) 

不幸的是,你设计带有嵌入层的 Keras 功能模型的整个概念是错误的。

  1. 当您使用嵌入层时,它需要二维数据。
Input shape

2D tensor with shape: (batch_size, sequence_length).

Output shape

3D tensor with shape: (batch_size, sequence_length, output_dim).

参考:https://keras.io/layers/embeddings/

词汇表采用一系列 ID 或标记。这必须是一个整数数组。

假设我们的词汇表有 len 36,我们将范围为 (0, 36) 的整数数组列表传递给它

[1, 34, 32, 23] 有效 [0.2, 0.5] 无效

  1. 通常我们用Embedding来表示reduced中的向量space,所以output_dim低于input_dim,但反之亦然关于设计。

  2. 您需要为输入数据指定input_length。

  3. 如果您使用 return_sequences = True,时间维度将被传递到下一个维度,这在您的情况下是不需要的。

  4. 您的标签格式为 (0, 1, 0, 1, 0, 0, ...) 而不是单热编码格式,所以不要使用 softmax 但sigmoid 在最后一个 dense 中有 1 个单元。

这是修正后的网络。

from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import numpy as np
train_data = np.random.randint(0,3, (100, 4000))
y_labels = np.random.randint(0,2, (100,))

input_ = Input(shape=(4000)) # shape from train_data
embedded = Embedding(36, 256, input_length = 4000)(input_) 
lstm = LSTM(256, return_sequences=False)(embedded) # --> ERROR
dense = Dense(1, activation='softmax')(lstm) 

model = Model(input_, dense)
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_6 (InputLayer)         [(None, 4000)]            0         
_________________________________________________________________
embedding_5 (Embedding)      (None, 4000, 256)         9216      
_________________________________________________________________
lstm_5 (LSTM)                (None, 256)               525312    
_________________________________________________________________
dense (Dense)                (None, 1)                 257       
=================================================================
Total params: 534,785
Trainable params: 534,785
Non-trainable params: 0