你如何在pytorch中正确使用next_functions[0][0] on grad_fn?
How do you use next_functions[0][0] on grad_fn correctly in pytorch?
我在pytorch官方教程中得到了这个nn结构:
input
-> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
-> view -> linear -> relu -> linear -> relu -> linear
-> MSELoss
-> loss
然后是如何使用 Variable 中的内置 .grad_fn 向后跟踪 grad 的示例。
# Eg:
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
所以我认为我可以通过粘贴 next_function[0][0] 9 次来达到 Conv2d 的 grad 对象,因为给定的示例但是我从索引中得到了错误元组。那么我怎样才能正确地索引这些反向传播对象呢?
在 PyTorch CNN tutorial 之后 运行 以下来自教程:
output = net(input)
target = torch.randn(10) # a dummy target, for example
target = target.view(1, -1) # make it the same shape as output
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)
以下代码片段将打印完整图表:
def print_graph(g, level=0):
if g == None: return
print('*'*level*4, g)
for subg in g.next_functions:
print_graph(subg[0], level+1)
print_graph(loss.grad_fn, 0)
尝试运行宁
print(loss.grad_fn.next_functions[0][0].next_functions)
你会看到这给出了一个包含三个元素的数组。它实际上是您要选择的 [1][0] 元素,否则您将获得累积的梯度,并且您不能再进一步了。当你深入挖掘时,你会发现你可以一路打通网络。例如,尝试 运行ning:
print(loss.grad_fn.next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions)
先运行.next_functions不索引进去,然后看你需要选择哪个元素才能到达nn的下一层
我在pytorch官方教程中得到了这个nn结构:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> view -> linear -> relu -> linear -> relu -> linear -> MSELoss -> loss
然后是如何使用 Variable 中的内置 .grad_fn 向后跟踪 grad 的示例。
# Eg:
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
所以我认为我可以通过粘贴 next_function[0][0] 9 次来达到 Conv2d 的 grad 对象,因为给定的示例但是我从索引中得到了错误元组。那么我怎样才能正确地索引这些反向传播对象呢?
在 PyTorch CNN tutorial 之后 运行 以下来自教程:
output = net(input)
target = torch.randn(10) # a dummy target, for example
target = target.view(1, -1) # make it the same shape as output
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)
以下代码片段将打印完整图表:
def print_graph(g, level=0):
if g == None: return
print('*'*level*4, g)
for subg in g.next_functions:
print_graph(subg[0], level+1)
print_graph(loss.grad_fn, 0)
尝试运行宁
print(loss.grad_fn.next_functions[0][0].next_functions)
你会看到这给出了一个包含三个元素的数组。它实际上是您要选择的 [1][0] 元素,否则您将获得累积的梯度,并且您不能再进一步了。当你深入挖掘时,你会发现你可以一路打通网络。例如,尝试 运行ning:
print(loss.grad_fn.next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions)
先运行.next_functions不索引进去,然后看你需要选择哪个元素才能到达nn的下一层