LSTM 的 model.reset_states 是否会影响模型中的任何其他非 LSTM 层?

Does model.reset_states for LSTM affect any other non-LSTM layers in the model?

我在 tf.keras 中使用 LSTM 的状态模式,在处理完序列数据后,我需要手动执行 reset_states,如 所述。似乎通常人们会这样做 model.reset_states(),但在我的例子中,我的 LSTM 层嵌入在一个复杂得多的网络中,该网络包括各种其他层,如 Dense、Conv 等。我的问题是,如果我只是在嵌入了 LSTM 的主模型上调用 model.reset_states()(并且只有一个 LSTM),我是否应该担心该重置会影响模型中的其他层,例如 Dense 或 Conv层?寻找 LSTM 层并将 reset_states 调用隔离到该层会更好吗?

TLDR:像LSTM/GRU这样的层有权重和状态,而像Conv/Dense/这样的层Embedding 只有权重。 reset_state() 仅影响具有状态的图层。

reset_states() 所做的是,对于 LSTM,它会重置层中的 c_th_t 输出。这些是您通常通过设置 LSTM(n, return_state=True) 获得的值。

EmbeddingDenseConv层没有这样的状态。所以 model.reset_states() 不会影响那些前馈层。只是像 LSTM 和 GRU 这样的顺序层。

如果您愿意,可以查看 source code 并验证此函数是否查看每个层是否具有 reset_state 属性(前馈层没有)。

任何具有可设置stateful属性的图层都受制于reset_states();该方法遍历每一层,检查它是否有 stateful=True - 如果有,调用它的 reset_states() 方法;参见 source

在 Keras 中,包括 ConvLSTM2D 在内的所有循环层都有一个可设置的 stateful 属性——我不知道有任何其他属性。 tensorflow.keras,但是,有很多自定义层实现可能;您可以使用以下代码进行确认:

def print_statefuls(model):
    for layer in model.layers:
        if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
            print(layer.name, "is stateful")