为什么这种朴素的矩阵乘法比基数 R 的乘法快?

Why is this naive matrix multiplication faster than base R's?

在 R 中,矩阵乘法非常优化,即实际上只是对 BLAS/LAPACK 的调用。然而,令我惊讶的是,这个用于矩阵向量乘法的非常简单的 C++ 代码似乎确实快了 30%。

 library(Rcpp)

 # Simple C++ code for matrix multiplication
 mm_code = 
 "NumericVector my_mm(NumericMatrix m, NumericVector v){
   int nRow = m.rows();
   int nCol = m.cols();
   NumericVector ans(nRow);
   double v_j;
   for(int j = 0; j < nCol; j++){
     v_j = v[j];
     for(int i = 0; i < nRow; i++){
       ans[i] += m(i,j) * v_j;
     }
   }
   return(ans);
 }
 "
 # Compiling
 my_mm = cppFunction(code = mm_code)

 # Simulating data to use
 nRow = 10^4
 nCol = 10^4

 m = matrix(rnorm(nRow * nCol), nrow = nRow)
 v = rnorm(nCol)

 system.time(my_ans <- my_mm(m, v))
#>    user  system elapsed 
#>   0.103   0.001   0.103 
 system.time(r_ans <- m %*% v)
#>   user  system elapsed 
#>  0.154   0.001   0.154 

 # Double checking answer is correct
 max(abs(my_ans - r_ans))
 #> [1] 0

base R 的 %*% 是否执行我跳过的某种类型的数据检查?

编辑:

在了解发生了什么之后(感谢 SO!),值得注意的是,这是 R 的 %*% 的最坏情况,即矩阵向量。例如,@RalfStubner 指出,使用矩阵向量乘法的 RcppArmadillo 实现甚至比我演示的简单实现更快,这意味着比基 R 快得多,但实际上与基 R 的 %*% 矩阵相同- 矩阵乘法(当两个矩阵都很大并且是正方形时):

 arma_code <- 
   "arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
 return m * m2;
 };"
 arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

 nRow = 10^3 
 nCol = 10^3

 mat1 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)
 mat2 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)

 system.time(arma_mm(mat1, mat2))
#>   user  system elapsed 
#>   0.798   0.008   0.814 
 system.time(mat1 %*% mat2)
#>   user  system elapsed 
#>   0.807   0.005   0.822  

所以 R 的当前 (v3.5.0) %*% 对于矩阵-矩阵来说接近最优,但如果您可以跳过检查,对于矩阵-向量可以显着加速。

快速浏览 names.chere in particular) points you to do_matprod, the C function that is called by %*% and which is found in the file array.c. (Interestingly, it turns out, that both crossprod and tcrossprod dispatch to that same function as well). Here is a linkdo_matprod 的代码。

滚动浏览该函数,您可以看到它处理了许多您的简单实现没有处理的事情,包括:

  1. 在有意义的地方保留行名和列名。
  2. 当通过调用 %*% 操作的两个对象属于 类 时,允许分派到替代 S4 方法,并为其提供了此类方法。 (这就是函数 this portion 中发生的事情。)
  3. 处理实矩阵和复矩阵。
  4. 实现了一系列关于如何处理矩阵与矩阵、向量与矩阵、矩阵与向量以及向量与向量的乘法的规则。 (回想一下,在 R 中的交叉乘法下,LHS 上的向量被视为行向量,而在 RHS 上,它被视为列向量;这是实现这一点的代码。)

Near the end of the function, it dispatches to either of matprod or or cmatprod. Interestingly (to me at least), in the case of real matrices, if either matrix might contain NaN or Inf values, then matprod dispatches (here) to a function called simple_matprod 这和你自己的一样简单明了。否则,它将分派到几个 BLAS Fortran 例程中的一个,如果可以保证统一 'well-behaved' 矩阵元素,这些例程可能更快。

Josh 的回答解释了为什么 R 的矩阵乘法不如这种幼稚的方法快。我很想知道使用 RcppArmadillo 可以获得多少收益。代码很简单:

arma_code <- 
  "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
       return m * v;
   };"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

基准:

> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
       m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
 arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10

所以 RcppArmadillo 给了我们更好的语法和更好的性能。

好奇心战胜了我。这里有一个直接使用 BLAS 的解决方案:

blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")

基准:

Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
       m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
 arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
 blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10

Armadillo 和 BLAS(在我的例子中是 OpenBLAS)几乎是一样的。 BLAS 代码也是 R 最终所做的。所以 R 所做的 2/3 是错误检查等

再补充一点Ralf Stubner的解决方案,那么可以使用下面的C++版本来

  1. 同时做多列以避免多次re-reading输出向量。
  2. 添加 __restrict__ 以潜在地允许向量操作(这里可能无关紧要,因为我猜它只是读取)。
#include <Rcpp.h>
using namespace Rcpp;

inline void mat_vec_mult_vanilla
(double const * __restrict__ m, 
 double const * __restrict__ v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  for(size_t j = 0; j < dm; ++j, ++v){
    double * r = res;
    for(size_t i = 0; i < dn; ++i, ++r, ++m)
      *r += *m * *v;
  }
}

inline void mat_vec_mult
(double const * __restrict__ const m, 
 double const * __restrict__ const v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  size_t j(0L);
  double const * vj = v,
               * mi = m;
  constexpr size_t const ncl(8L);
  {
    double const * mvals[ncl];
    size_t const end_j = dm - (dm % ncl),
                   inc = ncl * dn;
    for(; j < end_j; j += ncl, vj += ncl, mi += inc){
      double *r = res;
      mvals[0] = mi;
      for(size_t i = 1; i < ncl; ++i)
        mvals[i] = mvals[i - 1L] + dn;
      for(size_t i = 0; i < dn; ++i, ++r)
        for(size_t ii = 0; ii < ncl; ++ii)
          *r += *(vj + ii) * *mvals[ii]++;
    }
  }
  
  mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}

// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

我的 Makevars 文件和 gcc-8.3 中 -O3 的结果是

set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)

all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
#R> [1] TRUE
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))
#R> [1] TRUE

bench::mark(
  R              = m %*% lv, 
  `OP's version` = my_mm(m = m, v = lv), 
  `BLAS`         = blas_mm(m = m, v = lv),
  `C++ vanilla`  = mat_vec_mult_vanilla(m = m, v = lv), 
  `C++`          = mat_vec_mult(m = m, v = lv), check = FALSE)
#R> # A tibble: 5 x 13
#R>   expression        min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory                 time          gc               
#R>   <bch:expr>   <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>                 <list>        <list>           
#R> 1 R             147.9ms    151ms      6.57    78.2KB        0     4     0      609ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [4]>  <tibble [4 × 3]> 
#R> 2 OP's version   56.9ms   57.1ms     17.4     78.2KB        0     9     0      516ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]>  <tibble [9 × 3]> 
#R> 3 BLAS           90.1ms   90.7ms     11.0     78.2KB        0     6     0      545ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [6]>  <tibble [6 × 3]> 
#R> 4 C++ vanilla    57.2ms   57.4ms     17.4     78.2KB        0     9     0      518ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]>  <tibble [9 × 3]> 
#R> 5 C++              51ms   51.4ms     19.3     78.2KB        0    10     0      519ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [10]> <tibble [10 × 3]>

所以略有改善。结果可能非常依赖于 BLAS 版本。我使用的版本是

sessionInfo()
#R> #...
#R> Matrix products: default
#R> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
#R> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
#R> ...

Rcpp::sourceCpp()编辑的整个文件是

#include <Rcpp.h>
#include <R_ext/BLAS.h>
using namespace Rcpp;

inline void mat_vec_mult_vanilla
(double const * __restrict__ m, 
 double const * __restrict__ v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  for(size_t j = 0; j < dm; ++j, ++v){
    double * r = res;
    for(size_t i = 0; i < dn; ++i, ++r, ++m)
      *r += *m * *v;
  }
}

inline void mat_vec_mult
(double const * __restrict__ const m, 
 double const * __restrict__ const v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  size_t j(0L);
  double const * vj = v,
               * mi = m;
  constexpr size_t const ncl(8L);
  {
    double const * mvals[ncl];
    size_t const end_j = dm - (dm % ncl),
                   inc = ncl * dn;
    for(; j < end_j; j += ncl, vj += ncl, mi += inc){
      double *r = res;
      mvals[0] = mi;
      for(size_t i = 1; i < ncl; ++i)
        mvals[i] = mvals[i - 1L] + dn;
      for(size_t i = 0; i < dn; ++i, ++r)
        for(size_t ii = 0; ii < ncl; ++ii)
          *r += *(vj + ii) * *mvals[ii]++;
    }
  }
  
  mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}

// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export(rng = false)]]
NumericVector my_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  double v_j;
  for(int j = 0; j < nCol; j++){
    v_j = v[j];
    for(int i = 0; i < nRow; i++){
      ans[i] += m(i,j) * v_j;
    }
  }
  return(ans);
}

// [[Rcpp::export(rng = false)]]
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}

/*** R
set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)

all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))

bench::mark(
  R              = m %*% lv, 
  `OP's version` = my_mm(m = m, v = lv), 
  `BLAS`         = blas_mm(m = m, v = lv),
  `C++ vanilla`  = mat_vec_mult_vanilla(m = m, v = lv), 
  `C++`          = mat_vec_mult(m = m, v = lv), check = FALSE)
*/