Tensorflow - 在 "batch-level" 而不是 "example-level" 处洗牌
Tensorflow - shuffling at "batch-level" instead of"example-level"
我有一个问题,我会尝试用一个例子来解释,以便于理解。
我想对橙子 (O) 和苹果 (A) 进行分类。由于 technical/legacy 个原因(网络中的一个组件),每个批次应该只有 O 或只有 A 个示例。因此,示例级别的传统洗牌不是 possible/adequate,因为我负担不起包含 O 和 A 示例混合的批处理。然而,某种洗牌是可取的,因为这是训练深度网络的常见做法。
这些是我采取的步骤:
- 我首先需要将原始 data/examples 转换为 TFRecords。
- 我打乱了原始示例的顺序,然后我创建了单独的 TFRecords,其中仅包含打乱后的 O 示例,或仅包含打乱后的 A 示例。我们称此为“example-level”洗牌。这是离线发生的事情,而且只发生一次。
- 此时我有 "clean batches":仅包含 O 个示例的 O-baches,以及仅包含 A 个示例的 A-batches。
- 我不想先向网络提供所有 O 批次,然后按顺序向网络提供所有 A 批次。这可能对收敛没有多大帮助。
- 我可以在“batch-level”上打乱这些批次,即不影响它们的内部吗?
如果您使用 Dataset
api 则相当简单。只需压缩 O
和 A
批次,然后使用 Dataset.map()
:
应用随机选择函数
ds0 = tf.data.Dataset.from_tensor_slices([0])
ds0 = ds0.repeat()
ds0 = ds0.batch(5)
ds1 = tf.data.Dataset.from_tensor_slices([1])
ds1 = ds1.repeat()
ds1 = ds1.batch(5)
def rand_select(ds0, ds1):
rval = tf.random_uniform([])
return tf.cond(rval<0.5, lambda: ds0, lambda: ds1)
dataset = tf.data.Dataset()
dataset = dataset.zip((ds0, ds1)).map(lambda ds0, ds1: rand_select(ds0, ds1))
iterator = dataset.make_one_shot_iterator()
ds = iterator.get_next()
with tf.Session() as sess:
for _ in range(5):
print(sess.run(ds))
> [0 0 0 0 0]
[1 1 1 1 1]
[1 1 1 1 1]
[0 0 0 0 0]
[0 0 0 0 0]
我有一个问题,我会尝试用一个例子来解释,以便于理解。
我想对橙子 (O) 和苹果 (A) 进行分类。由于 technical/legacy 个原因(网络中的一个组件),每个批次应该只有 O 或只有 A 个示例。因此,示例级别的传统洗牌不是 possible/adequate,因为我负担不起包含 O 和 A 示例混合的批处理。然而,某种洗牌是可取的,因为这是训练深度网络的常见做法。
这些是我采取的步骤:
- 我首先需要将原始 data/examples 转换为 TFRecords。
- 我打乱了原始示例的顺序,然后我创建了单独的 TFRecords,其中仅包含打乱后的 O 示例,或仅包含打乱后的 A 示例。我们称此为“example-level”洗牌。这是离线发生的事情,而且只发生一次。
- 此时我有 "clean batches":仅包含 O 个示例的 O-baches,以及仅包含 A 个示例的 A-batches。
- 我不想先向网络提供所有 O 批次,然后按顺序向网络提供所有 A 批次。这可能对收敛没有多大帮助。
- 我可以在“batch-level”上打乱这些批次,即不影响它们的内部吗?
如果您使用 Dataset
api 则相当简单。只需压缩 O
和 A
批次,然后使用 Dataset.map()
:
ds0 = tf.data.Dataset.from_tensor_slices([0])
ds0 = ds0.repeat()
ds0 = ds0.batch(5)
ds1 = tf.data.Dataset.from_tensor_slices([1])
ds1 = ds1.repeat()
ds1 = ds1.batch(5)
def rand_select(ds0, ds1):
rval = tf.random_uniform([])
return tf.cond(rval<0.5, lambda: ds0, lambda: ds1)
dataset = tf.data.Dataset()
dataset = dataset.zip((ds0, ds1)).map(lambda ds0, ds1: rand_select(ds0, ds1))
iterator = dataset.make_one_shot_iterator()
ds = iterator.get_next()
with tf.Session() as sess:
for _ in range(5):
print(sess.run(ds))
> [0 0 0 0 0]
[1 1 1 1 1]
[1 1 1 1 1]
[0 0 0 0 0]
[0 0 0 0 0]