将可变大小的张量填充到自定义层中的特定大小

Pad a variable size tensor to a specific size in a custom layer

我有一个具有动态第一维和固定第二维的二维张量。我使用 segment_sum 更新此张量,其中第一维的大小可能会发生变化,因此,我想将修改后的张量零填充为与输入相同的形状。

中提供的答案对我没有帮助,因此根据我的用例的具体情况提出这个问题。

class MyLayer(layers.Layer):
    def call(self, inputs):
        x, segment_ids = inputs
        x_ = tf.math.segment_sum(x, segment_ids)
        # the first dimension of x and x_ 
        # may not be equal at this point, hence zero-padding.
        x_ = tf.pad(
            x_, 
            [(0, 0), (tf.shape(x)[0] - tf.shape(x_)[0], 0)])
        return x_

如果我没看错 and ,上面的方法应该有效!?但是,我在下游 MyLayer 层收到错误抱怨:

unsupported operand type(s) for *=: 'int' and 'NoneType'

我什至在第一个纪元开始之前就收到了这个错误,不管我提供的输入是什么;但是,如果我删除填充,我将只为我的一些输入得到不兼容的形状。因此,我猜测错误与填充有关,也许错误消息中给出的 tf.shape NoneType?


例子

这是一个最小的示例,我可以在示例末尾使用简单的填充获得预期的输出。但是,在我的实际用例中,模型是根据动态输入大小定义的,即大小为 None,类似的方法失败并出现上述错误。

x = [[1, 10], [2, 20], [3, 30], [4, 40]]
i = [0, 1, 1, 1]
o = tf.math.segment_sum(x, i)
o = tf.get_static_value(o)
print(f"shape: {o.shape}")
print(f"\ntensor:\n{o}")

输出:

shape: (2, 2)

tensor:
[[ 1 10]
 [ 9 90]]

预期输出

o = tf.pad(o, [(0, len(x) - o.shape[0]), (0, 0)])

shape: (4, 2)

tensor:
[[ 1 10]
 [ 9 90]
 [ 0  0]
 [ 0  0]]

您可以尝试使用 tf.concat 填充您的情况:

import tensorflow as tf

class MyLayer(tf.keras.layers.Layer):

    def call(self, inputs):
        x, segment_ids = inputs
        x_ = tf.math.segment_sum(x, segment_ids)
        x_ = tf.concat([x_, tf.zeros((int(tf.shape(x)[0] - tf.shape(x_)[0]), tf.shape(x)[1]))], axis=0)
        return x_

inputs = tf.keras.layers.Input((2,))
l = MyLayer()
x = l([inputs, [0, 1, 1, 1]])
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse')

x = tf.constant([[1, 10], [2, 20], [3, 30], [4, 40]], dtype=tf.float32)
print(l([x, [0, 1, 1, 1]]))

y = tf.constant([1, 2, 1, 4])
model.fit(x, y, epochs=2)
tf.Tensor(
[[ 1. 10.]
 [ 9. 90.]
 [ 0.  0.]
 [ 0.  0.]], shape=(4, 2), dtype=float32)
Epoch 1/2
1/1 [==============================] - 1s 683ms/step - loss: 2624.1050
Epoch 2/2
1/1 [==============================] - 0s 12ms/step - loss: 2361.9163
<keras.callbacks.History at 0x7f7fdd4cf810>