在 tensorflow v2 的图中使用 tf.timestamp()

Using tf.timestamp() in the graph in tensorflow v2

我正在尝试对我的 tf 模型中的特定块进行基准测试。因此,我正在尝试使用 tf.timestamp()。我需要将它包含在图形执行中,以便每次调用模型时都会执行它。

我实际上可以通过使用 tf.compat.v1.Print() 来做到这一点,

x = self.mp1(x)
x = tf.compat.v1.Print(x, [tf.timestamp()])
x = self.c3(x)

但是这是在打印值,这会导致一些开销。相反,我想将它存储到某个变量中,以便我可以在执行后使用它。有没有其他方法可以将 tf.timestamp() 嵌入到 tf2.

的图形中

如果您使用此图进行优化(例如作为深度学习模型), 你可能会 运行 model.fit() 函数。 然后你可以使用 callback 函数在每个纪元后保存。

import tensorflow as tf
from tf.keras.callbacks import ModelCheckpoint
EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)

我尝试了几件事后解决了这个问题。我所做的是:

  • 将自定义层定义为
class TimeStamp(layers.Layer):
    def __init__(self):
        super(TimeStamp, self).__init__()
        self.ts = tf.Variable(initial_value=0., dtype=tf.float64, trainable=False)

    def call(self, inputs):
        self.ts = tf.timestamp()
        return tf.identity(inputs)

    def getTs(self):
        return self.ts

  • 然后在模型中多次使用它,然后通过减去这些 self.ts 值得到经过的时间。