在 tensorflow 中使用 monitoredtrainingsession 获取运行时统计信息
getting runtime statistics with monitoredtrainingsession in tensorflow
我正在尝试按照 运行 时间统计指令 here 获取我的 tensorflow 代码配置文件(运行网络中每一层的宁和内存消耗)。据我了解,我需要像这样创建 运行 选项和 运行 元数据
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
并将它们传递给 sess.run
但是,因为我也在尝试使用 tf.train.MonitoredTrainingSession
,所以我不知道我是否可以将同样的东西传递给这个 class。一种合理的方法可以使用 Hooks,但我不知道该怎么做。我对他们还是很陌生
您可以简单地创建一个自定义挂钩并将其传递给 MonitoredTrainingSession
。无需将您自己的 tf.RunMetadata()
实例传递给 运行 调用。
这是一个 Hook 示例,它每 N 步将元数据存储到 ckptdir:
import tensorflow as tf
class TraceHook(tf.train.SessionRunHook):
"""Hook to perform Traces every N steps."""
def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE):
self._trace = every_step == 1
self.writer = tf.summary.FileWriter(ckptdir)
self.trace_level = trace_level
self.every_step = every_step
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use _TraceHook.")
def before_run(self, run_context):
if self._trace:
options = tf.RunOptions(trace_level=self.trace_level)
else:
options = None
return tf.train.SessionRunArgs(fetches=self._global_step_tensor,
options=options)
def after_run(self, run_context, run_values):
global_step = run_values.results - 1
if self._trace:
self._trace = False
self.writer.add_run_metadata(run_values.run_metadata,
f'{global_step}', global_step)
if not (global_step + 1) % self.every_step:
self._trace = True
它检查 before_run
是否必须跟踪,如果是,则添加 RunOptions。在 after_run
中,它检查是否需要跟踪下一个 运行 调用,如果需要,它再次将 _trace
设置为 True。此外,它会在可用时存储元数据。
我正在尝试按照 运行 时间统计指令 here 获取我的 tensorflow 代码配置文件(运行网络中每一层的宁和内存消耗)。据我了解,我需要像这样创建 运行 选项和 运行 元数据
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
并将它们传递给 sess.run
但是,因为我也在尝试使用 tf.train.MonitoredTrainingSession
,所以我不知道我是否可以将同样的东西传递给这个 class。一种合理的方法可以使用 Hooks,但我不知道该怎么做。我对他们还是很陌生
您可以简单地创建一个自定义挂钩并将其传递给 MonitoredTrainingSession
。无需将您自己的 tf.RunMetadata()
实例传递给 运行 调用。
这是一个 Hook 示例,它每 N 步将元数据存储到 ckptdir:
import tensorflow as tf
class TraceHook(tf.train.SessionRunHook):
"""Hook to perform Traces every N steps."""
def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE):
self._trace = every_step == 1
self.writer = tf.summary.FileWriter(ckptdir)
self.trace_level = trace_level
self.every_step = every_step
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use _TraceHook.")
def before_run(self, run_context):
if self._trace:
options = tf.RunOptions(trace_level=self.trace_level)
else:
options = None
return tf.train.SessionRunArgs(fetches=self._global_step_tensor,
options=options)
def after_run(self, run_context, run_values):
global_step = run_values.results - 1
if self._trace:
self._trace = False
self.writer.add_run_metadata(run_values.run_metadata,
f'{global_step}', global_step)
if not (global_step + 1) % self.every_step:
self._trace = True
它检查 before_run
是否必须跟踪,如果是,则添加 RunOptions。在 after_run
中,它检查是否需要跟踪下一个 运行 调用,如果需要,它再次将 _trace
设置为 True。此外,它会在可用时存储元数据。