Pytorch `bachward()` 更新多个模型
Pytorch `bachward()` updates multiple models
谁能告诉我为什么鉴别器的梯度也会发生变化,是否有办法避免这种情况?
for i in range(2):
X_fake = gen_model(z)
pred_real = disc_model(X)
pred_fake = disc_model(X_fake.detach())
disc_loss = (loss_fn(pred_real, y) + loss_fn(pred_fake, y)) / 2
disc_optimizer.zero_grad()
disc_loss.backward()
disc_optimizer.step()
pred_fake = disc_model(X_fake)
gen_loss = loss_fn(pred_fake, y)
gen_optimizer.zero_grad()
i == 1 and print_grads(disc_model) # Checkpoint 1
gen_loss.backward()
i == 1 and print_grads(disc_model) # Checkpoint 2
gen_optimizer.step()
这是代码的其余部分。
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self._linear = nn.Sequential( nn.Linear(1, 5) )
def forward(self, X):
return self._linear(X)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self._linear = nn.Sequential( nn.Linear(5, 1) )
def forward(self, X):
return self._linear(X)
def print_grads(model):
for params in model.parameters():
print(params.grad)
# Build the model and data.
gen_model = Generator()
gen_optimizer = torch.optim.Adam(gen_model.parameters(), 1)
disc_model = Discriminator()
disc_optimizer = torch.optim.Adam(disc_model.parameters(), 1)
loss_fn = torch.nn.BCEWithLogitsLoss()
z = torch.rand((1, 1))
X = torch.rand((1, 5))
y = torch.rand((1, 1))
鉴别器的梯度已更新,因为您已通过鉴别器将损失 gen_loss
反向传播到生成器本身。在您当前的训练循环公式中,这不是问题,因为只有生成器的参数会在下一行 gen_optimizer.step()
得到更新。换句话说:是的,鉴别器的梯度将被更新(它们从#1 变为#2)但这些不会用于更新鉴别器的参数。只要你在zero_grad
反向传播之前正确清除梯度缓存,你就没事了。
一种常见的做法是在训练生成器时冻结判别器,这样可以避免不必要的梯度计算:
# freeze discriminator
disc_model.requires_grad_(False)
# train the generator
gen_optimizer.zero_grad()
gen_loss.backward()
gen_optimizer.step()
# unfreeze discriminator for next step
disc_model.requires_grad_(True)
谁能告诉我为什么鉴别器的梯度也会发生变化,是否有办法避免这种情况?
for i in range(2):
X_fake = gen_model(z)
pred_real = disc_model(X)
pred_fake = disc_model(X_fake.detach())
disc_loss = (loss_fn(pred_real, y) + loss_fn(pred_fake, y)) / 2
disc_optimizer.zero_grad()
disc_loss.backward()
disc_optimizer.step()
pred_fake = disc_model(X_fake)
gen_loss = loss_fn(pred_fake, y)
gen_optimizer.zero_grad()
i == 1 and print_grads(disc_model) # Checkpoint 1
gen_loss.backward()
i == 1 and print_grads(disc_model) # Checkpoint 2
gen_optimizer.step()
这是代码的其余部分。
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self._linear = nn.Sequential( nn.Linear(1, 5) )
def forward(self, X):
return self._linear(X)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self._linear = nn.Sequential( nn.Linear(5, 1) )
def forward(self, X):
return self._linear(X)
def print_grads(model):
for params in model.parameters():
print(params.grad)
# Build the model and data.
gen_model = Generator()
gen_optimizer = torch.optim.Adam(gen_model.parameters(), 1)
disc_model = Discriminator()
disc_optimizer = torch.optim.Adam(disc_model.parameters(), 1)
loss_fn = torch.nn.BCEWithLogitsLoss()
z = torch.rand((1, 1))
X = torch.rand((1, 5))
y = torch.rand((1, 1))
鉴别器的梯度已更新,因为您已通过鉴别器将损失 gen_loss
反向传播到生成器本身。在您当前的训练循环公式中,这不是问题,因为只有生成器的参数会在下一行 gen_optimizer.step()
得到更新。换句话说:是的,鉴别器的梯度将被更新(它们从#1 变为#2)但这些不会用于更新鉴别器的参数。只要你在zero_grad
反向传播之前正确清除梯度缓存,你就没事了。
一种常见的做法是在训练生成器时冻结判别器,这样可以避免不必要的梯度计算:
# freeze discriminator
disc_model.requires_grad_(False)
# train the generator
gen_optimizer.zero_grad()
gen_loss.backward()
gen_optimizer.step()
# unfreeze discriminator for next step
disc_model.requires_grad_(True)