使用 keras.utils.Sequence 多处理和数据库 - 何时连接?
Using keras.utils.Sequence multiprocessing and data base - when to connect?
我正在使用带有 Tensorflow 后端的 Keras 训练神经网络。数据集不适合 RAM,因此,我将其存储在 Mongo 数据库中,并使用 keras.utils.Sequence
.
的 subclass 检索批次
一切正常,如果我 运行 model.fit_generator()
和 use_multiprocessing=False
。
当我打开多处理时,我在生成工作进程或连接到数据库时遇到错误。
如果我在 __init__
中创建连接,我会遇到一个异常,其文本说明了 pickling 锁对象中的错误。对不起,我记不太清了。但是连训练都没有开始
如果我在 __get_item__
中创建连接,训练开始并且 运行s 一些 epoch,然后我得到错误 [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
。
根据 the pyMongo manuals,它不是分叉安全的,每个子进程都必须创建自己的数据库连接。我使用 Windows,它不使用叉子,而是生成进程,但是,恕我直言,区别在这里并不重要。
这解释了为什么在 __init__
中无法连接。
这是 docs 的另一引述:
Create this client once for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient.
这解释了 __get_item__
中的错误。
但是,我的 class 是如何理解 Keras 创建了新进程的,这还不清楚。
这是我的 Sequence 实现的最后一个变体的伪代码(每个请求都有新连接):
import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical
class MongoSequence(Sequence):
def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
self._train_set = train_set
self._server = server
self._db = database
self.collection = collection
self._batch_size = batch_size
query = {} # train_set query
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]
def _connect(self):
client = pymongo.MongoClient(self._server)
db = self._client[self._db]
return _db[self._collection]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
也就是说,在对象构造上,我根据条件检索所有相关的 ObjectIDs
形成训练集。在对 __getitem__
的调用中从数据库中检索实际对象。它们的 ObjectIDs
由列表切片确定。
这段调用 model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... )
的代码产生了几个 python 进程,每个进程根据日志消息初始化 Tensorflow 后端,然后训练开始。
但最终异常从函数 connect
中抛出,位于 pymongo
的深处。
很遗憾,我没有存储调用堆栈。错误是上面描述的,我重复:[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
。
我的假设是这段代码创建了太多与服务器的连接,因此 __getitem__
中的连接是错误的。
构造函数中的连接也是错误的,因为它是在主进程中执行的,Mongo文档直接反对它。
Sequence
class、on_epoch_end
还有一个方法。但是,我需要在纪元开始而不是结束时建立连接。
引自 Keras 文档:
If you want to modify your dataset between epochs you may implement on_epoch_end
那么,有什么建议吗?文档在这里不是很具体。
看来我找到了解决办法。解决方案是 - 跟踪进程 ID 并在它更改时重新连接
class MongoSequence(Sequence):
def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
self._server = server
self._db = database
self._collection_name = collection
self._batch_size = batch_size
self._query = query
self._collection = self._connect()
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]
self._pid = os.getpid()
del self._collection # to be sure, that we've disconnected
self._collection = None
def _connect(self):
client = pymongo.MongoClient(self._server)
db = client[self._db]
return db[self._collection_name]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
if self._collection is None or self._pid != os.getpid():
self._collection = self._connect()
self._pid = os.getpid()
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
在 on_epoch_end()
中创建连接,并从“init()”方法显式调用 on_epoch_end()
。这使得 on_epoch_end()
在实践中工作,就好像 ti 是 "on epoch begin" 一样。 (每个纪元的结束,是下一个纪元的开始。第一个纪元之前没有纪元,因此在初始化中显式调用。)
我正在使用带有 Tensorflow 后端的 Keras 训练神经网络。数据集不适合 RAM,因此,我将其存储在 Mongo 数据库中,并使用 keras.utils.Sequence
.
一切正常,如果我 运行 model.fit_generator()
和 use_multiprocessing=False
。
当我打开多处理时,我在生成工作进程或连接到数据库时遇到错误。
如果我在 __init__
中创建连接,我会遇到一个异常,其文本说明了 pickling 锁对象中的错误。对不起,我记不太清了。但是连训练都没有开始
如果我在 __get_item__
中创建连接,训练开始并且 运行s 一些 epoch,然后我得到错误 [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
。
根据 the pyMongo manuals,它不是分叉安全的,每个子进程都必须创建自己的数据库连接。我使用 Windows,它不使用叉子,而是生成进程,但是,恕我直言,区别在这里并不重要。
这解释了为什么在 __init__
中无法连接。
这是 docs 的另一引述:
Create this client once for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient.
这解释了 __get_item__
中的错误。
但是,我的 class 是如何理解 Keras 创建了新进程的,这还不清楚。
这是我的 Sequence 实现的最后一个变体的伪代码(每个请求都有新连接):
import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical
class MongoSequence(Sequence):
def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
self._train_set = train_set
self._server = server
self._db = database
self.collection = collection
self._batch_size = batch_size
query = {} # train_set query
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]
def _connect(self):
client = pymongo.MongoClient(self._server)
db = self._client[self._db]
return _db[self._collection]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
也就是说,在对象构造上,我根据条件检索所有相关的 ObjectIDs
形成训练集。在对 __getitem__
的调用中从数据库中检索实际对象。它们的 ObjectIDs
由列表切片确定。
这段调用 model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... )
的代码产生了几个 python 进程,每个进程根据日志消息初始化 Tensorflow 后端,然后训练开始。
但最终异常从函数 connect
中抛出,位于 pymongo
的深处。
很遗憾,我没有存储调用堆栈。错误是上面描述的,我重复:[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
。
我的假设是这段代码创建了太多与服务器的连接,因此 __getitem__
中的连接是错误的。
构造函数中的连接也是错误的,因为它是在主进程中执行的,Mongo文档直接反对它。
Sequence
class、on_epoch_end
还有一个方法。但是,我需要在纪元开始而不是结束时建立连接。
引自 Keras 文档:
If you want to modify your dataset between epochs you may implement on_epoch_end
那么,有什么建议吗?文档在这里不是很具体。
看来我找到了解决办法。解决方案是 - 跟踪进程 ID 并在它更改时重新连接
class MongoSequence(Sequence):
def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
self._server = server
self._db = database
self._collection_name = collection
self._batch_size = batch_size
self._query = query
self._collection = self._connect()
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]
self._pid = os.getpid()
del self._collection # to be sure, that we've disconnected
self._collection = None
def _connect(self):
client = pymongo.MongoClient(self._server)
db = client[self._db]
return db[self._collection_name]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
if self._collection is None or self._pid != os.getpid():
self._collection = self._connect()
self._pid = os.getpid()
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
在 on_epoch_end()
中创建连接,并从“init()”方法显式调用 on_epoch_end()
。这使得 on_epoch_end()
在实践中工作,就好像 ti 是 "on epoch begin" 一样。 (每个纪元的结束,是下一个纪元的开始。第一个纪元之前没有纪元,因此在初始化中显式调用。)