在 for 循环中使用 tf.concat 添加值是否慢?

Is adding values with tf.concat slow in for-loops?

我正在使用 tensorflow 2.0 并尝试通过稍微优化我的代码来加快我的训练速度。

我 运行 我的模型是分批处理的,并希望保护每批的结果,以便在一个张量中的一个纪元结束时获得所有结果。

我的代码是这样的:

...
for epoch in range(start_epoch, end_epoch):

    # this vector shall hold all results for one epoch
    predictions_epoch = tf.zeros(0,)
   
    for batch in tf_dataset: 
        # get prediction with predictions_batch.shape[0] euqals batch_size
        predictions_batch = model(batch)   
        
        # Add the batch result to the previous results
        predictions_epoch = tf.concat(predictions_batch, predictions_epoch)
        
        # DO SOME OTHER STUFF LIKE BACKPROB
        ...

    # predictions_epoch.shape[0] now equals number of all samples in dataset
    with writer.as_default():
        tf.summary.histogram(name='predictions', data=predictions_epoch, step=epoch)

让我们假设,一个预测只是一个标量值。所以 predictions_batch 是一个 shape=[batchsize,].

的张量

这种串联的方式效果很好。

现在我的问题是: 这个 tf.concat() 操作会减慢我的整个训练速度吗?我也用了tf.stack()这个目的,不过速度好像没什么区别。

我想知道,因为一旦我使用 Matlab,在 for 循环中向 Vector 添加新值(并因此更改其大小)非常慢。用零初始化向量然后在循环中赋值在速度方面效率更高。

tensorflow 也是这样吗?或者是否有另一种更多 'proper' 的方法来做一些事情,比如在 for 循环中将张量加在一起,这样更干净或更快? 我没有在网上找到任何替代解决方案。

感谢您的帮助。

是的,这不是最值得推荐的方法。最好将每个张量简单地添加到一个列表中,并在最后将它们连接一次:

for epoch in range(start_epoch, end_epoch):
    predictions_batches = []
    for batch in tf_dataset:
        predictions_batch = model(batch)
        predictions_batches.append(predictions_batch)
        # ...
    predictions_epoch = tf.concat(predictions_batches)

您也可以使用 tf.TensorArray, which may be better if you want to decorate the code with tf.function

for epoch in range(start_epoch, end_epoch):
    # Pass arguments as required
    # If the number of batches is know or an upper bound
    # can be estimated use that and dynamic_size=False
    predictions_batches = tf.TensorArray(
        tf.float32, INTIAL_SIZE, dynamic_size=True, element_shape=[BATCH_SIZE])
    i = tf.constant(0)
    for batch in tf_dataset:
        predictions_batch = model(batch)
        predictions_batches = predictions_batches.write(i, predictions_batch)
        i += 1
        # ...
    predictions_epoch = predictions_batches.concat()