Pytorch 中的内存泄漏:对象检测

Memory leak in Pytorch: object detection

我正在 PyTorch 上开发 object detection tutorial。原始教程在给定的几个时期内运行良好。我将它扩展到大纪元并遇到 内存不足 错误。

我试着调试它并发现了一些有趣的东西。这是我正在使用的工具:

def debug_gpu():
    # Debug out of memory bugs.
    tensor_list = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                tensor_list.append(obj)
        except:
            pass
    print(f'Count of tensors = {len(tensor_list)}.')

我用它来监控训练一个epoch的记忆:

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    ...
    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        # inference + backward + optimization
        debug_gpu()

输出是这样的:

Count of tensors = 414.
Count of tensors = 419.
Count of tensors = 424.
Count of tensors = 429.
Count of tensors = 434.
Count of tensors = 439.
Count of tensors = 439.
Count of tensors = 444.
Count of tensors = 449.
Count of tensors = 449.
Count of tensors = 454.

如您所见,垃圾收集器跟踪的张量数量不断增加。

可以找到要执行的相关文件here

我有两个问题: 1. 是什么阻碍了垃圾收集器释放这些张量? 2.内存不足怎么办?

  1. 如何识别错误? 在 tracemalloc 的帮助下,我拍了两张快照,中间有数百次迭代。该教程将向您展示它很容易遵循。

  2. 错误原因是什么? Pytorch 中的 rpn.anchor_generator._cache 是一个 python dict,它跟踪网格锚点。它是检测模型的一个属性,大小随着每个提议而增加。

  3. 如何解决? 在训练迭代结束时放置一个简单的绕过 model.rpn.anchor_generator._cache.clear()


我已经向 PyTorch 提交了 fix。自 torchvision 0.5 以来,您可能不会遇到 OOM 错误。