def show_batch( ) 不显示我的火车图像

def show_batch( ) not showing my train images

我按照他们网站上的说明在 TensorFlow 中预处理数据:https://www.tensorflow.org/tutorials/load_data/images

但是,将图像转换为张量并将每个图像分配给相应的标签后,我无法绘制它们。

我并行加载了对(图像,标签): labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)。然后我检查图像形状和相应的标签:

for image, label in labeled_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

然后得到 Image shape: (80, 80, 3) Label: [False False True False] 符合预期。

然后我定义了以下函数来批量准备数据集进行训练:

def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
  ds = ds.batch(100)
  ds = ds.prefetch(buffer_size=AUTOTUNE)

  if cache:
    if isinstance(cache, str):
      ds = ds.cache(cache)
    else:
      ds = ds.cache()

  ds = ds.shuffle(buffer_size=shuffle_buffer_size)
  ds = ds.repeat()

  return ds



train_ds = prepare_for_training(labeled_ds)
image_batch, label_batch = next(iter(train_ds))

但是当我想使用 plt.show() 显示每个图像及其标签时,图像不显示。我就是这样做的:

def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
        plt.axis('off')
        return plt.show()


show_batch(image_batch.numpy(), label_batch.numpy())

关于为什么我的图像可能无法显示的任何线索?

为什么要添加 return plt.show()Official tutorial 没有这条线(而且它有效!) - 通过在 for 循环的第一次迭代中调用 return,您不允许显示图像。

return plt.show 在 for 循环中。它需要像这样在外面:

def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
        plt.axis('off')

    return plt.show()