如何在 cuSparse 中获取稀疏矩阵的对角线?

How to get the diagonal of a sparse matrix in cuSparse?

我在 cuSparse 中有一个稀疏矩阵,我想提取对角线。除了将它转换回 CPU 内存到 Eigen SparseMatrix 并使用 Eigen 提供的 .diagonal 来完成它,然后将结果复制回 GPU 之外,我似乎找不到其他方法.显然这是非常低效的,所以我想知道是否有办法直接在 GPU 中执行此操作。请参考以下代码:

void CuSparseTransposeToEigenSparse(
    const int *d_row,
    const int *d_col,
    const double *d_val,
    const int num_non0,
    const int mat_row,
    const int mat_col,
    Eigen::SparseMatrix<double> &mat){
  std::vector<int> outer(mat_col + 1);
  std::vector<int> inner(num_non0);
  std::vector<double> value(num_non0);

  cudaMemcpy(
      outer.data(), d_row, sizeof(int) * (mat_col + 1), cudaMemcpyDeviceToHost);

  cudaMemcpy(
      inner.data(), d_col, sizeof(int) * num_non0, cudaMemcpyDeviceToHost);

  cudaMemcpy(
      value.data(), d_val, sizeof(double) * num_non0, cudaMemcpyDeviceToHost);

  Eigen::Map<Eigen::SparseMatrix<double>> mat_map(
      mat_row, mat_col, num_non0, outer.data(), inner.data(), value.data());

  mat = mat_map.eval();
}

int main(){

  int *d_A_row;
  int *d_A_col;
  double *d_A_val;
  int A_len;
  int num_A_non0;
  double *d_A_diag;

  // these values are filled with some computation

  // current solution
  Eigen::SparseMatrix<double> A;

  CuSparseTransposeToEigenSparse(
      d_A_row, d_A_col, d_A_val, num_A_non0, A_len, A_len, A);

  Eigen::VectorXd A_diag = A.diagonal();

  cudaMemcpy(d_A_diag, A_diag.data(), sizeof(double) * A_len, cudaMemcpyHostToDevice);

  // is there a way to fill in d_A_diag without copying back to CPU?

  return 0;
}

以防万一有人感兴趣。我想通了 CSR 矩阵的情况。执行此操作的自定义内核如下所示:

__global__ static void GetDiagFromSparseMat(const int *A_row,
                                            const int *A_col,
                                            const double *A_val,
                                            const int A_len,
                                            double *A_diag){
  const int x = blockIdx.x * blockDim.x + threadIdx.x;

  if (x < A_len){
    const int num_non0_row = A_row[x + 1] - A_row[x];

    A_diag[x] = 0.0;

    for (int i = 0; i < num_non0_row; i++){
      if (A_col[i + A_row[x]] == x){
        A_diag[x] = A_val[i + A_row[x]];
        break;
      }
    }
  }
}