tf.get_collection 中不存在 RNN 单元格
RNN Cell not present in tf.get_collection
使用 tf.get_collection()
时,RNN 单元未显示。我错过了什么?
import tensorflow as tf
print(tf.__version__)
rnn_cell = tf.nn.rnn_cell.LSTMCell(16)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
other_var = tf.Variable(0)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
打印出来
0.12.0
[]
[<tensorflow.python.ops.variables.Variable object at 0x0000027961250B70>]
Windows10,Python3.5
您没有 运行 LSTMCell
上的 __call__
,这就是您看不到变量的原因。试试这个(我假设 batch_size=10
和 rnn_size=16
)
import tensorflow as tf
print(tf.__version__)
rnn_cell = tf.nn.rnn_cell.LSTMCell(16)
a = tf.placeholder(tf.float32, [10, 16])
zero = rnn_cell.zero_state(10,tf.float32)
# The variables are created in the following __call__
b = rnn_cell(a, zero)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
other_var = tf.Variable(0)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
使用 tf.get_collection()
时,RNN 单元未显示。我错过了什么?
import tensorflow as tf
print(tf.__version__)
rnn_cell = tf.nn.rnn_cell.LSTMCell(16)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
other_var = tf.Variable(0)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
打印出来
0.12.0
[]
[<tensorflow.python.ops.variables.Variable object at 0x0000027961250B70>]
Windows10,Python3.5
您没有 运行 LSTMCell
上的 __call__
,这就是您看不到变量的原因。试试这个(我假设 batch_size=10
和 rnn_size=16
)
import tensorflow as tf
print(tf.__version__)
rnn_cell = tf.nn.rnn_cell.LSTMCell(16)
a = tf.placeholder(tf.float32, [10, 16])
zero = rnn_cell.zero_state(10,tf.float32)
# The variables are created in the following __call__
b = rnn_cell(a, zero)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
other_var = tf.Variable(0)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))