R dist:计算少到多的距离

R dist: calculating distance of few to many

假设你在给定的实验中有一些参与者和控制,这些参与者和控制在三个特征中进行评估,如下所示:

part_A <- c(3, 5, 4)
part_B <- c(12, 15, 18)
part_C <- c(50, 40, 45)

ctrl_1 <- c(4, 5, 5)
ctrl_2 <- c(1, 0, 4)
ctrl_3 <- c(13, 16, 17)
ctrl_4 <- c(28, 30, 35)
ctrl_5 <- c(51, 43, 44)

我想为每个参与者找到最接近匹配的控制案例。

如果我用dist()函数,我可以得到,但是计算控件之间的距离也需要很多时间,这对我没用(而且在真实数据中,有控制案例比参与者案例多 1000 倍)。

有没有办法求出每个 这些 元素与每个 那些 元素之间的距离?以及适用于非常大的数据集的东西?

在上面的例子中,我想要的结果是:

  Participant Closest_Ctrl
1      part_A       ctrl_1
2      part_B       ctrl_3
3      part_C       ctrl_5

将输入转换为数据帧

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))

生成输出

# calculate distances
dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))

# generate output by calculating column with min value (max negative value)
data.frame(Participant = names(parts), 
           Closest_Ctrl = names(ctrl)[max.col(-dists)])

#   Participant Closest_Ctrl
# 1      part_A       ctrl_1
# 2      part_B       ctrl_3
# 3      part_C       ctrl_5

基准

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
parts <- do.call(cbind, replicate(100, parts, simplify = F))
ctrl <- do.call(cbind, replicate(100, ctrl, simplify = F))

r1 <- f1()
r2 <- f2()

all.equal(r1 %>% lapply(as.factor) %>% setNames(1:2), 
          r2[2:1] %>% lapply(as.factor) %>% setNames(1:2))
# [1] TRUE


f1 <- function(x){
  dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
  # generate output by calculating column with min value (max negative value)
  data.frame(Participant = names(parts), 
             Closest_Ctrl = names(ctrl)[max.col(-dists)])
}

f2 <- function(x){
  res <- vapply(parts, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
            FUN.VALUE = character(1))

  stack(res)
}

microbenchmark::microbenchmark(f1(), f2(), times = 5)        
# Unit: milliseconds
#  expr        min         lq       mean     median         uq        max neval
#  f1()   305.7324   314.8356   435.3961   324.6116   461.4788   770.3221     5
#  f2() 12359.6995 12831.7995 13567.8296 13616.5216 14244.0836 14787.0438     5

基准 2

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
parts <- do.call(cbind, replicate(10, parts, simplify = F))
ctrl <- do.call(cbind, replicate(10*1000, ctrl, simplify = F))

r1 <- f1()
r2 <- f2()

all.equal(r1 %>% lapply(as.factor) %>% setNames(1:2), 
          r2[2:1] %>% lapply(as.factor) %>% setNames(1:2))
# [1] TRUE


f1 <- function(x){
  dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
  # generate output by calculating column with min value (max negative value)
  data.frame(Participant = names(parts), 
             Closest_Ctrl = names(ctrl)[max.col(-dists)])
}

f2 <- function(x){
  res <- vapply(parts, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
            FUN.VALUE = character(1))

  stack(res)
}

microbenchmark::microbenchmark(f1(), f2(), times = 5)        
# Unit: seconds
#  expr        min         lq       mean     median         uq        max neval
#  f1()   3.450176   4.211997   4.493805   4.339818   5.154191   5.312844     5
#  f2() 119.120484 124.280423 132.637003 130.858727 131.148630 157.776749     5

这是一个解决方案,对于不太多的参与者来说应该足够快:

ctrl <- do.call(cbind, mget(ls(pattern = "ctrl_\d+")))

dat <- mget(ls(pattern = "part_[[:upper:]+]"))

res <- vapply(dat, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
                FUN.VALUE = character(1))

stack(res)
#  values    ind
#1 ctrl_1 part_A
#2 ctrl_3 part_B
#3 ctrl_5 part_C

如果这太慢,我会很快用 Rcpp 编写代码。