BatchDataset 显示图像和标签

BatchDataset display images and label

我有一个 TrainValidation 批处理数据集:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_path,
    label_mode = 'categorical', #it is used for multiclass classification. It is one hot encoded labels for each class
    validation_split = 0.2,     #percentage of dataset to be considered for validation
    subset = "training",        #this subset is used for training
    seed = 1337,                # seed is set so that same results are reproduced
    image_size = img_size,      # shape of input images
    batch_size = batch_size,    # This should match with model batch size
)

valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_path,
    label_mode ='categorical',
    validation_split = 0.2,
    subset = "validation",      #this subset is used for validation
    seed = 1337,
    image_size = img_size,
    batch_size = batch_size,
)

我试图显示 9 张图像以显示它们的外观,我设法做到了,但我似乎无法绘制它们各自的标签。

代码如下:

class_names = train_ds.class_names


plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.axis("off")

其中显示:

如果我尝试通过添加以下内容来获取标签:plt.title(class_names[labels[i]])

我收到以下错误:TypeError: only integer scalar arrays can be converted to a scalar index

我尝试过其他帖子的解决方案,例如以下 plt.title(class_names[labels[i][0]]) 但没有成功。

当我打印标签时[i],我得到了标签的一种热编码...也许这就是为什么?

结果代码:

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[np.argmax(labels[i], axis=None, out=None)])
    plt.axis("off")

根据您最后的评论,您尝试过使用 argmax 吗?

numpy.argmax(a, axis=None, out=None)

这 returns 沿轴的最大值的索引。

试试下面的代码:

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.
    plt.axis("off")