张量流 2 中的 ConvLSTMCell

ConvLSTMCell in tensorflow 2

从 1 升级到 tensorflow 版本 2 后,tf.contrib 中的所有模块都已折旧。

为了应用attention method,我需要每个细胞的状态。

最初,我在 tf 版本 1 中所做的是:


#ConvLSTMCell
convlstm_layer = tf.contrib.rnn.ConvLSTMCell(
                conv_ndims = 2,    
                input_shape = [10, 10, 32],
                output_channels = 32,
                kernel_shape = [2, 2],
                use_bias = True,
                skip_connection = False,
                forget_bias = 1.0,
                initializers = None,
                )


# Run RNN with ConvLSTMCell
outputs, state = tf.compat.v1.nn.dynamic_rnn(convlstm_layer, conv1_out, time_major = False, dtype = input.dtype)

现在,我正在尝试将其转换为 tf 版本 2 中的代码。

但是,正如我上面提到的,两个模块(tf.contrib 和 tf.compat)都已折旧。

我找到了 tf.compat.v1.nn.dynamic_rnn 的替代方法 tf.keras.layers.rnn

但是没有创建 ConvLSTMCell 的函数。有什么建议吗?

我想你要找的是这里:https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM2D?version=stable

您可以在您的代码中导入它,例如:

import tensorflow as tf
conv_lstm_layer = tf.keras.layers.ConvLSTM2D(my_parameters)