用于对补丁进行分类的 Tensorflow 数据集管道
Tensorflow dataset pipeline for classifying patches
我正在尝试在 tensorflow
中编写数据集管道来标记图像补丁。现在我正在读取一堆 tfrecord 文件,其中每个文件都有多个补丁,但只有一个标签。标签有四个类.
当我通过管道传递单例时,Tensorflow 似乎不喜欢它。我收到以下错误:
ValueError: Value Tensor("args_1:0", shape=(), dtype=int32) has insufficient rank for batching.
我正在想办法让这个用例发挥作用。这基本上是我想要做的。我可以使用关于我应该对 y
做什么的建议,这样我就可以在管道末端为每个补丁获得一个标签。如果我需要更改 tfrecord 文件的结构,使 y
成为 onehot 编码向量,那很好;只是不知道有没有这个必要
def parse_func(proto):
features = tf.io.parse_single_example(
serialized=proto,
features={'X': tf.io.FixedLenFeature([], tf.string),
'length': tf.io.FixedLenFeature([], tf.int64),
'y': tf.io.FixedLenFeature([], tf.int64)})
y = tf.cast(features['y'], tf.int32) # this is just an integer, but maybe it should be a one-hot encoded vector
X = tf.io.decode_raw(features['X'], tf.float32)
length = tf.cast(features['length'], tf.int32)
shape = tf.stack([length, 60, 1])
return tf.reshape(X, shape), y
def get_patches(X, y):
X = X[tf.newaxis, ...]
patches = tf.image.extract_patches(X,
sizes=[1, 128, 60, 1],
strides=[1, 4, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID')
patches = tf.reshape(patches, [-1, 128, 60, 1])
y = repeat_so_that_there_is_one_label_per_patch(y)
return patches, y
dataset = (tf.data.Dataset.from_tensor_slices('tf_record_file_paths')
.shuffle(100)
.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4)
.map(parse_func)
.map(get_patches)
.unbatch()
.shuffle(100)
.repeat()
.batch(64, drop_remainder=True)
.prefetch(1))
我按如下方式解决了这个问题:
def repeat_so_that_there_is_one_label_per_patch(y, patches):
num_patches = tf.shape(patches)[0]
tiled_y = tf.tile(y, multiples=[num_patches])
return tf.reshape(tiled_y, tf.shape(y) * num_patches)
我正在尝试在 tensorflow
中编写数据集管道来标记图像补丁。现在我正在读取一堆 tfrecord 文件,其中每个文件都有多个补丁,但只有一个标签。标签有四个类.
当我通过管道传递单例时,Tensorflow 似乎不喜欢它。我收到以下错误:
ValueError: Value Tensor("args_1:0", shape=(), dtype=int32) has insufficient rank for batching.
我正在想办法让这个用例发挥作用。这基本上是我想要做的。我可以使用关于我应该对 y
做什么的建议,这样我就可以在管道末端为每个补丁获得一个标签。如果我需要更改 tfrecord 文件的结构,使 y
成为 onehot 编码向量,那很好;只是不知道有没有这个必要
def parse_func(proto):
features = tf.io.parse_single_example(
serialized=proto,
features={'X': tf.io.FixedLenFeature([], tf.string),
'length': tf.io.FixedLenFeature([], tf.int64),
'y': tf.io.FixedLenFeature([], tf.int64)})
y = tf.cast(features['y'], tf.int32) # this is just an integer, but maybe it should be a one-hot encoded vector
X = tf.io.decode_raw(features['X'], tf.float32)
length = tf.cast(features['length'], tf.int32)
shape = tf.stack([length, 60, 1])
return tf.reshape(X, shape), y
def get_patches(X, y):
X = X[tf.newaxis, ...]
patches = tf.image.extract_patches(X,
sizes=[1, 128, 60, 1],
strides=[1, 4, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID')
patches = tf.reshape(patches, [-1, 128, 60, 1])
y = repeat_so_that_there_is_one_label_per_patch(y)
return patches, y
dataset = (tf.data.Dataset.from_tensor_slices('tf_record_file_paths')
.shuffle(100)
.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4)
.map(parse_func)
.map(get_patches)
.unbatch()
.shuffle(100)
.repeat()
.batch(64, drop_remainder=True)
.prefetch(1))
我按如下方式解决了这个问题:
def repeat_so_that_there_is_one_label_per_patch(y, patches):
num_patches = tf.shape(patches)[0]
tiled_y = tf.tile(y, multiples=[num_patches])
return tf.reshape(tiled_y, tf.shape(y) * num_patches)