TensorFlow:将 tf.Dataset 转换为 tf.Tensor
TensorFlow: convert tf.Dataset to tf.Tensor
我想生成范围为 10 的 windows:
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
并想在此数据集上训练我的模型。
为此,必须将那些 windows 转换为张量。但是这些 windows 的数据类型不能通过 tf.convert_to_tensor
转换为张量。 tf.convert_to_tensor(list(window))
是可以的,但是效率很低。
有谁知道如何有效地将 tf.VariantDataset
转换为 tf.Tensor
?
感谢您的帮助!
如果你想创建一个滑动张量 windows,通过数据集来做并不是最好的方法,效率和灵活性要低得多。我认为没有合适的操作,但是对于 2D 和 3D 数组有两个相似的操作,tf.image.extract_patches
and tf.extract_volume_patches
。您可以重塑一维数据以使用它们:
import tensorflow as tf
a = tf.range(10)
win_size = 5
stride = 1
# Option 1
a_win = tf.image.extract_patches(tf.reshape(a, [1, -1, 1, 1]),
sizes=[1, win_size, 1, 1],
strides=[1, stride, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID')[0, :, 0]
# Option 2
a_win = tf.extract_volume_patches(tf.reshape(a, [1, -1, 1, 1, 1]),
ksizes=[1, win_size, 1, 1, 1],
strides=[1, stride, 1, 1, 1],
padding='VALID')[0, :, 0, 0]
# Print result
print(a_win.numpy())
# [[0 1 2 3 4]
# [1 2 3 4 5]
# [2 3 4 5 6]
# [3 4 5 6 7]
# [4 5 6 7 8]
# [5 6 7 8 9]]
我想生成范围为 10 的 windows:
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
并想在此数据集上训练我的模型。
为此,必须将那些 windows 转换为张量。但是这些 windows 的数据类型不能通过 tf.convert_to_tensor
转换为张量。 tf.convert_to_tensor(list(window))
是可以的,但是效率很低。
有谁知道如何有效地将 tf.VariantDataset
转换为 tf.Tensor
?
感谢您的帮助!
如果你想创建一个滑动张量 windows,通过数据集来做并不是最好的方法,效率和灵活性要低得多。我认为没有合适的操作,但是对于 2D 和 3D 数组有两个相似的操作,tf.image.extract_patches
and tf.extract_volume_patches
。您可以重塑一维数据以使用它们:
import tensorflow as tf
a = tf.range(10)
win_size = 5
stride = 1
# Option 1
a_win = tf.image.extract_patches(tf.reshape(a, [1, -1, 1, 1]),
sizes=[1, win_size, 1, 1],
strides=[1, stride, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID')[0, :, 0]
# Option 2
a_win = tf.extract_volume_patches(tf.reshape(a, [1, -1, 1, 1, 1]),
ksizes=[1, win_size, 1, 1, 1],
strides=[1, stride, 1, 1, 1],
padding='VALID')[0, :, 0, 0]
# Print result
print(a_win.numpy())
# [[0 1 2 3 4]
# [1 2 3 4 5]
# [2 3 4 5 6]
# [3 4 5 6 7]
# [4 5 6 7 8]
# [5 6 7 8 9]]