TensorFlow 中 graph collections 的用途是什么?

What is the purpose of graph collections in TensorFlow?

API 讨论 Graph Collections which judging from the code 是通用 key/data 存储。那些collections的目的是什么?

请记住,在幕后,Tensorflow 是一个用于指定然后执行计算数据流图的系统。图集合用作跟踪构建的图以及必须如何执行它们的一部分。例如,当您创建某些类型的操作时,例如 tf.train.batch_join,添加操作的代码也会将一些队列运行器添加到 QUEUE_RUNNERS 图形集合中。稍后,当您调用 start_queue_runners() 时,默认情况下,它会查看 QUEUE_RUNNERS 集合以了解要启动哪些跑步者。

到目前为止,我认为至少有两个好处:

  1. 当您将程序分发到多个 GPU 或机器上时,可以方便地从同一集合中的不同设备收集损失。使用 tf.add_n 将它们相加以累积损失。
  2. 以我自己的方式更新一组特定的变量,如权重和偏差。

例如:

import tensorflow as tf    
w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)    
w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
weight_init_op = tf.variables_initializer(tf.get_collection_ref(tf.GraphKeys.WEIGHTS))
sess = tf.InteractiveSession()
weight_init_op.run()
for vari in tf.get_collection_ref(tf.GraphKeys.WEIGHTS): 
    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, vari.assign(0.2 * vari))
weight_update_ops = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS)
for op in weight_update_ops:
    print(op.eval())

输出:

[0.2 0.4 0.6]
[2.2 4.4 6.4]