RcppEigen 更快的协方差

RcppEigen faster covariance

我已经拟合了一个输出协方差矩阵S的回归模型,用于回归参数B。我需要通过乘以X来对这个协方差矩阵进行运算,然后得到新的协方差和stderr向量

cov(X * B) = X * cov(B) * X.transpose()

因为我只需要cov(X * B)的对角线我不需要做全矩阵乘法,我可以得到每一行的协方差X_i * B然后求和

#include <RcppEigen.h>
// [[Rcpp::depends(RcppEigen)]]

using Eigen::Map;
using Eigen::MatrixXd;
using Eigen::VectorXd;
using Eigen::SparseMatrix;
using Eigen::MappedSparseMatrix;
using namespace Rcpp;
using namespace Eigen;

double foo(const Eigen::MappedSparseMatrix<double>& mm, 
           const Eigen::MappedSparseMatrix<double>& vcov) {

  int n = mm.rows();
  double out = 0;
  SparseMatrix<double> mm_t = mm.adjoint();

  SparseMatrix<double> var(1, 1);
  var.setZero();

  for (int i = 0; i < n; i++) {
    var = mm.row(i) * vcov * mm_t.col(i);
    out += var.coeff(0, 0);
  }

  return out;
}

出于某种原因,此函数在 1M 行上非常慢。我尝试使用 "blocks" 而不是逐行对 mm 进行操作,认为通过对值块进行操作可以使与 vcov 的矩阵乘法更快。这并没有使函数更快。这是一个可重现的例子

require(Matrix)

set.seed(100)
N = 2.5e5
p = 100

mm = rsparsematrix(N, p, .01)
vcov = rsparsematrix(p, p, .5)

system.time(foo(mm, vcov))

有没有办法让这个功能更快?

如果协方差矩阵是实数且对称(并且在您的情况下是协方差矩阵),您可以使用简单的数学方法 "trick"。

x %*% b %*% t(b) %*% t(x)的对角线元素之和可以计算为

sum((x %*% b)^2)

超级快。请注意,上面的公式将 b %*% t(b) 作为 "sandwich" 的 "ham" 部分,因此您需要计算 cov(B) 的平方根,然后您可以使用该公式。

或者,您可以直接在 R 中使用以下逐元素乘积

sum((mm %*% vcov) * mm)

我不太熟悉 RcppEigen 和那里的稀疏矩阵,所以以下内容可能会被优化,但看起来很快

// [[Rcpp::export]]                                                                                                                        
double foo2(const Eigen::MappedSparseMatrix<double>& mm,
           const Eigen::MappedSparseMatrix<double>& vcov) {

  double out = 0;
  SparseMatrix<double> mat;

  mat = mm.cwiseProduct(mm*vcov);


  for (int k=0; k<mat.outerSize(); ++k) {
    for (SparseMatrix<double>::InnerIterator it(mat,k); it; ++it)
      {
        out +=it.value();
      }
  }

  return out;
}

这里有一个简短的速度比较

> microbenchmark::microbenchmark(foo(mm, vcov), foo2(mm, vcov), sum((mm %*% vcov) * mm), times=2)
Unit: milliseconds
                    expr        min         lq       mean     median         uq
           foo(mm, vcov) 32575.5488 32575.5488 33587.4147 33587.4147 34599.2806
          foo2(mm, vcov)   463.9440   463.9440   492.4232   492.4232   520.9023
 sum((mm %*% vcov) * mm)   953.7902   953.7902   981.4750   981.4750  1009.1598
        max neval cld
 34599.2806     2   b
   520.9023     2  a 
  1009.1598     2  a 

相当大的改进。即使只是单独使用 R。