如何修复 tf.keras 中的断言错误?
How to fix this AssertionError in tf.keras?
我有一个简单的 tf.keras 函数模型,它有两个输入,一个用于 CNN 层,一个用于 Dense 层。
Batch_Size=64
train_dataset1=Dataset.from_tensor_slices((Xs_train, Xk_train,
y_train)).shuffle(Batch_Size*2).batch(Batch_Size)
valid_dataset1=Dataset.from_tensor_slices((Xs_valid, Xk_valid, y2_valid)).batch(Batch_Size)
# Input for CNN 1D
input1 = Input(shape=(col_seqs, ))
reshaped_input = Reshape((rc_maxlen,CHARS_COUNT))(input1)
conv1=Conv1D(16, 4, activation='relu')(reshaped_input)
flat1=Flatten()(conv1)
# Input for Dense
input2=Input(shape=(col_kmers, ))
dense1=Dense(16, activation='relu')(input2)
# Merge CNN and Dense outputs
merged=concatenate([flat1, dense1])
# Multiclass Classification layers
dense2=Dense(64, activation='relu')(merged)
output=Dense(num_classes, activation='softmax')(dense2)
model=Model(inputs=[input1, input2], outputs=output)
opt=Adam()
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='accuracy')
es = EarlyStopping(monitor='val_accuracy', patience=15, mode='max', verbose=0, restore_best_weights=True)
history=model.fit(train_dataset1, verbose=1,
validation_data=valid_dataset1,
callbacks=[es],
epochs=500)
模型图看起来也正确。
但是,我收到一条我不理解的错误消息:
关于如何修复它的任何指示?
问题在于您指定输入训练和输入验证格式的方式。
这里的答案(两个输入,两个输出):,看看输入训练和输入验证的馈送方式:
history = model.fit({'I1':x1, 'I2':x2},
{'O1':y1, 'O2': y2},
validation_data=([val_x1,val_x2], [val_y1,val_y2]),
epochs=100,
verbose = 1)
我有一个简单的 tf.keras 函数模型,它有两个输入,一个用于 CNN 层,一个用于 Dense 层。
Batch_Size=64
train_dataset1=Dataset.from_tensor_slices((Xs_train, Xk_train,
y_train)).shuffle(Batch_Size*2).batch(Batch_Size)
valid_dataset1=Dataset.from_tensor_slices((Xs_valid, Xk_valid, y2_valid)).batch(Batch_Size)
# Input for CNN 1D
input1 = Input(shape=(col_seqs, ))
reshaped_input = Reshape((rc_maxlen,CHARS_COUNT))(input1)
conv1=Conv1D(16, 4, activation='relu')(reshaped_input)
flat1=Flatten()(conv1)
# Input for Dense
input2=Input(shape=(col_kmers, ))
dense1=Dense(16, activation='relu')(input2)
# Merge CNN and Dense outputs
merged=concatenate([flat1, dense1])
# Multiclass Classification layers
dense2=Dense(64, activation='relu')(merged)
output=Dense(num_classes, activation='softmax')(dense2)
model=Model(inputs=[input1, input2], outputs=output)
opt=Adam()
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='accuracy')
es = EarlyStopping(monitor='val_accuracy', patience=15, mode='max', verbose=0, restore_best_weights=True)
history=model.fit(train_dataset1, verbose=1,
validation_data=valid_dataset1,
callbacks=[es],
epochs=500)
模型图看起来也正确。
但是,我收到一条我不理解的错误消息:
关于如何修复它的任何指示?
问题在于您指定输入训练和输入验证格式的方式。
这里的答案(两个输入,两个输出):
history = model.fit({'I1':x1, 'I2':x2},
{'O1':y1, 'O2': y2},
validation_data=([val_x1,val_x2], [val_y1,val_y2]),
epochs=100,
verbose = 1)