如何在 Tensorflow 2.0 中管理队列?
How can I manage Queues in Tensorflow 2.0?
好吧,我正在尝试了解线程和队列。
在网上看了很多文档,没想到在tensorflow 2.0中竟然连一个这样的例子都没有
我想让我的队列做的是,
- 定义生成示例的操作。
- 定义队列。
- 定义一个 enqueue_operation 将示例放入上面使用多个线程创建的队列中。
- 控制此队列使批次出队。
我的想法是,
import tensorflow as tf
import threading
batch_size = 2
example = tf.random.normal([1, 2]) # Generate an example, shape = [1, 2]
queue = tf.queue.RandomShuffleQueue(capacity=10, min_after_dequeue=0, \
dyptes=tf.float32, shapes=[1, 2])
enqueue_op = queue.enqueue(example)
# inputs = queue.dequeue(2) # Don't run this. This would stop your computer.
我不知道我在做什么。我还了解到如何使用 tf.train.Coordinator()
管理多个线程,但我不知道在哪里使用它..
在问这个问题时,我怀疑 tf.data.Dataset
中的许多 API 替换了所有这些并且多线程可以替换为 tf.data.experimental.AUTOTUNE
。
很抱歉这里弄得一团糟。即使在询问期间我也无法正确安排。
任何意见将不胜感激。提前致谢。
我认为首选方法是使用 tf.data.Dataset
API。你可以关注这个link。我还将重点介绍可帮助您为批处理实现多线程的重要代码。
dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
此外,link 还记录了以下内容:
Performance can often be improved by setting num_parallel_calls so
that map will use multiple threads to process elements. If
deterministic order isn't required, it can also improve performance to
set deterministic=False.
我认为你应该遵循这个 API。
同时查看 cache
和 prefetch
API,这优化了输入管道。
好吧,我正在尝试了解线程和队列。
在网上看了很多文档,没想到在tensorflow 2.0中竟然连一个这样的例子都没有
我想让我的队列做的是,
- 定义生成示例的操作。
- 定义队列。
- 定义一个 enqueue_operation 将示例放入上面使用多个线程创建的队列中。
- 控制此队列使批次出队。
我的想法是,
import tensorflow as tf
import threading
batch_size = 2
example = tf.random.normal([1, 2]) # Generate an example, shape = [1, 2]
queue = tf.queue.RandomShuffleQueue(capacity=10, min_after_dequeue=0, \
dyptes=tf.float32, shapes=[1, 2])
enqueue_op = queue.enqueue(example)
# inputs = queue.dequeue(2) # Don't run this. This would stop your computer.
我不知道我在做什么。我还了解到如何使用 tf.train.Coordinator()
管理多个线程,但我不知道在哪里使用它..
在问这个问题时,我怀疑 tf.data.Dataset
中的许多 API 替换了所有这些并且多线程可以替换为 tf.data.experimental.AUTOTUNE
。
很抱歉这里弄得一团糟。即使在询问期间我也无法正确安排。
任何意见将不胜感激。提前致谢。
我认为首选方法是使用 tf.data.Dataset
API。你可以关注这个link。我还将重点介绍可帮助您为批处理实现多线程的重要代码。
dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
此外,link 还记录了以下内容:
Performance can often be improved by setting num_parallel_calls so that map will use multiple threads to process elements. If deterministic order isn't required, it can also improve performance to set deterministic=False.
我认为你应该遵循这个 API。
同时查看 cache
和 prefetch
API,这优化了输入管道。