ND输入的im2col算法

The im2col algorithm for ND input

我正在尝试为输入维度 > 2D 编写自己的 im2col 算法。 目前我正在查看 Matlab im2col 实现。但是,我找不到任何关于任何超过 2 维的输入发生了什么的文档。

如果我将 3D 张量输入函数,我确实会得到一个输出。但是我真的不明白你是如何从 2D 到 ND 的。文档中没有提到这一点的事实表明它很简单,但我还是不明白。

哎呀,我什至不明白为什么输出矩阵的大小是现在的大小。

让我先说 im2col 仅适用于二维矩阵。事实上它有时会起作用(我的意思是返回一个结果而不会抛出错误)只是一个快乐的巧合。

现在我看了一下edit im2col.m,不用过多研究代码,distinctsliding方法的第一行应该给你一个直觉发生了什么:

...
if strcmp(kind, 'distinct')
    [m,n] = size(a);
    ...
elseif strcmp(kind,'sliding')
    [ma,na] = size(a);
    ...
end
...

首先回想一下 [s1,s2] = size(arr) 其中 arr 是一个 3d 数组会将第 2 维和第 3 维的大小折叠为一个大小。这是相关的 doc size:

[d1,d2,d3,...,dn] = size(X) returns the sizes of the dimensions of the array X, provided the number of output arguments n equals ndims(X). If n < ndims(X), di equals the size of the ith dimension of X for 0<i<n, but dn equals the product of the sizes of the remaining dimensions of X, that is, dimensions n through ndims(X).

所以基本上对于大小为 M-by-N-by-P 的数组,函数反而认为它是大小为 M-by-(N*P) 的矩阵。现在 MATLAB 有一些古怪的索引规则,可以让你做这样的事情:

>> x = reshape(1:4*3*2,4,3,2)
x(:,:,1) =
     1     5     9
     2     6    10
     3     7    11
     4     8    12
x(:,:,2) =
    13    17    21
    14    18    22
    15    19    23
    16    20    24
>> x(:,:)
ans =
     1     5     9    13    17    21
     2     6    10    14    18    22
     3     7    11    15    19    23
     4     8    12    16    20    24

我认为这就是最终发生在这里的事情。以下是确认 im2col 在 RGB 图像上的行为的示例:

% normal case (grayscale image)
>> M = magic(5);
>> B1 = im2col(M, [3 3], 'sliding');

% (RGB image)
>> MM = cat(3, M, M+50, M+100);
>> B2 = im2col(MM, [3 3], 'sliding');
>> B3 = im2col(reshape(MM, [5 5*3]), [3 3], 'sliding');
>> assert(isequal(B2,B3))

注意B2B3是相等的,所以基本上认为im2col在数组arr = cat(3,R,G,B)上的结果与[=的结果相同41=](水平连接)。

有趣的是,使用 "distinct" 块方法你不会那么幸运:

>> B1 = im2col(M, [3 3], 'distinct')    % works
% ..snip..

>> B2 = im2col(MM, [3 3], 'distinct')   % errors
Subscripted assignment dimension mismatch.
Error in im2col (line 59)
    aa(1:m,1:n) = a; 

现在我们了解了发生的事情,让我们考虑如何为 3D 阵列正确执行此操作。

在我看来,要为彩色图像实现 im2col,我会在每个颜色通道上分别 运行(每个都是一个二维矩阵),然后沿三维连接结果。所以像这个包装函数:

function B = im2col_rgb(img, sz, varargin)
    B = cell(1,size(img,3));
    for i=1:size(img,3)
        B{i} = im2col(img(:,:,i), sz, varargin{:});
    end
    B = cat(3, B{:});
end