TensorFlow:在不同输出形状的数据集之间交替

TensorFlow: alternate between datasets of different output shapes

我正在尝试将 tf.Dataset 用于 3D 图像 CNN,其中从训练集和验证集输入的 3D 图像的形状不同(训练:(64, 64, 64 ), 验证: (176, 176, 160)).我什至不知道这是可能的,但我正在根据一篇论文重新创建这个网络,并使用经典的 feed_dict 方法网络确实有效。出于性能原因(并且只是为了学习),我正在尝试将网络切换为使用 tf.Dataset

我有两个数据集和迭代器,如下所示:

def _data_parser(dataset, shape):
        features = {"input": tf.FixedLenFeature((), tf.string),
                    "label": tf.FixedLenFeature((), tf.string)}
        parsed_features = tf.parse_single_example(dataset, features)

        image = tf.decode_raw(parsed_features["input"], tf.float32)
        image = tf.reshape(image, shape + (1,))

        label = tf.decode_raw(parsed_features["label"], tf.float32)
        label = tf.reshape(label, shape + (1,))
        return image, label

train_datasets = ["train.tfrecord"]
train_dataset = tf.data.TFRecordDataset(train_datasets)
train_dataset = train_dataset.map(lambda x: _data_parser(x, (64, 64, 64)))
train_dataset = train_dataset.batch(batch_size) # batch_size = 16
train_iterator = train_dataset.make_initializable_iterator()

val_datasets = ["validation.tfrecord"]
val_dataset = tf.data.TFRecordDataset(val_datasets)
val_dataset = val_dataset.map(lambda x: _data_parser(x, (176, 176, 160)))
val_dataset = val_dataset.batch(1)
val_iterator = val_dataset.make_initializable_iterator()

TensorFlow documentation 有关于使用 reinitializable_iteratorfeedable_iterator 在数据集之间切换的示例,但它们都在 same 输出形状的迭代器之间切换,这里不是这种情况。

那么我应该如何使用 tf.Datasettf.data.Iterator 在训练集和验证集之间切换?

只需为尺寸不匹配的轴上的形状提供未指定 (None) 的值。例如

import numpy as np
import tensorflow as tf

training_dataset = tf.data.Dataset.from_tensors(np.zeros((64, 64, 64), np.float32)).repeat().batch(4)
validation_dataset = tf.data.Dataset.from_tensors(np.zeros((176, 176, 160), np.float32)).repeat().batch(1)

iterator = tf.data.Iterator.from_structure(
    training_dataset.output_types,
    <b>tf.TensorShape([None, None, None, None])</b>)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

sess = tf.InteractiveSession()
sess.run(training_init_op)
print(sess.run(next_element).shape)
sess.run(validation_init_op)
print(sess.run(next_element).shape)