如何修复 tf.keras 中的断言错误?

How to fix this AssertionError in tf.keras?

我有一个简单的 tf.keras 函数模型,它有两个输入,一个用于 CNN 层,一个用于 Dense 层。

Batch_Size=64
train_dataset1=Dataset.from_tensor_slices((Xs_train, Xk_train, 
    y_train)).shuffle(Batch_Size*2).batch(Batch_Size)
valid_dataset1=Dataset.from_tensor_slices((Xs_valid, Xk_valid, y2_valid)).batch(Batch_Size)

# Input for CNN 1D 
input1 = Input(shape=(col_seqs, ))
reshaped_input = Reshape((rc_maxlen,CHARS_COUNT))(input1)
conv1=Conv1D(16, 4, activation='relu')(reshaped_input)
flat1=Flatten()(conv1)
# Input for Dense
input2=Input(shape=(col_kmers, ))
dense1=Dense(16, activation='relu')(input2)
# Merge CNN and Dense outputs
merged=concatenate([flat1, dense1])
# Multiclass Classification layers
dense2=Dense(64, activation='relu')(merged)
output=Dense(num_classes, activation='softmax')(dense2)

model=Model(inputs=[input1, input2], outputs=output)

opt=Adam()
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='accuracy')  
es = EarlyStopping(monitor='val_accuracy', patience=15, mode='max', verbose=0, restore_best_weights=True)

history=model.fit(train_dataset1, verbose=1, 
              validation_data=valid_dataset1, 
              callbacks=[es], 
              epochs=500) 

模型图看起来也正确。

但是,我收到一条我不理解的错误消息:

关于如何修复它的任何指示?

问题在于您指定输入训练和输入验证格式的方式。

这里的答案(两个输入,两个输出):,看看输入训练和输入验证的馈送方式:

history = model.fit({'I1':x1, 'I2':x2},
                    {'O1':y1, 'O2': y2},
                    validation_data=([val_x1,val_x2], [val_y1,val_y2]),
                    epochs=100,
                    verbose = 1)