在 Rcpp 中查找向量中所有 max/min 值的索引

Find index of all max/min values in vector in Rcpp

假设我有一个向量

v = c(1,2,3)

我可以使用

轻松找到最大的元素
cppFunction('int which_maxCpp(NumericVector v) {
  int z = which_max(v);
  return z;
}')

which_maxCpp(v)

2

但是,如果我有一个向量,例如

v2 = c(1,2,3,1,2,3)

得到

which_maxCpp(v2)

2

而我应该发现索引 2 和索引 5(如果使用 1 索引,则索引 3 和索引 6)等于向量中的最大值

有没有办法让 which_max(或 which_min)找到一个向量的所有 min/max 元素的索引,或者是另一个(我假设本机 C++ ) 需要功能吗?

我不知道本机函数,但是循环编写起来相当简单。

这里有三个版本。

两个找到向量的Rcpp::max(),然后找到匹配这个最大值的向量的索引。一个使用预分配的 Rcpp::IntegerVector() 来存储结果,然后将其子集化以删除额外的 'unused' 零。另一个使用 std::vector< int >.push_back() 来存储结果。

library(Rcpp)

cppFunction('IntegerVector which_maxCpp1(NumericVector v) {
  double m = Rcpp::max(v);
  Rcpp::IntegerVector res( v.size() );  // pre-allocate result vector

  int i;
  int counter = 0;
  for( i = 0; i < v.size(); ++i) {
    if( v[i] == m ) {
      res[ counter ] = i;
      counter++;
    }
  }
  counter--;
  Rcpp::Range rng(0, counter);  
  return res[rng];
}')

v = c(1,2,3,1,2,3)

which_maxCpp(v)
# [1] 2 5
cppFunction('IntegerVector which_maxCpp2(NumericVector v) {
  double m = Rcpp::max(v);
  std::vector< int > res;

  int i;
  for( i = 0; i < v.size(); ++i) {
    if( v[i] == m ) {
      res.push_back( i );
    }
  }
  Rcpp::IntegerVector iv( res.begin(), res.end() );
  return iv;
}')

which_maxCpp(v)
# [1] 2 5

第三个选项通过查找最大值并同时跟踪一个循环中的索引来避免对向量进行两次传递。

cppFunction('IntegerVector which_maxCpp3(NumericVector v) {

  double current_max = v[0];
  int n = v.size();
  std::vector< int > res;
  res.push_back( 0 );
  int i;

  for( i = 1; i < n; ++i) {
    double x = v[i];
    if( x > current_max ) {
      res.clear();
      current_max = x;
      res.push_back( i );
    } else if ( x == current_max ) {
      res.push_back( i );
    }
  }
  Rcpp::IntegerVector iv( res.begin(), res.end() );
  return iv;
}')

基准测试

这里有一些基准展示了这些函数如何与基本的 R 方法相比较。

library(microbenchmark)

x <- sample(1:100, size = 1e6, replace = T)

microbenchmark(
  iv = { which_maxCpp1(x) },
  stl = { which_maxCpp2(x) },
  max = { which_maxCpp3(x) },
  r = { which( x == max(x)) } 
)

# Unit: milliseconds
# expr      min        lq      mean    median       uq        max neval
#   iv 6.638583 10.617945 14.028378 10.956616 11.63981 165.719783   100
#  stl 6.830686  9.506639  9.787291  9.744488 10.17247  11.275061   100
#  max 3.161913  5.690886  5.926433  5.913899  6.19489   7.427020   100
#    r 4.044166  5.558075  5.819701  5.719940  6.00547   7.080742   100