Return Keras 中 RNN 中跨时间步长的所有状态

Return all states across time-steps in RNN in Keras

我正在 Keras 中实现我自己的循环层,在 step 函数中我想访问所有时间步长的隐藏状态,而不仅仅是默认情况下的最后一个状态,所以我可以做一些事情,比如及时向后添加跳过连接。

我正在尝试将 tensorflow 后端 K.rnn 内的 _step 修改为 return 到目前为止的所有隐藏状态。我最初的想法是简单地将每个隐藏状态存储到 TensorArray,然后将所有这些传递给 step_function(即我层中的 step 函数)。我当前修改的函数如下,它将每个隐藏状态写入 TensorArray states_ta_t:

   def _step(time, output_ta_t, states_ta_t, *states):
            current_input = input_ta.read(time)
            # Here I'd like to return all states up to current time
            # and pass to step_function, instead of just the last
            states = [states_ta_t.read(time)]
            output, new_states = step_function(current_input,
                                               tuple(states) +
                                               tuple(constants))
            for state, new_state in zip(states, new_states):
                new_state.set_shape(state.get_shape())
            states_ta_t = states_ta_t.write(time+1, new_states[0]) # record states
            output_ta_t = output_ta_t.write(time, output)
            return (time + 1, output_ta_t, states_ta_t) + tuple(new_states) 

这个版本只是 return 的最后一个状态,就像最初的实现一样,并且作为一个普通的 RNN 工作。我怎样才能将所有状态存储在数组中,然后传递给 step_function?感觉这应该非常简单,但是我对TensorArrays的使用不是很精通...

(注意:这在展开版本中比在符号版本中更容易做到,但不幸的是,我会 运行 使用展开版本进行我的实验时内存不足)

-- 已编辑 --

我发现我误解了你的问题,对此我感到非常抱歉...

简而言之,试试这个:

states = states_ta_t.stack()[:time]

这里有一些解释:您确实将所有这些状态存储在 states_ta_t 中,但您只将最后一个传递给 step_function

您在代码中所做的是:

# Param 'time' refers to 'current time step'
states = [states_ta_t.read(time)]

这意味着您正在从 states_ta_t 读取 'current' 状态,换句话说,就是最后一个状态。

如果您想改为进行一些切片,也许 stack 函数会有所帮助。例如:

states = states_ta_t.stack()[:time]

但我不确定这是否是一个正确的实现,因为我也不熟悉 TensorArray...

希望对您有所帮助!如果没有,您愿意留下评论和我讨论是我的荣幸!