PyTorch 模型跟踪不起作用:我们没有 aten::fill_ 的操作

PyTorch model tracing not working: We don't have an op for aten::fill_

我在这个特定模块上跟踪 PyTorch 模型时遇到错误:

RuntimeError: 0INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":611, please report a bug to PyTorch. We don't have an op for aten::fill_ but it isn't a special case.  Argument types: Tensor, bool, 

Candidates:
    aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> (Tensor(a!))
    aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> (Tensor(a!))

这里是重现错误的精简代码示例:

import torch
import torch.nn.functional as F
import torch.nn as nn

class SurroundPattern(nn.Module):
    def __init__(self, crop_size=1./2):
        super(SurroundPattern, self).__init__()
        self.crop_size = crop_size

    def forward(self, x, s):
        H,W         = x.shape[2:]
        crop_h      = (int(H / 2 - self.crop_size / 2 * H), int(H / 2 + self.crop_size / 2 * H))
        crop_w      = (int(W / 2 - self.crop_size / 2 * W), int(W / 2 + self.crop_size / 2 * W))
        x_mask      = torch.zeros(H,W,device=x.device, dtype=torch.bool)
        x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True

        inside_indices  = torch.where(x_mask)
        inside_part = x[:, :, inside_indices[0], inside_indices[1]]
        inside_feat = inside_part.mean(2)

        outside_indices = torch.where(~x_mask)
        outside_part    = x[:, :, outside_indices[0], outside_indices[1]]
        outside_feat    = outside_part.mean(2)
        fused = torch.stack([inside_feat, outside_feat], dim=2).unsqueeze(3)
        if s is None:
            return fused

        SH,SW       = s.shape[2:]
        crop_sh     = (int(SH / 2 - self.crop_size / 2 * SH), int(SH / 2 + self.crop_size / 2 * SH))
        crop_sw     = (int(SW / 2 - self.crop_size / 2 * SW), int(SW / 2 + self.crop_size / 2 * SW))
        s_mask      = torch.zeros(SH, SW, device=s.device, dtype=torch.bool)
        s_mask[crop_sh[0] : crop_sh[1], crop_sw[0] : crop_sw[1]] = True

        s_inside_indices = torch.where(s_mask)
        inside_sal  = s[:, :, s_inside_indices[0], s_inside_indices[1]].flatten(1)

        s_outside_indices = torch.where(~s_mask)
        outside_sal = s[:, :, s_outside_indices[0], s_outside_indices[1]].flatten(1)
        if outside_sal.shape != inside_sal.shape:
            outside_sal = F.adaptive_max_pool1d(outside_sal.unsqueeze(1), output_size=784)
            outside_sal = outside_sal.squeeze(1)
        fused_sal    = torch.stack([inside_sal, outside_sal], dim=2).unsqueeze(3)
        return fused, fused_sal

x = torch.randn(2, 512, 7, 7)
s = torch.randn(2, 1, 56, 56)

patt = SurroundPattern()

traced_cell = torch.jit.trace(patt, (x, s))
print(traced_cell)

如何找出问题的确切位置?有没有办法用其他功能修复它? 谢谢!

问题是你试图填充一个 bool Tensor,这显然在 jit 中还不支持(或一个错误)

替换这个:

    x_mask= torch.zeros(H,W,device=x.device, dtype=torch.bool)
    x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True

与:

    x_mask= torch.zeros(H,W,device=x.device)
    x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = 1

应该可以解决错误。这当然不是目标张量类型的最佳选择,但您应该能够执行您将使用 torch.BoolTensor

执行的任何其他操作