不是 Tensor 块内的 Tensor API 调用

Not a Tensor API call inside a Tensor block

我有一个使用张量的代码 API。但是我向它添加了一个函数调用,据我所知,我不能在张量中使用它(但不确定)。在下面的代码中,调用:

image= io.imread(image_path , plugin='simpleitk')

但是重置的都是张量。目前,运行 因未知原因而失败。是否有更改代码(io.imread 或重置)以相互配合的解决方案:

    !pip install SimpleITK
    # Image preprocessing utils
    import skimage.io as io
    @tf.function
    def parse_images(image_path):
        
        #--------------io.read is not tensor
        
        image= io.imread(image_path , plugin='simpleitk')
        
        #---------------
        
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize(image, size=[224, 224])
    
        return image

通过此代码调用:

import sys
# Create TensorFlow dataset
BATCH_SIZE = 64

train_ds = tf.data.Dataset.from_tensor_slices( my_train_images) 

train_ds = (
    train_ds
    .map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .shuffle(1024)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

skimage.io 似乎每次在图形模式和急切执行模式下都会崩溃。您可以尝试使用 tf.io.read_file:

生成随机数据

import numpy
from PIL import Image

for i in range(5):
  imarray = numpy.random.rand(300,300,3) * 255
  im = Image.fromarray(imarray.astype('uint8')).convert('RGBA')
  im.save('result_image{}.png'.format(i))

处理数据

import tensorflow as tf
import matplotlib.pyplot as plt

normalization_layer = tf.keras.layers.Rescaling(1./255)

@tf.function
def parse_images(image_path):
    raw = tf.io.read_file(image_path) 
    image = tf.image.decode_png(raw, channels=3)
    image = tf.image.resize(normalization_layer(image), size=[224, 224])
    return image

train_ds = tf.data.Dataset.from_tensor_slices(['/content/result_image0.png', 
                                               '/content/result_image1.png', 
                                               '/content/result_image2.png', 
                                               '/content/result_image3.png', 
                                               '/content/result_image4.png']) 

train_ds = (
    train_ds
    .map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .shuffle(5)
    .batch(2, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

images = next(iter(train_ds.take(1)))

image = images[0] # (224, 224, 3)
plt.imshow(image.numpy())

更新 1:对于 *.mhd 图像,您可以尝试将 tf.py_functionSimpleITK 库一起使用:

import tensorflow as tf
import matplotlib.pyplot as plt
import SimpleITK as sitk

normalization_layer = tf.keras.layers.Rescaling(1./255)

def parse_images(image_path):
    itkimage = sitk.ReadImage(image_path.numpy().decode("utf-8"))
    image = sitk.GetArrayFromImage(itkimage)
    image = tf.image.resize(normalization_layer(image), size=[224, 224])
    return image

train_ds = tf.data.Dataset.from_tensor_slices(['/content/result_image0.mhd', 
                                               '/content/result_image1.mhd', 
                                               '/content/result_image2.mhd', 
                                               '/content/result_image3.mhd', 
                                               '/content/result_image4.mhd']) 
train_ds = (
    train_ds
    .map(lambda x: tf.py_function(parse_images, [x], tf.float32))
    .shuffle(5)
    .batch(2, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

images = next(iter(train_ds.take(1)))

image = images[0] # (224, 224, 3)
plt.imshow(image.numpy())