如何以正常的非并行方式 运行 Pytorch 模型?

How to run Pytorch model in normal non-parallel way?

我正在查看 this script,这里有一个代码块考虑了 2 个选项,DataParallelDistributedDataParallel

if not args.distributed:
    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()
else:
    model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model)

如果我不想要这两个选项中的任何一个,并且 我想要 运行 它甚至没有 DataParallel 怎么办?我该怎么做?

如何定义我的模型,使其 运行 成为一个普通的 nn 而不是并行化任何东西?

  • DataParallel 是一个包装器对象,用于在同一台机器的多个 GPU 上并行计算,请参阅 here
  • DistributedDataParallel 也是一个包装器对象,可让您在多个设备上分发数据,请参阅 here

如果您不想要它,您可以简单地移除包装器并按原样使用模型:

if not args.distributed:
    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = model.features
        model.cuda()
    else:
        model = model.cuda()
else:
    model.cuda()
    model = model

这是为了尽量减少代码修改。当然,由于您对并行化不感兴趣,您可以将整个 if 语句放到以下行中:

if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
    model.features = model.features
model = model.cuda()

请注意,此代码假定您在 GPU 上 运行。