在稀疏矩阵的行中查找交点
Finding intersections in rows of sparse matrices
作为 MCVE,考虑这样的稀疏矩阵(另请参阅末尾的 dput
输出)
> X
10 x 8 sparse Matrix of class "dgCMatrix"
[1,] . . . . 5.45 . . 1.75
[2,] . . 5.05 1.75 5.45 3.60 . .
[3,] 5.45 . 2.45 . . . . .
[4,] . . 5.05 . 6.50 . . .
[5,] 5.45 . . . . 2.85 . .
[6,] . . . . 5.95 . . 1.75
[7,] 5.45 . . 1.60 . . 2.45 .
[8,] 5.45 . . 1.60 . . 2.45 .
[9,] 5.45 . 2.45 . . . . .
[10,] . . 5.05 1.75 5.45 3.60 . .
例如,如果给定的交集是 c(1L, 3L)
,那么我想知道在第一列和第三列上具有非零元素的行的索引,即 c(3, 9)
。对于交集c(3L, 4L, 5L)
,应该是c(2, 10)
.
请注意,在我的申请中
- 矩阵
X
可能有几十万行and/or几千列。
- 每个交叉点通常有 2 到 3 个元素,最多 6 个元素。
- 将有数百个不同的交叉点需要
lapply
编辑,因此您可能需要进行一些预处理。
这是我现在正在做的事情
> intersections <- list(c(1L, 3L), c(3L, 4L, 5L))
> nonzero.rows <- by(X@i, rep(1:ncol(X), times=diff(X@p)), list)
> find.row.id <- function(intersection, nonzero.rows) Reduce(intersect, nonzero.rows[as.character(intersection)]) + 1
> lapply(intersections, find.row.id, nonzero.rows=nonzero.rows)
[[1]]
[1] 3 9
[[2]]
[1] 2 10
分析表明这是我的库中最大的瓶颈之一。你能让它更快吗?
> dput(X)
new("dgCMatrix", i = c(2L, 4L, 6L, 7L, 8L, 1L, 2L, 3L, 8L, 9L,
1L, 6L, 7L, 9L, 0L, 1L, 3L, 5L, 9L, 1L, 4L, 9L, 6L, 7L, 0L, 5L
), p = c(0L, 5L, 5L, 10L, 14L, 19L, 22L, 24L, 26L), Dim = c(10L,
8L), Dimnames = list(NULL, NULL), x = c(5.45, 5.45, 5.45, 5.45,
5.45, 5.05, 2.45, 5.05, 2.45, 5.05, 1.75, 1.6, 1.6, 1.75, 5.45,
5.45, 6.5, 5.95, 5.45, 3.6, 2.85, 3.6, 2.45, 2.45, 1.75, 1.75
), factors = list())
代表
library(Matrix)
set.seed(1)
X <- rsparsematrix(10000, 1000, 0.3)
intersections <- replicate(10000, sample(ncol(X), sample(2:4)))
测试一些解决方案
您的解决方案:
system.time({
nonzero.rows <- by(X@i, rep(1:ncol(X), times=diff(X@p)), list)
find.row.id <- function(intersection, nonzero.rows) Reduce(intersect, nonzero.rows[as.character(intersection)]) + 1
lapply(intersections, find.row.id, nonzero.rows=nonzero.rows)
}) # 3.4 sec
将 X
重新编码为向量列表(离您的解决方案不远,但更优雅):
system.time({
X2 <- as(X, "dgTMatrix")
X3 <- split(X2@i + 1L, factor(X2@j + 1L, levels = seq_len(ncol(X))))
lapply(intersections, function(ind) Reduce(intersect, X3[ind]))
}) # 3.4 sec
从较小的集合开始减少:
system.time({
X2 <- as(X, "dgTMatrix")
X3 <- split(X2@i + 1L, factor(X2@j + 1L, levels = seq_len(ncol(X))))
lapply(intersections, function(ind) {
X3.ind <- X3[ind]
len <- lengths(X3.ind)
Reduce(intersect, X3.ind[order(len)])
})
}) # 3.7 sec
评论中提出的解决方案:
system.time({
lapply(intersections, function(ind) {
which(Matrix::rowSums(X[, ind] != 0) == length(ind))
})
}) # 46 sec
https://coolbutuseless.github.io/2018/09/17/intersection-of-multiple-vectors/ 提出的解决方案:
system.time({
X2 <- as(X, "dgTMatrix")
X3 <- split(X2@i + 1L, factor(X2@j + 1L, levels = seq_len(ncol(X))))
lapply(intersections, function(ind) {
tally <- integer(nrow(X))
for (elements in X3[ind]) {
tally[elements] <- tally[elements] + 1L
}
which(tally == length(ind))
})
}) # 1.7 sec
您可以轻松并行化 lapply()
。
作为 MCVE,考虑这样的稀疏矩阵(另请参阅末尾的 dput
输出)
> X
10 x 8 sparse Matrix of class "dgCMatrix"
[1,] . . . . 5.45 . . 1.75
[2,] . . 5.05 1.75 5.45 3.60 . .
[3,] 5.45 . 2.45 . . . . .
[4,] . . 5.05 . 6.50 . . .
[5,] 5.45 . . . . 2.85 . .
[6,] . . . . 5.95 . . 1.75
[7,] 5.45 . . 1.60 . . 2.45 .
[8,] 5.45 . . 1.60 . . 2.45 .
[9,] 5.45 . 2.45 . . . . .
[10,] . . 5.05 1.75 5.45 3.60 . .
例如,如果给定的交集是 c(1L, 3L)
,那么我想知道在第一列和第三列上具有非零元素的行的索引,即 c(3, 9)
。对于交集c(3L, 4L, 5L)
,应该是c(2, 10)
.
请注意,在我的申请中
- 矩阵
X
可能有几十万行and/or几千列。 - 每个交叉点通常有 2 到 3 个元素,最多 6 个元素。
- 将有数百个不同的交叉点需要
lapply
编辑,因此您可能需要进行一些预处理。
这是我现在正在做的事情
> intersections <- list(c(1L, 3L), c(3L, 4L, 5L))
> nonzero.rows <- by(X@i, rep(1:ncol(X), times=diff(X@p)), list)
> find.row.id <- function(intersection, nonzero.rows) Reduce(intersect, nonzero.rows[as.character(intersection)]) + 1
> lapply(intersections, find.row.id, nonzero.rows=nonzero.rows)
[[1]]
[1] 3 9
[[2]]
[1] 2 10
分析表明这是我的库中最大的瓶颈之一。你能让它更快吗?
> dput(X)
new("dgCMatrix", i = c(2L, 4L, 6L, 7L, 8L, 1L, 2L, 3L, 8L, 9L,
1L, 6L, 7L, 9L, 0L, 1L, 3L, 5L, 9L, 1L, 4L, 9L, 6L, 7L, 0L, 5L
), p = c(0L, 5L, 5L, 10L, 14L, 19L, 22L, 24L, 26L), Dim = c(10L,
8L), Dimnames = list(NULL, NULL), x = c(5.45, 5.45, 5.45, 5.45,
5.45, 5.05, 2.45, 5.05, 2.45, 5.05, 1.75, 1.6, 1.6, 1.75, 5.45,
5.45, 6.5, 5.95, 5.45, 3.6, 2.85, 3.6, 2.45, 2.45, 1.75, 1.75
), factors = list())
代表
library(Matrix)
set.seed(1)
X <- rsparsematrix(10000, 1000, 0.3)
intersections <- replicate(10000, sample(ncol(X), sample(2:4)))
测试一些解决方案
您的解决方案:
system.time({
nonzero.rows <- by(X@i, rep(1:ncol(X), times=diff(X@p)), list)
find.row.id <- function(intersection, nonzero.rows) Reduce(intersect, nonzero.rows[as.character(intersection)]) + 1
lapply(intersections, find.row.id, nonzero.rows=nonzero.rows)
}) # 3.4 sec
将 X
重新编码为向量列表(离您的解决方案不远,但更优雅):
system.time({
X2 <- as(X, "dgTMatrix")
X3 <- split(X2@i + 1L, factor(X2@j + 1L, levels = seq_len(ncol(X))))
lapply(intersections, function(ind) Reduce(intersect, X3[ind]))
}) # 3.4 sec
从较小的集合开始减少:
system.time({
X2 <- as(X, "dgTMatrix")
X3 <- split(X2@i + 1L, factor(X2@j + 1L, levels = seq_len(ncol(X))))
lapply(intersections, function(ind) {
X3.ind <- X3[ind]
len <- lengths(X3.ind)
Reduce(intersect, X3.ind[order(len)])
})
}) # 3.7 sec
评论中提出的解决方案:
system.time({
lapply(intersections, function(ind) {
which(Matrix::rowSums(X[, ind] != 0) == length(ind))
})
}) # 46 sec
https://coolbutuseless.github.io/2018/09/17/intersection-of-multiple-vectors/ 提出的解决方案:
system.time({
X2 <- as(X, "dgTMatrix")
X3 <- split(X2@i + 1L, factor(X2@j + 1L, levels = seq_len(ncol(X))))
lapply(intersections, function(ind) {
tally <- integer(nrow(X))
for (elements in X3[ind]) {
tally[elements] <- tally[elements] + 1L
}
which(tally == length(ind))
})
}) # 1.7 sec
您可以轻松并行化 lapply()
。