Pytorch DDP 卡在获取空闲端口

Pytorch DDP get stuck in getting free port

我尝试在 PyTorch 的 DDP 初始化中获取空闲端口。但是,我的代码卡住了。以下片段可以重复我的描述:

def get_open_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    port = get_open_port()
    os.environ['MASTER_PORT'] = str(port)   # '12345'

    # Initialize the process group.
    dist.init_process_group('NCCL', rank=rank, world_size=world_size)

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 5)

    def forward(self, x):
        print(f'x device={x.device}')
        return self.net1(x)


def demo_basic(rank, world_size):
    setup(rank, world_size)

    logger = logging.getLogger('train')
    logger.setLevel(logging.DEBUG)
    logger.info(f'Running DPP on rank={rank}.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20, 10)  # .to(rank)

    print(f'inputs device={inputs.device}')
    outputs = ddp_model(inputs)
    print(f'output device={outputs.device}')

    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func, world_size):
    mp.spawn(
        demo_func,
        args=(world_size,),
        nprocs=world_size,
        join=True
    )

run_demo(demo_basic, 4)

函数get_open_port 应该在调用后释放端口。我的问题是: 1. 它是如何发生的? 2.如何解决?

答案来自here。详细的回答是: 1. 由于每个空闲端口都是由单独的进程生成的,所以最终端口是不同的; 2.我们可以一开始就得到一个空闲端口,然后传给进程。

更正后的代码段:

def get_open_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]


def setup(rank, world_size, port):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)

    # Initialize the process group.
    dist.init_process_group('NCCL', rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 5)

    def forward(self, x):
        print(f'x device={x.device}')
        # return self.net2(self.relu(self.net1(x)))
        return self.net1(x)


def demo_basic(rank, world_size, free_port):
    setup(rank, world_size, free_port)

    logger = logging.getLogger('train')
    logger.setLevel(logging.DEBUG)
    logger.info(f'Running DPP on rank={rank}.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20, 10)  # .to(rank)

    print(f'inputs device={inputs.device}')
    outputs = ddp_model(inputs)
    print(f'output device={outputs.device}')

    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func, world_size, free_port):
    mp.spawn(
        demo_func,
        args=(world_size, free_port),
        nprocs=world_size,
        join=True
    )

free_port = get_open_port()
run_demo(demo_basic, 4, free_port)