Pytorch `bachward()` 更新多个模型

Pytorch `bachward()` updates multiple models

谁能告诉我为什么鉴别器的梯度也会发生变化,是否有办法避免这种情况?

for i in range(2):

    X_fake = gen_model(z)

    pred_real = disc_model(X)
    pred_fake = disc_model(X_fake.detach())
    disc_loss = (loss_fn(pred_real, y) + loss_fn(pred_fake, y)) / 2

    disc_optimizer.zero_grad()
    disc_loss.backward()
    disc_optimizer.step()

    pred_fake = disc_model(X_fake)
    gen_loss = loss_fn(pred_fake, y)

    gen_optimizer.zero_grad()
    i == 1 and print_grads(disc_model) # Checkpoint 1
    gen_loss.backward()
    i == 1 and print_grads(disc_model) # Checkpoint 2
    gen_optimizer.step()

这是代码的其余部分。

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self._linear = nn.Sequential( nn.Linear(1, 5) )

    def forward(self, X):
        return self._linear(X)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self._linear = nn.Sequential( nn.Linear(5, 1) )

    def forward(self, X):
        return self._linear(X)

def print_grads(model):
    for params in model.parameters():
        print(params.grad)

# Build the model and data.
gen_model = Generator()
gen_optimizer = torch.optim.Adam(gen_model.parameters(), 1)

disc_model = Discriminator()
disc_optimizer = torch.optim.Adam(disc_model.parameters(), 1)

loss_fn = torch.nn.BCEWithLogitsLoss()

z = torch.rand((1, 1))
X = torch.rand((1, 5))
y = torch.rand((1, 1))

鉴别器的梯度已更新,因为您已通过鉴别器将损失 gen_loss 反向传播到生成器本身。在您当前的训练循环公式中,这不是问题,因为只有生成器的参数会在下一行 gen_optimizer.step() 得到更新。换句话说:是的,鉴别器的梯度将被更新(它们从#1 变为#2)但这些不会用于更新鉴别器的参数。只要你在zero_grad反向传播之前正确清除梯度缓存,你就没事了。

一种常见的做法是在训练生成器时冻结判别器,这样可以避免不必要的梯度计算:

# freeze discriminator
disc_model.requires_grad_(False)

# train the generator
gen_optimizer.zero_grad()
gen_loss.backward()
gen_optimizer.step()

# unfreeze discriminator for next step
disc_model.requires_grad_(True)