为两个不同的神经网络调用 .backward() 函数但得到 retain_graph=True 错误

Calling .backward() function for two different neural networks but getting retain_graph=True error

我有一个 Actor Critic 神经网络,其中 Actor 是它自己的 class,Critic 是它自己的 class,具有自己的神经网络和 .forward() 函数。然后,我在更大的模型 class 中为每个 class 创建一个对象。我的设置如下:

self.actor = Actor().to(device)
self.actor_opt = optim.Adam(self.actor.parameters(), lr=lr)
self.critic = Critic().to(device)
self.critic_opt = optim.Adam(self.critic.parameters(), lr=lr)

然后我计算了两个不同的损失函数,并想分别更新每个神经网络。评论家:

loss_critic = F.smooth_l1_loss(value, expected)
self.critic_opt.zero_grad()
loss_critic.backward()
self.critic_opt.step()

演员:

loss_actor = -self.critic(state, action)
self.actor_opt.zero_grad()
loss_actor.backward()
self.actor_opt.step()

但是,在执行此操作时,出现以下错误:

RuntimeError: Trying to backward through the graph a second time, but the saved 
intermediate results have already been freed. Specify retain_graph=True when
calling backward the first time.

阅读本文时,我了解到在同一网络上向后调用两次时我只需要 retain_graph=True,并且在大多数情况下设置为 True 并不好,因为我会 运行 GPU 耗尽。此外,当我注释掉其中一个 .backward() 函数时,错误消失了,使我相信由于某种原因代码认为两个 backward() 函数正在同一个神经网络上调用,即使我想我是分开做的。这可能是什么原因?有没有办法指定我在哪个神经网络上调用反向函数?

编辑: 作为参考,此代码中的 optimize() 函数 https://github.com/wudongming97/PyTorch-DDPG/blob/master/train.py 使用 backward() 两次没有问题(我已经克隆了 repo 并对其进行了测试)。我希望我的代码能够以类似的方式运行,我分别通过评论家和演员进行反向传播。

是的,你不应该那样做。相反,您应该做的是通过图形的各个部分进行传播。

图表包含的内容

现在,图表包含 actorcritic。如果计算通过图的同一部分(例如,两次通过 actor),则会引发此错误。

  • 他们会,因为你清楚地使用 actorcritic 加入损失值 (这一行:loss_actor = -self.critic(state, action))

  • 不同的优化器不会改变这里的任何东西,因为这是 backward 问题(优化器只是将计算出的梯度应用到模型上)

正在尝试修复它

  • 这是在 GAN 中修复它的方法,但在这种情况下不是,请参阅下面的 Actual fix 段落,如果您对该主题感到好奇,请继续阅读

如果神经网络的一部分(在本例中为critic不参与当前优化步骤,则应将其视为常量(并且反之亦然)。

为此,您可以使用 torch.no_grad 上下文管理器 (documentation) and set critic to eval mode (documentation) 禁用 gradient,大致如下:

self.critic.eval()
with torch.no_grad():
    loss_actor = -self.critic(state, action)
...

但是,这里有一个问题:

We are turning off gradient (tape recording) for action and breaking the graph!

因此这不是一个可行的解决方案。

实际解决方案

它比你想象的要简单得多,你也可以在PyTorch's repository中看到它:

  • 不要backpropagatecritic/actor损失
  • 计算所有损失(criticactor
  • sum他们在一起
  • zero_grad 两个优化器
  • backpropagate 加上这个求和值
  • critic_optimizer.step()actor_optimizer.step()此时

类似的东西:

self.critic_opt.zero_grad()
self.actor_opt.zero_grad()

loss_critic = F.smooth_l1_loss(value, expected)
loss_actor = -self.critic(state, action)

total_loss = loss_actor + loss_critic
total_loss.backward()

self.critic_opt.step()
self.actor_opt.step()