调用函数中的变量 batch_size

Variable batch_size in call function

我正在尝试使用 TensorFlow 2 实现注意力网络。因此,对于每张图像,我只想瞥一眼,即图像的一小部分。为此,我实现了 tensorflow.keras.models.Model 的一个子类,这里是其中的一个片段。

class RecurrentAttentionModel(models.Model):
# ...

def call(self, inputs):

    l = tf.random.uniform((40,2,), minval=0, maxval=1)

    for _ in range(0, self.glimpses):
        glimpse = tf.image.extract_glimpse(inputs, size=(self.retina_size, self.retina_size), offsets=l, centered=False, normalized=True)

        # some other code...
        # update l to take a glimpse somewhere else


    return result           

现在,上面的代码可以完美运行和训练,但我的问题是,我在其中硬编码了 40,即我在数据集中定义的 batch_size。我无法在调用方法中 read/get batch_size 因为变量 "inputs" 的形式是 Tensor("input_1_77:0", shape=(None, 250, 500, 1), dtype=float32) 其中 None 代表 batch_size 似乎是预期的行为。 当我使用以下代码初始化 l 时(没有 batch_size)

l = tf.random.uniform((2,), minval=0, maxval=1)

它抛出这个错误

ValueError: Shape must be rank 2 but is rank 1 for 'recurrent_attention_model_86/ExtractGlimpse' (op: 'ExtractGlimpse') with input shapes: [?,250,500,1], [2], [2]

我完全理解,但我不知道如何根据 batch_size.

实现初始值

您可以使用 tf.shape.

动态提取批量大小维度
l = tf.random.normal(tf.stack([tf.shape(inputs)[0], 2]), minval=0, maxval=1))