检查索引是否匹配的矢量化逻辑
Vectorizing logic to check if index matches
我有以下完美运行的功能,但我想对其应用矢量化...
for i = 1:size(centroids,1)
centroids(i, :) = mean(X(idx == i, :));
end
它检查 idx
是否与当前索引匹配,如果匹配,则计算与该索引对应的所有 X
值的 mean
值。
这是我的矢量化尝试,我的解决方案不起作用,我知道为什么...
centroids = mean(X(idx == [1:size(centroids,1)], :));
下面的 idx == [1:size(centroids,1)]
破坏了代码。我不知道如何检查 idx
是否等于从 1
到 size(centroids,1)
.
中的任何一个数字
tl:dr
通过矢量化摆脱 for 循环
您可以将矩阵拆分为单元格,并使用 cellfun
(在其内部操作中应用循环)从每个单元格中取平均值:
生成数据:
dim = 10;
N = 400;
nc = 20;
idx = randi(nc,[N 1]);
X = rand(N,dim);
centroids = zeros(nc,dim);
mean using loop(题目的方法)
for i = 1:size(centroids,1)
centroids(i, :) = mean(X(idx == i, :));
end
矢量化:
% split X into cells by idx
A = accumarray(idx, (1:N)', [nc,1], @(i) {X(i,:)});
% mean of each cell
C = cell2mat(cellfun(@(x) mean(x,1),A,'UniformOutput',0));
方法之间的最大绝对误差:
max(abs(C(:) - centroids(:))) % about 1e-16
一种选择是使用 arrayfun
;
nIdx = size(centroids,1);
centroids = arrayfun(@(ii) mean(X(idx==ii,:)),1:nIdx, 'UniformOutput', false);
centroids = vertcat(centroids{:})
由于单个函数调用的输出不一定是标量,因此 UniformOutput
选项必须设置为 false
。因此,arrayfun
returns 一个元胞数组,您需要 vertcat
它以获得所需的双精度数组。
我有以下完美运行的功能,但我想对其应用矢量化...
for i = 1:size(centroids,1)
centroids(i, :) = mean(X(idx == i, :));
end
它检查 idx
是否与当前索引匹配,如果匹配,则计算与该索引对应的所有 X
值的 mean
值。
这是我的矢量化尝试,我的解决方案不起作用,我知道为什么...
centroids = mean(X(idx == [1:size(centroids,1)], :));
下面的 idx == [1:size(centroids,1)]
破坏了代码。我不知道如何检查 idx
是否等于从 1
到 size(centroids,1)
.
tl:dr
通过矢量化摆脱 for 循环
您可以将矩阵拆分为单元格,并使用 cellfun
(在其内部操作中应用循环)从每个单元格中取平均值:
生成数据:
dim = 10;
N = 400;
nc = 20;
idx = randi(nc,[N 1]);
X = rand(N,dim);
centroids = zeros(nc,dim);
mean using loop(题目的方法)
for i = 1:size(centroids,1)
centroids(i, :) = mean(X(idx == i, :));
end
矢量化:
% split X into cells by idx
A = accumarray(idx, (1:N)', [nc,1], @(i) {X(i,:)});
% mean of each cell
C = cell2mat(cellfun(@(x) mean(x,1),A,'UniformOutput',0));
方法之间的最大绝对误差:
max(abs(C(:) - centroids(:))) % about 1e-16
一种选择是使用 arrayfun
;
nIdx = size(centroids,1);
centroids = arrayfun(@(ii) mean(X(idx==ii,:)),1:nIdx, 'UniformOutput', false);
centroids = vertcat(centroids{:})
由于单个函数调用的输出不一定是标量,因此 UniformOutput
选项必须设置为 false
。因此,arrayfun
returns 一个元胞数组,您需要 vertcat
它以获得所需的双精度数组。