用于对补丁进行分类的 Tensorflow 数据集管道

Tensorflow dataset pipeline for classifying patches

我正在尝试在 tensorflow 中编写数据集管道来标记图像补丁。现在我正在读取一堆 tfrecord 文件,其中每个文件都有多个补丁,但只有一个标签。标签有四个类.

当我通过管道传递单例时,Tensorflow 似乎不喜欢它。我收到以下错误:

ValueError: Value Tensor("args_1:0", shape=(), dtype=int32) has insufficient rank for batching.

我正在想办法让这个用例发挥作用。这基本上是我想要做的。我可以使用关于我应该对 y 做什么的建议,这样我就可以在管道末端为每个补丁获得一个标签。如果我需要更改 tfrecord 文件的结构,使 y 成为 onehot 编码向量,那很好;只是不知道有没有这个必要

def parse_func(proto):
    features = tf.io.parse_single_example(
        serialized=proto,
        features={'X': tf.io.FixedLenFeature([], tf.string),
                  'length': tf.io.FixedLenFeature([], tf.int64),
                  'y': tf.io.FixedLenFeature([], tf.int64)})

    y = tf.cast(features['y'], tf.int32)  # this is just an integer, but maybe it should be a one-hot encoded vector

    X = tf.io.decode_raw(features['X'], tf.float32)
    length = tf.cast(features['length'], tf.int32)
    shape = tf.stack([length, 60, 1])
    return tf.reshape(X, shape), y


def get_patches(X, y):
    X = X[tf.newaxis, ...]

    patches = tf.image.extract_patches(X,
                                       sizes=[1,  128, 60, 1],
                                       strides=[1, 4, 1, 1],
                                       rates=[1, 1, 1, 1],
                                       padding='VALID')
    patches = tf.reshape(patches, [-1, 128, 60, 1])
    y = repeat_so_that_there_is_one_label_per_patch(y)
    return patches, y


dataset = (tf.data.Dataset.from_tensor_slices('tf_record_file_paths')
           .shuffle(100)
           .interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4)
           .map(parse_func)
           .map(get_patches)
           .unbatch()
           .shuffle(100)
           .repeat()
           .batch(64, drop_remainder=True)
           .prefetch(1))

我按如下方式解决了这个问题:

def repeat_so_that_there_is_one_label_per_patch(y, patches):
    num_patches = tf.shape(patches)[0]
    tiled_y = tf.tile(y, multiples=[num_patches])
    return tf.reshape(tiled_y, tf.shape(y) * num_patches)