如何在 'fit' 调用期间调试 Tensorflow logits/label 形状问题?

How to debug Tensorflow logits/label shape issue during 'fit' call?

我正在尝试使用 TensorFlow 训练分割 U-Net。我计算机上的数据集图像在 运行 模型之前进行了预处理并保存为 Tfrecords。

所以在训练之前,我用 tf.data 加载了 tfrecords。如果我看一个例子,我明白了:

ds_train = tf.data.TFRecordDataset(training_tfrecords).map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = ds_train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
one_record = train_dataset.take(1)
for i, (img, seg) in enumerate(one_record):
    print(f"BatchImg{i}: {img.shape}")
    print(f"BatchSeg{i}: {seg.shape}")

输出(6=批量大小,96=img 尺寸,1=通道暗淡):

BatchImg0: (6, 96, 96, 96, 1)
BatchSeg0: (6, 96, 96, 96, 1)

到目前为止看起来还不错。但是当我尝试开始训练时,出现以下错误:

model_history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=180)

输出:

ValueError: logits and labels must have the same shape ((None, 96, 96, 96, 16) vs (None, None, None, None, 1))

我不确定如何进一步调试...谢谢您的帮助!

正如 Andrey 指出的(谢谢!),我不仅需要检查 我的输入图像形状,还要检查模型的输出 (logits) 形状 。问题仅来自 logits 和分割之间的形状差异(最后一个维度中的“16”与 1)。所以模型输出不正确。修复后(使用 model.summary() 找出要做什么),一切 运行 就好了!