Pytorch 中的内存泄漏:对象检测
Memory leak in Pytorch: object detection
我正在 PyTorch 上开发 object detection tutorial。原始教程在给定的几个时期内运行良好。我将它扩展到大纪元并遇到 内存不足 错误。
我试着调试它并发现了一些有趣的东西。这是我正在使用的工具:
def debug_gpu():
# Debug out of memory bugs.
tensor_list = []
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
tensor_list.append(obj)
except:
pass
print(f'Count of tensors = {len(tensor_list)}.')
我用它来监控训练一个epoch的记忆:
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
...
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
# inference + backward + optimization
debug_gpu()
输出是这样的:
Count of tensors = 414.
Count of tensors = 419.
Count of tensors = 424.
Count of tensors = 429.
Count of tensors = 434.
Count of tensors = 439.
Count of tensors = 439.
Count of tensors = 444.
Count of tensors = 449.
Count of tensors = 449.
Count of tensors = 454.
如您所见,垃圾收集器跟踪的张量数量不断增加。
可以找到要执行的相关文件here。
我有两个问题:
1. 是什么阻碍了垃圾收集器释放这些张量?
2.内存不足怎么办?
如何识别错误?
在 tracemalloc 的帮助下,我拍了两张快照,中间有数百次迭代。该教程将向您展示它很容易遵循。
错误原因是什么?
Pytorch 中的 rpn.anchor_generator._cache
是一个 python dict
,它跟踪网格锚点。它是检测模型的一个属性,大小随着每个提议而增加。
如何解决?
在训练迭代结束时放置一个简单的绕过 model.rpn.anchor_generator._cache.clear()
。
我已经向 PyTorch 提交了 fix。自 torchvision 0.5 以来,您可能不会遇到 OOM 错误。
我正在 PyTorch 上开发 object detection tutorial。原始教程在给定的几个时期内运行良好。我将它扩展到大纪元并遇到 内存不足 错误。
我试着调试它并发现了一些有趣的东西。这是我正在使用的工具:
def debug_gpu():
# Debug out of memory bugs.
tensor_list = []
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
tensor_list.append(obj)
except:
pass
print(f'Count of tensors = {len(tensor_list)}.')
我用它来监控训练一个epoch的记忆:
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
...
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
# inference + backward + optimization
debug_gpu()
输出是这样的:
Count of tensors = 414.
Count of tensors = 419.
Count of tensors = 424.
Count of tensors = 429.
Count of tensors = 434.
Count of tensors = 439.
Count of tensors = 439.
Count of tensors = 444.
Count of tensors = 449.
Count of tensors = 449.
Count of tensors = 454.
如您所见,垃圾收集器跟踪的张量数量不断增加。
可以找到要执行的相关文件here。
我有两个问题: 1. 是什么阻碍了垃圾收集器释放这些张量? 2.内存不足怎么办?
如何识别错误? 在 tracemalloc 的帮助下,我拍了两张快照,中间有数百次迭代。该教程将向您展示它很容易遵循。
错误原因是什么? Pytorch 中的
rpn.anchor_generator._cache
是一个 pythondict
,它跟踪网格锚点。它是检测模型的一个属性,大小随着每个提议而增加。如何解决? 在训练迭代结束时放置一个简单的绕过
model.rpn.anchor_generator._cache.clear()
。
我已经向 PyTorch 提交了 fix。自 torchvision 0.5 以来,您可能不会遇到 OOM 错误。