LSTM 的 model.reset_states 是否会影响模型中的任何其他非 LSTM 层?
Does model.reset_states for LSTM affect any other non-LSTM layers in the model?
我在 tf.keras
中使用 LSTM 的状态模式,在处理完序列数据后,我需要手动执行 reset_states
,如 所述。似乎通常人们会这样做 model.reset_states()
,但在我的例子中,我的 LSTM 层嵌入在一个复杂得多的网络中,该网络包括各种其他层,如 Dense、Conv 等。我的问题是,如果我只是在嵌入了 LSTM 的主模型上调用 model.reset_states()
(并且只有一个 LSTM),我是否应该担心该重置会影响模型中的其他层,例如 Dense 或 Conv层?寻找 LSTM 层并将 reset_states
调用隔离到该层会更好吗?
TLDR:像LSTM
/GRU
这样的层有权重和状态,而像Conv
/Dense
/这样的层Embedding
只有权重。 reset_state()
仅影响具有状态的图层。
reset_states()
所做的是,对于 LSTM,它会重置层中的 c_t
和 h_t
输出。这些是您通常通过设置 LSTM(n, return_state=True)
获得的值。
Embedding
、Dense
、Conv
层没有这样的状态。所以 model.reset_states()
不会影响那些前馈层。只是像 LSTM 和 GRU 这样的顺序层。
如果您愿意,可以查看 source code 并验证此函数是否查看每个层是否具有 reset_state
属性(前馈层没有)。
任何具有可设置stateful
属性的图层都受制于reset_states()
;该方法遍历每一层,检查它是否有 stateful=True
- 如果有,调用它的 reset_states()
方法;参见 source。
在 Keras 中,包括 ConvLSTM2D 在内的所有循环层都有一个可设置的 stateful
属性——我不知道有任何其他属性。 tensorflow.keras
,但是,有很多自定义层实现可能;您可以使用以下代码进行确认:
def print_statefuls(model):
for layer in model.layers:
if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
print(layer.name, "is stateful")
我在 tf.keras
中使用 LSTM 的状态模式,在处理完序列数据后,我需要手动执行 reset_states
,如 model.reset_states()
,但在我的例子中,我的 LSTM 层嵌入在一个复杂得多的网络中,该网络包括各种其他层,如 Dense、Conv 等。我的问题是,如果我只是在嵌入了 LSTM 的主模型上调用 model.reset_states()
(并且只有一个 LSTM),我是否应该担心该重置会影响模型中的其他层,例如 Dense 或 Conv层?寻找 LSTM 层并将 reset_states
调用隔离到该层会更好吗?
TLDR:像LSTM
/GRU
这样的层有权重和状态,而像Conv
/Dense
/这样的层Embedding
只有权重。 reset_state()
仅影响具有状态的图层。
reset_states()
所做的是,对于 LSTM,它会重置层中的 c_t
和 h_t
输出。这些是您通常通过设置 LSTM(n, return_state=True)
获得的值。
Embedding
、Dense
、Conv
层没有这样的状态。所以 model.reset_states()
不会影响那些前馈层。只是像 LSTM 和 GRU 这样的顺序层。
如果您愿意,可以查看 source code 并验证此函数是否查看每个层是否具有 reset_state
属性(前馈层没有)。
任何具有可设置stateful
属性的图层都受制于reset_states()
;该方法遍历每一层,检查它是否有 stateful=True
- 如果有,调用它的 reset_states()
方法;参见 source。
在 Keras 中,包括 ConvLSTM2D 在内的所有循环层都有一个可设置的 stateful
属性——我不知道有任何其他属性。 tensorflow.keras
,但是,有很多自定义层实现可能;您可以使用以下代码进行确认:
def print_statefuls(model):
for layer in model.layers:
if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
print(layer.name, "is stateful")