TensorFlow:如果我在图形中处理二次计算,是否会消耗更多内存?
TensorFlow: Will more memory be consumed if I handle secondary computations within a graph?
如果我有一个来自 Google (inception-resnet-v2) 的经过训练的 Imagnet 模型,并且该模型实际上返回了两个输出:logits
和一个名为 [=22= 的列表] 我可以在其中提取已使用 softmax 激活执行的最终预测层,即名为 predictions
的变量。但是,这并没有明确地给我预测所需的 class 标签。为此,我要么必须在 label = tf.argmax(predictions, 1)
之后执行 我在图中定义了 train_op,这样我就不会影响原始计算。
或者,我可以使用从图中计算出来的 np.argmax(sess.run(predictions), 1)
。
我的问题是,如果我选择第一种方法,它是否会消耗更多内存并影响我的计算(就我可以使用的 batch_size 而言)?仅从图中计算必要的标签是否更安全更好?
当您发出多个 .run
调用时,图形定义会被缓存。如果您修改 Graph,它需要 re-encode 它并再次发送。因此,在您第一次 运行 修改图表时,graph_def.SerializeToString
可能会使用一些额外的内存,但这应该不会影响之后的 .run
步。
相关逻辑在session.py,注意检查self._graph.version > self._current_version
的那一行
def _extend_graph(self):
# Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock:
if self._graph.version > self._current_version:
# pylint: disable=protected-access
graph_def, self._current_version = self._graph._as_graph_def(
from_version=self._current_version,
add_shapes=self._add_shapes)
# pylint: enable=protected-access
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_ExtendGraph(
self._session, graph_def.SerializeToString(), status)
self._opened = True
如果我有一个来自 Google (inception-resnet-v2) 的经过训练的 Imagnet 模型,并且该模型实际上返回了两个输出:logits
和一个名为 [=22= 的列表] 我可以在其中提取已使用 softmax 激活执行的最终预测层,即名为 predictions
的变量。但是,这并没有明确地给我预测所需的 class 标签。为此,我要么必须在 label = tf.argmax(predictions, 1)
之后执行 我在图中定义了 train_op,这样我就不会影响原始计算。
或者,我可以使用从图中计算出来的 np.argmax(sess.run(predictions), 1)
。
我的问题是,如果我选择第一种方法,它是否会消耗更多内存并影响我的计算(就我可以使用的 batch_size 而言)?仅从图中计算必要的标签是否更安全更好?
当您发出多个 .run
调用时,图形定义会被缓存。如果您修改 Graph,它需要 re-encode 它并再次发送。因此,在您第一次 运行 修改图表时,graph_def.SerializeToString
可能会使用一些额外的内存,但这应该不会影响之后的 .run
步。
相关逻辑在session.py,注意检查self._graph.version > self._current_version
def _extend_graph(self):
# Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock:
if self._graph.version > self._current_version:
# pylint: disable=protected-access
graph_def, self._current_version = self._graph._as_graph_def(
from_version=self._current_version,
add_shapes=self._add_shapes)
# pylint: enable=protected-access
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_ExtendGraph(
self._session, graph_def.SerializeToString(), status)
self._opened = True