Pytorch 如何在不改变单个过滤器形状的情况下 reshape/reduce 过滤器的数量

Pytorch how to reshape/reduce the number of filters without altering the shape of the individual filters

对于 3D 张量的形状(过滤器的数量、高度、宽度),如何通过将原始过滤器保持为整个块的重塑来减少过滤器的数量?

假设新尺寸的尺寸选择使得所有原始过滤器可以并排放置在一个新过滤器中。所以原始尺寸 (4, 2, 2) 可以重塑为 (2, 2, 4)。

并排整形的视觉解释,您会看到标准整形将改变各个过滤器的形状:

我已经尝试了各种 pytorch 函数,例如 gatherselect_index,但没有找到以一般方式获得最终结果的方法(即适用于不同数量的过滤器和不同的过滤器尺码)。

我认为执行重塑后重新排列张量值会更容易,但无法获得 pytorch 重塑形式的张量:

[[[1,2,3,4],
  [5,6,7,8]],
 
 [[9,10,11,12],
  [13,14,15,16]]]

至:

[[[1,2,5,6],
  [3,4,7,8]],

 [[9,10,13,14],
  [11,12,15,16]]]

为了完整起见,整形前的原始张量:

[[[1,2],
  [3,4]],
 
 [[5,6],
  [7,8]],

 [[9,10],
  [11,12]],

 [[13,14],
  [15,16]]]

您可以通过分块张量然后重新组合来实现。

def side_by_side_reshape(x):
    n_pairs = x.shape[0] // 2
    filter_size = x.shape[-1]
    x = x.reshape((n_pairs, 2, filter_size, filter_size))
    return torch.stack(list(map(lambda x: torch.hstack(x.unbind()), k)))
>> p = torch.arange(1, 91).reshape((10, 3, 3))
>> side_by_side_reshape(p)

tensor([[[ 1,  2,  3, 10, 11, 12],
         [ 4,  5,  6, 13, 14, 15],
         [ 7,  8,  9, 16, 17, 18]],

        [[19, 20, 21, 28, 29, 30],
         [22, 23, 24, 31, 32, 33],
         [25, 26, 27, 34, 35, 36]],

        [[37, 38, 39, 46, 47, 48],
         [40, 41, 42, 49, 50, 51],
         [43, 44, 45, 52, 53, 54]],

        [[55, 56, 57, 64, 65, 66],
         [58, 59, 60, 67, 68, 69],
         [61, 62, 63, 70, 71, 72]],

        [[73, 74, 75, 82, 83, 84],
         [76, 77, 78, 85, 86, 87],
         [79, 80, 81, 88, 89, 90]]])

但我知道这并不理想,因为有 maplistunbind 会破坏记忆。这就是我提供的,直到我弄清楚如何仅通过视图来做到这一点(所以真正的重塑)

另一种选择是构造一个零件列表并将它们连接起来

x = torch.arange(4).reshape(4, 1, 1).repeat(1, 2, 2)
y = torch.cat([x[i::2] for i in range(2)], dim=2)

print('Before\n', x)
print('After\n', y)

这给出了

Before
 tensor([[[0, 0],
         [0, 0]],

        [[1, 1],
         [1, 1]],

        [[2, 2],
         [2, 2]],

        [[3, 3],
         [3, 3]]])
After
 tensor([[[0, 0, 1, 1],
         [0, 0, 1, 1]],

        [[2, 2, 3, 3],
         [2, 2, 3, 3]]])

或者更一般地说,我们可以编写一个函数,沿源维度获取邻居组并沿目标维度将它们连接起来

def group_neighbors(x, group_size, src_dim, dst_dim):
    assert x.shape[src_dim] % group_size == 0
    return torch.cat([x[[slice(None)] * (src_dim) + [slice(i, None, group_size)] + [slice(None)] * (len(x.shape) - (src_dim + 2))] for i in range(group_size)], dim=dst_dim)


x = torch.arange(4).reshape(4, 1, 1).repeat(1, 2, 2)
# read as "take neighbors in groups of 2 from dimension 0 and concatenate them in dimension 2"
y = group_neighbors(x, group_size=2, src_dim=0, dst_dim=2)

print('Before\n', x)
print('After\n', y)