带有 Tensorflow 对象检测的多通道输入 API V2

Multi-channel inputs with Tensorflow Objection Detection API V2

我想在 Tensorflow V2 对象检测中构建一个网络 API 使用 5 通道图像。但是,我一直卡在如何使用Tensorflow 2.2框架修改第一个卷积层的权重。

我已经从 V2 Model Zoo 下载了预训练好的 RetinaNet。然后我尝试了以下方法来修改检查点第一层的权重并将它们保存回来:

tf_path = tf.train.latest_checkpoint('./RetinaNet/checkpoint/')
init_vars = tf.train.list_variables(tf_path)
tf_vars = {}
for name, shape in init_vars:

    array = tf.train.load_variable(tf_path, name)
    try:
        if shape[2]==3:#look for a layer who's 3rd input dimension is 3 i.e. the 1st convolutional layer
            array=np.concatenate((array,array[:,:,:2,:]),axis=2)
            array=array.astype('float32')
            tf_vars[name]=tf.Variable(array)
            
        else:
            tf_vars[name]=tf.Variable(array)
            
    except:
        tf_vars[name]=tf.Variable(array)
        
        
saver = tf.compat.v1.train.Saver(var_list=tf_vars)
sess = tf.compat.v1.Session()
saver.save(sess, './RetinaNet/checkpoint/ckpt-0')

我重新加载模型以确保第一个卷积层已更改 - 一切看起来都正常。

但是当我去训练模型时,我得到了以下错误: 模型是用形状 (None, None, None, 3) 为输入 Tensor("input_1:0", shape=(None, None, None, 3), dtype=float32), 但它是在形状不兼容的输入上调用的 (64, 128, 128, 5)

这让我相信我修改权重的方法毕竟不是那么“好”。谁能提供一些关于如何正确修改这些权重的提示?

谢谢

这现在可行,但解决方案非常棘手...这也意味着不使用模型动物园的预训练权重进行训练 - 因此您需要在配置文件中注释与 fine_tune_checkpoint 相关的所有内容. 然后,转到 .\Lib\site-packages\official\vision\image_classification\efficientnet 并更改 efficientnet_model.py 和 efficientnet_config.py 中的输入通道数和 类 数。