如何用 Pytorch 张量中的特定值替换每一行填充为零的行?

How to replace every row that are filled with zero with a certain value from a Pytorch tensor?

我有一个大小为 bsize x 50 x 50 的 Pytorch 张量,其中一些行完全用零填充:

         [[0, 2, 0,  ..., 0, 0, 0],
         [2, 0, 2,  ..., 0, 0, 0],
         [0, 2, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 2, 0,  ..., 0, 0, 0],
         [2, 0, 2,  ..., 0, 0, 0],
         [0, 2, 0,  ..., 0, 0, 0],

我想用负值 -100 替换整个张量中用零填充的行。 预期张量 -

         [[0, 2, 0,  ..., 0, 0, 0],
         [2, 0, 2,  ..., 0, 0, 0],
         [0, 2, 0,  ..., 0, 0, 0],
         ...,
         [-100, -100, -100,  ..., -100,-100, -100],
         [-100, -100, -100,  ..., -100,-100, -100],
         [-100, -100, -100,  ..., -100,-100, -100]],

        [[0, 2, 0,  ..., 0, 0, 0],
         [2, 0, 2,  ..., 0, 0, 0],
         [0, 2, 0,  ..., 0, 0, 0],

避免在行形状上循环的最佳方法是什么?

假设 x 是你的张量 BxRxC(批次、行和列),你可以这样做:

x[(x == 0).all(dim=-1)] = -100

基本上:

  • x == 0 returns 一个布尔张量(形状 BxRxC),其中 True 等于零;
  • 然后,.all(dim=-1) returns 另一个布尔张量,现在形状为 BxR 因为我们选择在最后一个维度中做 all (-1) , True 其中 all 列是 True;
  • 最后,我们使用这个布尔张量对原始张量进行索引,并将 -100 分配给 True 位置。