创建层时处理符号张量的 none 维度

Dealing with none dimension of symbolic tensor when creating a layer

我想实现我自己的 Max Unpooling 层,如 here. For that, I need the argmax output of tf.nn.max_pool_with_argmax 中所述。 我使用任何模型外部的层成功地应用了它,但是当我想将它添加到模型时(因此,TensorFlow 使用 symbolic/static 张量,我得到了错误。我尝试了很多实现,我提出了 2 个选项总结:

在这两个选项中,我在 tf.scatter_nd 上遇到了同样的问题:

*** ValueError: Tried to convert 'shape' to a tensor and failed. Error: Cannot convert a partially known TensorShape to a Tensor

我明白为什么会出现这个错误,对于动态张量,批次是未知的,因此形状是 tf.TensorShape([None, ...])。但是我该如何处理呢?

这是我的实现。

class UnPooling2D(Layer):

    def __init__(self, desired_output_shape, name=None, dtype=DEFAULT_COMPLEX_TYPE, dynamic=False, **kwargs):
        self.desired_output_shape = desired_output_shape    # If option 1
        super(ComplexUnPooling2D, self).__init__(trainable=False, name=name, dtype=self.my_dtype.real_dtype,
                                               dynamic=dynamic, **kwargs)

    def call(self, inputs, unpool_mat, reference_tensor_for_shape: Optional[Only if option 2], **kwargs):
        """
        Performs unpooling
        :param inputs: Input Tensor.
        :param unpool_mat: Result argmax from tf.nn.max_pool_with_argmax
            https://www.tensorflow.org/api_docs/python/tf/nn/max_pool_with_argmax
        """       
        updates = tf.reshape(inputs, [-1])
        indices = tf.expand_dims(tf.reshape(unpool_mat, [-1]), axis=-1)

        ####################
        # Option 1:
        flat_output_shape = tf.reduce_prod(self.desired_output_shape)
        ret = tf.scatter_nd(indices, updates, shape=(inputs.get_shape()[0]*flat_output_shape,))
        desired_output_shape_with_batch = tf.concat([[inputs.get_shape()[0]], self.desired_output_shape], axis=0)
        ret = tf.reshape(ret, shape=desired_output_shape_with_batch)

        # Option 2 (untested, speudo code):
        flatten_reference_tensor = tf.reshape(reference_tensor_for_shape, [-1])
        ret = tf.scatter_nd(indices, updates, shape=flatten_reference_tensor.get_shape())
        ret = tf.reshape(ret, reference_tensor_for_shape.get_shape())
        #################
        return ret

我尝试了一个丑陋的修复程序,但出现错误:

WARNING:tensorflow:AutoGraph could not transform <bound method UnPooling2D.call of <layers.pooling.ComplexUnPooling2D object at 0x7f95cab16220>> and will run it as-is.

我尝试做的是:

    def call(self, inputs, unpool_mat, **kwargs):
        if inputs.get_shape()[0]:
           ... other solutions
        else:   # Dynamic tensors
            ret = tf.reshape(inputs, (-1,) + self.desired_output_shape)
        return ret

我部分解决了这个问题,但我认为我的新错误实际上值得一个新问题。最后我不得不将 inputs.get_shape() 更改为 tf.shape(inputs)。这是我“有效”的最终代码。 (至少它得到了正确的模型形状,没有错误,例如 model.summary().

    def call(self, inputs, unpool_mat, **kwargs):
        flat_output_shape = tf.reduce_prod(self.desired_output_shape)

        updates = tf.reshape(inputs, [-1])
        indices = tf.expand_dims(tf.reshape(unpool_mat, [-1]), axis=-1)

        ret = tf.scatter_nd(indices, updates, shape=(tf.shape(inputs)[0]*flat_output_shape,))
        desired_output_shape_with_batch = tf.concat([[tf.shape(inputs)[0]], self.desired_output_shape], axis=0)
        ret = tf.reshape(ret, shape=desired_output_shape_with_batch)
        return ret