Tensorflow RNN 示例限制为固定批量大小?

Tensorflow RNN example limited to fixed batch size?

在 Tensorflow 上查看 RNN example 时,我对初始状态的构造方式有疑问。在图的构建时,我们将图限制为仅处理一个批次大小的输入。这对我来说是个问题,因为我希望能够输入单个示例并获得对该单个示例的预测。

限制这个的代码部分是:

initial_state = state = tf.zeros([batch_size, lstm.state_size])

所以我的问题是如何扩展示例,以便我可以使用可变批量大小,以便我可以使用相同的模型进行批量大小的训练,然后使用单个示例进行预测?

我就是这样做的。您可以像这样将 batch_size 作为变量传递:

batch_size = tf.placeholder(tf.int32)
init_state = cell.zero_state(batch_size, tf.float32)

其中 cell 是 RNN 单元之一(BasicLSTMCellBasicGRUCellMultiRNNCell 等)。但是,如果您在多个批次中保留状态,则无法正常工作,因为它的大小必须保持不变。

Tensorflow 文本生成教程解释了如何执行此操作(现在是 TF 2.0)。 batch_size 似乎成为构建模型的一部分,因此您必须使用新的批量大小从保存的权重中 rebuild/reload:

https://www.tensorflow.org/tutorials/text/text_generation#restore_the_latest_checkpoint

To keep this prediction step simple, use a batch size of 1.

Because of the way the RNN state is passed from timestep to timestep, the model only accepts a fixed batch size once built.

To run the model with a different batch_size, we need to rebuild the model and restore the weights from the checkpoint.

model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
model.summary()

我不确定你为什么必须这样做,但我一直认为这是因为循环层的批处理需要管理多个并行的隐藏状态管道,所以它会预先分配它们。