使用 keras.utils.Sequence 多处理和数据库 - 何时连接?

Using keras.utils.Sequence multiprocessing and data base - when to connect?

我正在使用带有 Tensorflow 后端的 Keras 训练神经网络。数据集不适合 RAM,因此,我将其存储在 Mongo 数据库中,并使用 keras.utils.Sequence.

的 subclass 检索批次

一切正常,如果我 运行 model.fit_generator()use_multiprocessing=False

当我打开多处理时,我在生成工作进程或连接到数据库时遇到错误。

如果我在 __init__ 中创建连接,我会遇到一个异常,其文本说明了 pickling 锁对象中的错误。对不起,我记不太清了。但是连训练都没有开始

如果我在 __get_item__ 中创建连接,训练开始并且 运行s 一些 epoch,然后我得到错误 [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted

根据 the pyMongo manuals,它不是分叉安全的,每个子进程都必须创建自己的数据库连接。我使用 Windows,它不使用叉子,而是生成进程,但是,恕我直言,区别在这里并不重要。

这解释了为什么在 __init__ 中无法连接。

这是 docs 的另一引述:

Create this client once for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient.

这解释了 __get_item__ 中的错误。

但是,我的 class 是如何理解 Keras 创建了新进程的,这还不清楚。

这是我的 Sequence 实现的最后一个变体的伪代码(每个请求都有新连接):

import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical

class MongoSequence(Sequence):
    def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
        self._train_set = train_set
        self._server = server
        self._db = database
        self.collection = collection
        self._batch_size = batch_size

        query = {}  # train_set query
        self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]

    def _connect(self):
        client = pymongo.MongoClient(self._server)
        db = self._client[self._db]
        return _db[self._collection]

    def __len__(self):
        return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

    def __getitem__(self, item):
        oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
        X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
        y = np.empty((len(oids), 2), dtype=np.float32)
        for i, oid in enumerate(oids):
            smp = self._connect().find({'_id': oid}).next()
            X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
            y[i] = to_categorical(not smp['result'], 2)
        return X, y

也就是说,在对象构造上,我根据条件检索所有相关的 ObjectIDs 形成训练集。在对 __getitem__ 的调用中从数据库中检索实际对象。它们的 ObjectIDs 由列表切片确定。

这段调用 model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... ) 的代码产生了几个 python 进程,每个进程根据日志消息初始化 Tensorflow 后端,然后训练开始。

但最终异常从函数 connect 中抛出,位于 pymongo 的深处。

很遗憾,我没有存储调用堆栈。错误是上面描述的,我重复:[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted

我的假设是这段代码创建了太多与服务器的连接,因此 __getitem__ 中的连接是错误的。

构造函数中的连接也是错误的,因为它是在主进程中执行的,Mongo文档直接反对它。

Sequenceclass、on_epoch_end还有一个方法。但是,我需要在纪元开始而不是结束时建立连接。

引自 Keras 文档:

If you want to modify your dataset between epochs you may implement on_epoch_end

那么,有什么建议吗?文档在这里不是很具体。

看来我找到了解决办法。解决方案是 - 跟踪进程 ID 并在它更改时重新连接

class MongoSequence(Sequence):
    def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
        self._server = server
        self._db = database
        self._collection_name = collection
        self._batch_size = batch_size
        self._query = query
        self._collection = self._connect()

        self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]

        self._pid = os.getpid()
        del self._collection   #  to be sure, that we've disconnected
        self._collection = None

    def _connect(self):
        client = pymongo.MongoClient(self._server)
        db = client[self._db]
        return db[self._collection_name]

    def __len__(self):
        return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

    def __getitem__(self, item):
        if self._collection is None or self._pid != os.getpid():
            self._collection = self._connect()
            self._pid = os.getpid()

        oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
        X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
        y = np.empty((len(oids), 2), dtype=np.float32)
        for i, oid in enumerate(oids):
            smp = self._connect().find({'_id': oid}).next()
            X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
            y[i] = to_categorical(not smp['result'], 2)
        return X, y

on_epoch_end() 中创建连接,并从“init()”方法显式调用 on_epoch_end()。这使得 on_epoch_end() 在实践中工作,就好像 ti 是 "on epoch begin" 一样。 (每个纪元的结束,是下一个纪元的开始。第一个纪元之前没有纪元,因此在初始化中显式调用。)