检查索引是否匹配的矢量化逻辑

Vectorizing logic to check if index matches

我有以下完美运行的功能,但我想对其应用矢量化...

for i = 1:size(centroids,1)
    centroids(i, :) = mean(X(idx == i, :));
end

它检查 idx 是否与当前索引匹配,如果匹配,则计算与该索引对应的所有 X 值的 mean 值。

这是我的矢量化尝试,我的解决方案不起作用,我知道为什么...

centroids = mean(X(idx == [1:size(centroids,1)], :));

下面的 idx == [1:size(centroids,1)] 破坏了代码。我不知道如何检查 idx 是否等于从 1size(centroids,1).

中的任何一个数字

tl:dr

通过矢量化摆脱 for 循环

您可以将矩阵拆分为单元格,并使用 cellfun(在其内部操作中应用循环)从每个单元格中取平均值:

生成数据:

dim = 10;
N = 400;
nc = 20;
idx = randi(nc,[N 1]);
X = rand(N,dim);
centroids = zeros(nc,dim);

mean using loop(题目的方法)

for i = 1:size(centroids,1)
    centroids(i, :) = mean(X(idx == i, :));
end

矢量化:

% split X into cells by idx
A = accumarray(idx, (1:N)', [nc,1], @(i) {X(i,:)});
% mean of each cell
C = cell2mat(cellfun(@(x) mean(x,1),A,'UniformOutput',0));

方法之间的最大绝对误差:

max(abs(C(:) - centroids(:))) % about 1e-16

一种选择是使用 arrayfun;

nIdx      = size(centroids,1);
centroids = arrayfun(@(ii) mean(X(idx==ii,:)),1:nIdx, 'UniformOutput', false);
centroids = vertcat(centroids{:})

由于单个函数调用的输出不一定是标量,因此 UniformOutput 选项必须设置为 false。因此,arrayfun returns 一个元胞数组,您需要 vertcat 它以获得所需的双精度数组。