如何用元组初始化 LSTMCell

How to Initialize LSTMCell with tuple

我最近将我的 tesnorflow 从 Rev8 升级到了 Rev12。在 Rev8 中,rnn_cell.LSTMCell 中的默认 "state_is_tuple" 标志设置为 False,所以我用一个列表初始化了我的 LSTM Cell,请参见下面的代码。

#model definition  
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)


#init_state place holder and feed_dict
def add_placeholders(self):
     self.init_state = tf.placeholder("float", [None, self.cell_size])

def get_feed_dict(self, data, label):
    feed_dict = {self.input_data: data,
             self.input_label: reg_label,
             self.init_state: np.zeros((self.config.batch_size, self.cell_size))}
    return feed_dict

在 Rev12 中,默认的 "state_is_tuple" 标志设置为 True,为了使我的旧代码正常工作,我必须明确地将标志设置为 False。但是,现在我收到来自 tensorflow 的警告:

"Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True"

我尝试通过将 self.init_state 的占位符定义更改为以下内容来使用元组初始化 LSTM 单元:

self.init_state = tf.placeholder("float", (None, self.cell_size))

但现在我收到一条错误消息:

"'Tensor' object is not iterable"

有谁知道如何进行这项工作?

现在使用 cell.zero_state 将 "zero state" 馈送到 LSTM 更简单。您不需要将初始状态明确定义为占位符。将其定义为张量,并在需要时提供。这就是它的工作原理,

lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
self.initial_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)

如果你想输入一些其他值作为初始状态,比如 next_state = states[-1],在你的会话中计算它并在 feed_dict 中传递它,比如 -

feed_dict[self.initial_state] = next_state

根据您的问题,lstm_cell.zero_state() 就足够了。


无关,但请记住,您可以在提要字典中同时传递张量和占位符!这就是 self.initial_state 在上面的示例中的工作方式。查看 PTB Tutorial 中的工作示例。