Tensorflow (Keras) 和多处理导致 GPU 内存不足

Tensorflow (Keras) & Multiprocessing results in lack of GPU memory

我有一个自定义 DataGenerator,它使用 Python 的多处理模块来生成提供给 Tensorflow 模型的训练数据。

问题是每当初始化一个新的 DataGenerator 进程时,它似乎都会尝试初始化 Tensorflow(在代码顶部导入)并为自己分配一些 GPU 内存。

我按照 限制每个进程对 GPU 内存的访问并且我的代码有效,但我只能使用可用 GPU 内存的三分之一。

新进程和 Tensorflow 代码在同一个 Python 文件中启动。是否有适当的方法来使用多处理,同时禁止生成的进程导入 Tensorflow 并为自己分配一些 GPU 内存?

这里有一部分代码(运行于 Windows)以供说明:

from multiprocessing import Process, Queue
from multiprocessing.pool import Pool

import cv2
import numpy as np
import tensorflow as tf

from keras.models import load_model

def TrainQueueProcess(queue):
    # This Function Fills The Queue For Other Consumers

def get_model(model_path=None):
    import tensorflow as tf
    import keras.backend.tensorflow_backend as ktf

    def get_session(gpu_fraction=0.333):
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction, allow_growth=True)
        return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    ktf.set_session(get_session())

    from keras import Input, Model
    from keras.applications.mobilenetv2 import MobileNetV2
    from keras.layers import Dense, Dropout
    from keras.optimizers import adam
    from keras.utils import plot_model

    input_tensor = Input(shape=(128, 128, 3))
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_tensor=input_tensor, input_shape=(128, 128, 3), pooling='avg')
    for layer in base_model.layers:
        layer.trainable = True

    op = Dense(128, activation='relu')(base_model.output)
    op = Dropout(.25)(op)
    output_tensor = Dense(2, activation='softmax')(op)
    model = Model(inputs=input_tensor, outputs=output_tensor)
    model.compile(optimizer=adam(lr=0.0008), loss='binary_crossentropy', metrics=['accuracy'])

    return model


if __name__ == '__main__':
    TRAIN_QUEUE = Queue(maxsize=10)
    TRAIN_PROCESS = Process(target=TrainQueueProcess, args=(TRAIN_QUEUE))
    TRAIN_PROCESS.start()

    model = get_model(model_path)

如果您在 windows 上,请将所有 tfkeras 导入到方法中。

由于 Windows 缺少 os.fork() 所有导入都在新流程中再次导入(在您的情况下包括导入 tf)。

https://docs.python.org/2/library/multiprocessing.html#windows