找到每行最大值索引的最快方法

Fastest way to find the index of the maximum of each row

我正在尝试找到一种最佳方法来查找每行中最大值的索引。问题是我找不到真正有效的方法。 一个例子:

Dummy <- matrix(runif(500000000,0,3), ncol = 10000)
> system.time(max.col(Dummy, "first"))
   user  system elapsed 
  5.532   0.075   5.599 
> system.time(apply(Dummy,1,which.max))
   user  system elapsed 
 14.638   0.210  14.828 
> system.time(rowRanges(Dummy))
   user  system elapsed 
  2.083   0.029   2.109 

我的主要问题是,与使用 rowRanges 函数计算最大值和最小值相比,为什么计算最大值的索引要慢 2 倍以上。有什么方法可以提高计算每行最大值的索引的性能吗?

R 将矩阵存储在 column-major order 中。因此,遍历 通常会更快,因为一列的值在内存中彼此接近,并且会一次性遍历缓存层次结构:

Dummy <- matrix(runif(100000000,0,3), ncol = 10000)
system.time(apply(Dummy,1,function(x) NULL))
##   user  system elapsed 
##  1.360   0.160   1.519 
system.time(apply(Dummy,2,function(x) NULL))
##   user  system elapsed 
##   0.94    0.12    1.06 

这应该接近即使是最快的 Rcpp 解决方案也能获得的最短时间。任何使用 apply() 的解决方案都必须复制每个 column/row,这可以在使用 Rcpp 时保存。您决定 speed-up 的潜力是否值得您付出 2 倍的努力。

通常,在 R 中做事最快的方法是调用 C、C++ 或 FORTRAN。

似乎 matrixStats::rowRangesimplemented in C,这解释了为什么它是最快的。

如果你想进一步提高性能,修改 rowRanges.c 代码以忽略最小值而只获取最大值可能会获得一点速度,但我认为收益会是很小。

这是一个非常基本的 Rcpp 实现:

#include <Rcpp.h>

// [[Rcpp::export]]
Rcpp::NumericVector MaxCol(Rcpp::NumericMatrix m) {
    R_xlen_t nr = m.nrow(), nc = m.ncol(), i = 0;
    Rcpp::NumericVector result(nr);

    for ( ; i < nr; i++) {
        double current = m(i, 0);
        R_xlen_t idx = 0, j = 1;
        for ( ; j < nc; j++) {
            if (m(i, j) > current) {
                current = m(i, j);
                idx = j;
            }
        }
        result[i] = idx + 1;
    }
    return result;
}

/*** R

microbenchmark::microbenchmark(
    "Rcpp" = MaxCol(Dummy), 
    "R" = max.col(Dummy, "first"),
    times = 200L
)
#Unit: milliseconds
# expr      min       lq     mean   median       uq      max neval
# Rcpp 221.7777 224.7442 242.0089 229.6407 239.6339 455.9549   200
# R    513.4391 524.7585 562.7465 539.4829 562.3732 944.7587   200

*/

由于我的笔记本电脑没有足够的内存,我不得不将你的样本数据缩小一个数量级,但结果应该转化为你的原始样本数据:

Dummy <- matrix(runif(50000000,0,3), ncol = 10000)
all.equal(MaxCol(Dummy), max.col(Dummy, "first"))
#[1] TRUE

这可以稍微更改为 return 每行中 minmax 的索引:

// [[Rcpp::export]]
Rcpp::NumericMatrix MinMaxCol(Rcpp::NumericMatrix m) {
    R_xlen_t nr = m.nrow(), nc = m.ncol(), i = 0;
    Rcpp::NumericMatrix result(nr, 2);

    for ( ; i < nr; i++) {
        double cmin = m(i, 0), cmax = m(i, 0);
        R_xlen_t min_idx = 0, max_idx = 0, j = 1;
        for ( ; j < nc; j++) {
            if (m(i, j) > cmax) {
                cmax = m(i, j);
                max_idx = j;
            }
            if (m(i, j) < cmin) {
                cmin = m(i, j);
                min_idx = j;
            }
        }
        result(i, 0) = min_idx + 1;
        result(i, 1) = max_idx + 1;
    }
    return result;
}

扩展 krlmlr 的答案,一些基准:

在数据集上:

set.seed(007); Dummy <- matrix(runif(50000000,0,3), ncol = 1000)

maxCol_R 是一个 R by-column 循环,maxCol_col 是一个 C by-column 循环,maxCol_row 是一个 C by-row 循环。

microbenchmark::microbenchmark(max.col(Dummy, "first"), maxCol_R(Dummy), maxCol_col(Dummy), maxCol_row(Dummy), times = 30)
#Unit: milliseconds
#                    expr        min         lq     median         uq       max neval
# max.col(Dummy, "first") 1209.28408 1245.24872 1268.34146 1291.26612 1504.0072    30
#         maxCol_R(Dummy) 1060.99994 1084.80260 1099.41400 1154.11213 1436.2136    30
#       maxCol_col(Dummy)   86.52765   87.22713   89.00142   93.29838  122.2456    30
#       maxCol_row(Dummy)  577.51613  583.96600  598.76010  616.88250  671.9191    30
all.equal(max.col(Dummy, "first"), maxCol_R(Dummy))
#[1] TRUE
all.equal(max.col(Dummy, "first"), maxCol_col(Dummy))
#[1] TRUE
all.equal(max.col(Dummy, "first"), maxCol_row(Dummy))
#[1] TRUE

以及函数:

maxCol_R = function(x)
{
    ans = rep_len(1L, nrow(x))
    mx = x[, 1L]

    for(j in 2:ncol(x)) {
        tmp = x[, j]
        wh = which(tmp > mx)

        ans[wh] = j
        mx[wh] = tmp[wh]
    }

    ans
} 

maxCol_col = inline::cfunction(sig = c(x = "matrix"), body = '
    int nr = INTEGER(getAttrib(x, R_DimSymbol))[0], nc = INTEGER(getAttrib(x, R_DimSymbol))[1]; 
    double *px = REAL(x), *buf = (double *) R_alloc(nr, sizeof(double));
    for(int i = 0; i < nr; i++) buf[i] = R_NegInf;

    SEXP ans = PROTECT(allocVector(INTSXP, nr));
    int *pans = INTEGER(ans);

    for(int j = 0; j < nc; j++) {
        for(int i = 0; i < nr; i++) {
            if(px[i + j*nr] > buf[i]) {
                buf[i] = px[i + j*nr];
                pans[i] = j + 1;
            }
        }
    }

    UNPROTECT(1);
    return(ans);
', language = "C")

maxCol_row = inline::cfunction(sig = c(x = "matrix"), body = '
    int nr = INTEGER(getAttrib(x, R_DimSymbol))[0], nc = INTEGER(getAttrib(x, R_DimSymbol))[1]; 
    double *px = REAL(x), *buf = (double *) R_alloc(nr, sizeof(double));
    for(int i = 0; i < nr; i++) buf[i] = R_NegInf;

    SEXP ans = PROTECT(allocVector(INTSXP, nr));
    int *pans = INTEGER(ans);

    for(int i = 0; i < nr; i++) {
        for(int j = 0; j < nc; j++) {
            if(px[i + j*nr] > buf[i]) {
                buf[i] = px[i + j*nr];
                pans[i] = j + 1;
            }
        }
    }

    UNPROTECT(1);
    return(ans);
', language = "C")

编辑2016 年 6 月 10 日

稍作更改即可找到最大和最小的索引:

rangeCol = inline::cfunction(sig = c(x = "matrix"), body = '
    int nr = INTEGER(getAttrib(x, R_DimSymbol))[0], nc = INTEGER(getAttrib(x, R_DimSymbol))[1]; 
    double *px = REAL(x), 
           *maxbuf = (double *) R_alloc(nr, sizeof(double)),
           *minbuf = (double *) R_alloc(nr, sizeof(double));
    memcpy(maxbuf, &(px[0 + 0*nr]), nr * sizeof(double));
    memcpy(minbuf, &(px[0 + 0*nr]), nr * sizeof(double));

    SEXP ans = PROTECT(allocMatrix(INTSXP, nr, 2));
    int *pans = INTEGER(ans); 
    for(int i = 0; i < LENGTH(ans); i++) pans[i] = 1;

    for(int j = 1; j < nc; j++) {
        for(int i = 0; i < nr; i++) {
            if(px[i + j*nr] > maxbuf[i]) {
                maxbuf[i] = px[i + j*nr];
                pans[i] = j + 1;
            }
            if(px[i + j*nr] < minbuf[i]) {
                minbuf[i] = px[i + j*nr];
                pans[i + nr] = j + 1;
            }
        }
    }

    UNPROTECT(1);
    return(ans);
', language = "C")

set.seed(007); m = matrix(sample(24) + 0, 6, 4)
m
#     [,1] [,2] [,3] [,4]
#[1,]   24    7   23    6
#[2,]   10   17   21   11
#[3,]    3   22   20   14
#[4,]    2   18    1   15
#[5,]    5   19   12    8
#[6,]   16    4    9   13
rangeCol(m)
#     [,1] [,2]
#[1,]    1    4
#[2,]    3    1
#[3,]    2    1
#[4,]    2    3
#[5,]    2    1
#[6,]    1    2       

尝试使用 STL 算法和 RcppArmadillo。

microbenchmark::microbenchmark(MaxColArmadillo(Dummy), #Using RcppArmadillo
                               MaxColAlgorithm(Dummy), #Using STL algorithm max_element
                               maxCol_col(Dummy), #Column processing
                               maxCol_row(Dummy)) #Row processing

Unit: milliseconds
                   expr       min        lq     mean    median       uq      max neval
 MaxColArmadillo(Dummy) 227.95864 235.01426 261.4913 250.17897 276.7593 399.6183   100
 MaxColAlgorithm(Dummy) 292.77041 345.84008 392.1704 390.66578 433.8009 552.2349   100
      maxCol_col(Dummy)  40.64343  42.41487  53.7250  48.10126  61.3781 128.4968   100
      maxCol_row(Dummy) 146.96077 158.84512 173.0941 169.20323 178.7959 272.6261   100

STL 实现

#include <Rcpp.h>

// [[Rcpp::export]]

// Argument is a matrix ansd returns a 
// vector of max of each of the rows of the matrix

Rcpp::NumericVector MaxColAlgorithm(Rcpp::NumericMatrix m) {

//int numOfRows = m.rows();

//Create vector with 0 of size numOfRows
Rcpp::NumericVector total(m.rows());

  for(int i = 0; i < m.rows(); ++i)
  {
    //Create vector of the rows of matrix
    Rcpp::NumericVector rVec = m.row(i);

    //Apply STL max of elemsnts on the vector and store in a vector
    total(i) = *std::max_element(rVec.begin(), rVec.end());
  }

  return total;

}

RcppArmadillo 实现

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
using namespace Rcpp;

// [[Rcpp::export]]
arma::mat MaxColArmadillo(arma::mat x) 
{
  //RcppArmadillo max function where dim = 1 means max of each row
  // of the matrix
  return(max(x,1)); 
}