如何动态编辑外部 .config 文件?

How to dynamically edit an external .config file?

我正在使用 tensorflow 对象检测开发主动机器学习管道 api。 我的目标是动态更改网络 .config 文件中的路径。

标准配置如下所示:

    train_input_reader: {
       tf_record_input_reader {
       input_path: "/PATH_TO_CONFIGURE/train.record"
       }
       label_map_path: "/PATH_TO_CONFIGURE/label_map.pbtxt"
    }

"PATH_TO_CONFIGURE" 应该从我的 jupyter notebook 单元格中动态替换。

对象检测 API 配置文件具有 protobuf 格式。以下是您阅读、编辑和保存它们的大致方法。

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2

pipeline = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          

with tf.gfile.GFile('config path', "r") as f:                                                                                                                                                                                                                     
    proto_str = f.read()                                                                                                                                                                                                                                          
    text_format.Merge(proto_str, pipeline)

pipeline.train_input_reader.tf_record_input_reader.input_path[:] = ['your new entry'] # it's a repeated field 
pipeline.train_input_reader.label_map_path = 'your new entry'

config_text = text_format.MessageToString(pipeline)                                                                                                                                                                                                        
with tf.gfile.Open('config path', "wb") as f:                                                                                                                                                                                                                       
    f.write(config_text)

您将不得不调整代码,但一般原理应该很清楚。我建议将其重构为函数并调用 Jupyter。

以下是对我适用的 TensorFlow 2(API 从 tf.gfile.GFile 略微更改为 tf.io.gfile.GFile):

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2

def read_config():
    pipeline = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          
    with tf.io.gfile.GFile('pipeline.config', "r") as f:                                                                                                                                                                                                                     
        proto_str = f.read()                                                                                                                                                                                                                                          
        text_format.Merge(proto_str, pipeline)
    return pipeline

def write_config(pipeline):
    config_text = text_format.MessageToString(pipeline)                                                                                                                                                                                                        
    with tf.io.gfile.GFile('pipeline.config', "wb") as f:                                                                                                                                                                                                                       
        f.write(config_text)

def modify_config(pipeline):
    pipeline.model.ssd.num_classes = 1
    pipeline.train_config.fine_tune_checkpoint_type = 'detection'

    pipeline.train_input_reader.label_map_path = 'label_map.pbtxt'
    pipeline.train_input_reader.tf_record_input_reader.input_path[0] = 'train.record'

    pipeline.eval_input_reader[0].label_map_path = 'label_map.pbtxt'
    pipeline.eval_input_reader[0].tf_record_input_reader.input_path[0] = 'test.record'

    return pipeline


def setup_pipeline():
    pipeline = read_config()
    pipeline = modify_config(pipeline)
    write_config(pipeline)

setup_pipeline()