PyTorch 模型跟踪不起作用:我们没有 aten::fill_ 的操作
PyTorch model tracing not working: We don't have an op for aten::fill_
我在这个特定模块上跟踪 PyTorch 模型时遇到错误:
RuntimeError: 0INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":611, please report a bug to PyTorch. We don't have an op for aten::fill_ but it isn't a special case. Argument types: Tensor, bool,
Candidates:
aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> (Tensor(a!))
aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> (Tensor(a!))
这里是重现错误的精简代码示例:
import torch
import torch.nn.functional as F
import torch.nn as nn
class SurroundPattern(nn.Module):
def __init__(self, crop_size=1./2):
super(SurroundPattern, self).__init__()
self.crop_size = crop_size
def forward(self, x, s):
H,W = x.shape[2:]
crop_h = (int(H / 2 - self.crop_size / 2 * H), int(H / 2 + self.crop_size / 2 * H))
crop_w = (int(W / 2 - self.crop_size / 2 * W), int(W / 2 + self.crop_size / 2 * W))
x_mask = torch.zeros(H,W,device=x.device, dtype=torch.bool)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True
inside_indices = torch.where(x_mask)
inside_part = x[:, :, inside_indices[0], inside_indices[1]]
inside_feat = inside_part.mean(2)
outside_indices = torch.where(~x_mask)
outside_part = x[:, :, outside_indices[0], outside_indices[1]]
outside_feat = outside_part.mean(2)
fused = torch.stack([inside_feat, outside_feat], dim=2).unsqueeze(3)
if s is None:
return fused
SH,SW = s.shape[2:]
crop_sh = (int(SH / 2 - self.crop_size / 2 * SH), int(SH / 2 + self.crop_size / 2 * SH))
crop_sw = (int(SW / 2 - self.crop_size / 2 * SW), int(SW / 2 + self.crop_size / 2 * SW))
s_mask = torch.zeros(SH, SW, device=s.device, dtype=torch.bool)
s_mask[crop_sh[0] : crop_sh[1], crop_sw[0] : crop_sw[1]] = True
s_inside_indices = torch.where(s_mask)
inside_sal = s[:, :, s_inside_indices[0], s_inside_indices[1]].flatten(1)
s_outside_indices = torch.where(~s_mask)
outside_sal = s[:, :, s_outside_indices[0], s_outside_indices[1]].flatten(1)
if outside_sal.shape != inside_sal.shape:
outside_sal = F.adaptive_max_pool1d(outside_sal.unsqueeze(1), output_size=784)
outside_sal = outside_sal.squeeze(1)
fused_sal = torch.stack([inside_sal, outside_sal], dim=2).unsqueeze(3)
return fused, fused_sal
x = torch.randn(2, 512, 7, 7)
s = torch.randn(2, 1, 56, 56)
patt = SurroundPattern()
traced_cell = torch.jit.trace(patt, (x, s))
print(traced_cell)
如何找出问题的确切位置?有没有办法用其他功能修复它?
谢谢!
问题是你试图填充一个 bool Tensor,这显然在 jit 中还不支持(或一个错误)
替换这个:
x_mask= torch.zeros(H,W,device=x.device, dtype=torch.bool)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True
与:
x_mask= torch.zeros(H,W,device=x.device)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = 1
应该可以解决错误。这当然不是目标张量类型的最佳选择,但您应该能够执行您将使用 torch.BoolTensor
执行的任何其他操作
我在这个特定模块上跟踪 PyTorch 模型时遇到错误:
RuntimeError: 0INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":611, please report a bug to PyTorch. We don't have an op for aten::fill_ but it isn't a special case. Argument types: Tensor, bool,
Candidates:
aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> (Tensor(a!))
aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> (Tensor(a!))
这里是重现错误的精简代码示例:
import torch
import torch.nn.functional as F
import torch.nn as nn
class SurroundPattern(nn.Module):
def __init__(self, crop_size=1./2):
super(SurroundPattern, self).__init__()
self.crop_size = crop_size
def forward(self, x, s):
H,W = x.shape[2:]
crop_h = (int(H / 2 - self.crop_size / 2 * H), int(H / 2 + self.crop_size / 2 * H))
crop_w = (int(W / 2 - self.crop_size / 2 * W), int(W / 2 + self.crop_size / 2 * W))
x_mask = torch.zeros(H,W,device=x.device, dtype=torch.bool)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True
inside_indices = torch.where(x_mask)
inside_part = x[:, :, inside_indices[0], inside_indices[1]]
inside_feat = inside_part.mean(2)
outside_indices = torch.where(~x_mask)
outside_part = x[:, :, outside_indices[0], outside_indices[1]]
outside_feat = outside_part.mean(2)
fused = torch.stack([inside_feat, outside_feat], dim=2).unsqueeze(3)
if s is None:
return fused
SH,SW = s.shape[2:]
crop_sh = (int(SH / 2 - self.crop_size / 2 * SH), int(SH / 2 + self.crop_size / 2 * SH))
crop_sw = (int(SW / 2 - self.crop_size / 2 * SW), int(SW / 2 + self.crop_size / 2 * SW))
s_mask = torch.zeros(SH, SW, device=s.device, dtype=torch.bool)
s_mask[crop_sh[0] : crop_sh[1], crop_sw[0] : crop_sw[1]] = True
s_inside_indices = torch.where(s_mask)
inside_sal = s[:, :, s_inside_indices[0], s_inside_indices[1]].flatten(1)
s_outside_indices = torch.where(~s_mask)
outside_sal = s[:, :, s_outside_indices[0], s_outside_indices[1]].flatten(1)
if outside_sal.shape != inside_sal.shape:
outside_sal = F.adaptive_max_pool1d(outside_sal.unsqueeze(1), output_size=784)
outside_sal = outside_sal.squeeze(1)
fused_sal = torch.stack([inside_sal, outside_sal], dim=2).unsqueeze(3)
return fused, fused_sal
x = torch.randn(2, 512, 7, 7)
s = torch.randn(2, 1, 56, 56)
patt = SurroundPattern()
traced_cell = torch.jit.trace(patt, (x, s))
print(traced_cell)
如何找出问题的确切位置?有没有办法用其他功能修复它? 谢谢!
问题是你试图填充一个 bool Tensor,这显然在 jit 中还不支持(或一个错误)
替换这个:
x_mask= torch.zeros(H,W,device=x.device, dtype=torch.bool)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True
与:
x_mask= torch.zeros(H,W,device=x.device)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = 1
应该可以解决错误。这当然不是目标张量类型的最佳选择,但您应该能够执行您将使用 torch.BoolTensor