Tensorflow 警告:提供给 MultiRNNCell 的两个单元格是同一个对象

Tensorflow warning: two cells provided to MultiRNNCell are the same object

我在执行 tensorflow 脚本时一直收到以下警告

WARNING:tensorflow:At least two cells provided to MultiRNNCell are the same object and will share weights.

lstm_layer=rnn.LSTMBlockCell(num_units,forget_bias=1)
lstm_layer=rnn.DropoutWrapper(lstm_layer, output_keep_prob=output_keep_prob)
stacked_lstm = rnn.MultiRNNCell([lstm_layer] * num_layers)
outputs,_=rnn.static_rnn(stacked_lstm,input,dtype="float32")

但是,有问题的 RNN 似乎 运行 很好,并且正在做出准确的预测。

警告消息有什么含义?可以安全地忽略它吗?如果可能很严重,如何评估其影响?

您使用 [lstm_layer] * num_layers 创建多个 RNN 层,这些层实际上引用 python 中的同一对象。这种用法在某些版本的tensorflow中是正常的,有些版本会报错。

正如警告所说,由于所有 RNN 层都是同一个对象,因此它们的权重将保持不变。所有错误都反馈到 RNN 层。相当于减少了模型的参数,降低了模型的复杂度。

如果你想创建多个不同的RNN层和复杂的模型,你可以使用下面的用法。这两种不同方法的有效性评估取决于具体的应用场景和结果。如果您的模型结果足够好,再复杂的模型也没有多大意义。

rnn_layers = []
for _ in range(num_layers):
    lstm_layer = rnn.LSTMBlockCell(num_units, forget_bias=1)
    lstm_layer = rnn.DropoutWrapper(lstm_layer, output_keep_prob=output_keep_prob)
    rnn_layers.append(lstm_layer)

stacked_lstm = rnn.MultiRNNCell(rnn_layers)