Keras:'TypeError: Failed to convert object of type <class 'tuple'> to Tensor' occured when I build a self-defined layer

Keras:'TypeError: Failed to convert object of type <class 'tuple'> to Tensor' occured when I build a self-defined layer

我根据class.The定义了一个层,这个层的目的只是为输入添加一个可学习的权重。通过该层的输入和输出大小相同。 当我构建模型时,出现错误:

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (None, 256, 256). Consider casting elements to a supported type.

这是代码(定义和调用)。

定义:

class Filter_low(Layer):
def __init__(self,**kwargs):
    
    super(Filter_low, self).__init__(**kwargs)

def build(self, input_shape):
    
    self.kernel = self.add_weight(name='kernel',
                                  shape=input_shape,
                                  initializer='uniform',
                                  trainable=True)
    super(Filter_low, self).build(input_shape)  

def call(self, x):
    return K.dot(x, self.kernel)

def compute_output_shape(self, input_shape):
    return input_shape

来电:

 fre_dct = Input(shape=(256, 256))
 fw_low = Filter_low(name='Filter_low')(fre_dct)

尝试像这样更改 kernel 中的 input_shape

import tensorflow as tf

class Filter_low(tf.keras.layers.Layer):

  def __init__(self,**kwargs):
      
      super(Filter_low, self).__init__(**kwargs)

  def build(self, input_shape):
      output_dim = input_shape[-1]
      self.kernel = self.add_weight(name='kernel',
                                    shape=(output_dim, output_dim),
                                    initializer='uniform',
                                    trainable=True)
      super(Filter_low, self).build(input_shape)  

  def call(self, x):
      return tf.keras.backend.dot(x, self.kernel)

  def compute_output_shape(self, input_shape):
      return input_shape

fre_dct = tf.keras.Input(shape=(256, 256))
fw_low = Filter_low(name='Filter_low')(fre_dct)
model = tf.keras.Model(fre_dct, fw_low)

X = tf.random.normal((5, 256, 256))
y = tf.random.normal((5, 256, 256))
model.compile(optimizer='adam', loss='MSE')
model.fit(X, y, epochs=2)

或者,您可以设置 shape=(input_shape[1:])