如何使用 tf.data 在 CNN 中包含 rgb 和灰度图像?
How to include rgb and grayscale images in CNN using tf.data?
我正在尝试使用 rgb 图像作为输入,灰度图像作为基于 的标签图像。如何修改以下代码来定义标签图像包含一个通道?
# step 1
filenames = tf.constant(input_list)
labels = tf.constant(label_list)
# step 2: create a dataset returning slices of `filenames`
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
# step 3: parse every image in the dataset using `map`
def _parse_function(filename, label):
image_string = tf.io.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
return image, label
dataset = dataset.map(_parse_function)
dataset = dataset.batch(2)
# step 4: create iterator and final input tensor
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
images, labels = iterator.get_next()
使用此函数加载具有不同通道数的图像:
def _parse_function(filename, channels):
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=channels)
image = tf.image.convert_image_dtype(image, tf.float32)
return image
然后:
dataset = dataset.map(lambda x, y:
(
_parse_function(x, channels=3),
_parse_function(y, channels=1)
)
)
我正在尝试使用 rgb 图像作为输入,灰度图像作为基于
# step 1
filenames = tf.constant(input_list)
labels = tf.constant(label_list)
# step 2: create a dataset returning slices of `filenames`
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
# step 3: parse every image in the dataset using `map`
def _parse_function(filename, label):
image_string = tf.io.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
return image, label
dataset = dataset.map(_parse_function)
dataset = dataset.batch(2)
# step 4: create iterator and final input tensor
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
images, labels = iterator.get_next()
使用此函数加载具有不同通道数的图像:
def _parse_function(filename, channels):
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=channels)
image = tf.image.convert_image_dtype(image, tf.float32)
return image
然后:
dataset = dataset.map(lambda x, y:
(
_parse_function(x, channels=3),
_parse_function(y, channels=1)
)
)