你如何在pytorch中正确使用next_functions[0][0] on grad_fn?

How do you use next_functions[0][0] on grad_fn correctly in pytorch?

我在pytorch官方教程中得到了这个nn结构:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> view -> linear -> relu -> linear -> relu -> linear -> MSELoss -> loss

然后是如何使用 Variable 中的内置 .grad_fn 向后跟踪 grad 的示例。

# Eg: 
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

所以我认为我可以通过粘贴 next_function[0][0] 9 次来达到 Conv2d 的 grad 对象,因为给定的示例但是我从索引中得到了错误元组。那么我怎样才能正确地索引这些反向传播对象呢?

PyTorch CNN tutorial 之后 运行 以下来自教程:

output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)

以下代码片段将打印完整图表:

def print_graph(g, level=0):
    if g == None: return
    print('*'*level*4, g)
    for subg in g.next_functions:
        print_graph(subg[0], level+1)

print_graph(loss.grad_fn, 0)

尝试运行宁

print(loss.grad_fn.next_functions[0][0].next_functions)

你会看到这给出了一个包含三个元素的数组。它实际上是您要选择的 [1][0] 元素,否则您将获得累积的梯度,并且您不能再进一步了。当你深入挖掘时,你会发现你可以一路打通网络。例如,尝试 运行ning:

print(loss.grad_fn.next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions)

先运行.next_functions不索引进去,然后看你需要选择哪个元素才能到达nn的下一层