使用 data.table 进行迭代分区和标记

Iterative partitioning and labeling using data.table

我有一个迭代分区方法,它为每个观察分配一个标签并继续直到所有分区都小于或等于指定的最小观察值。

使用 data.table 我有 运行 合并 '{'':=' 的问题。我目前的解决方案如下:

part.test <- function(x, y, min.obs=4){
    PART = data.table(x=as.numeric(x),y=as.numeric(y),quadrant='q',prior.quadrant='q',key = c('quadrant','x','y'))
    PART=PART[,counts := .N,quadrant]
    setkey(PART,counts,quadrant,x,y)

    i=0L
    while(i>=0){

    PART=PART[,counts := .N,quadrant]
    l.PART=sum(PART$counts>min.obs)
    if(l.PART==0){break}
    min.obs.rows=PART[counts>=min.obs,which=TRUE]

    PART[min.obs.rows, prior.quadrant := quadrant]
    PART[min.obs.rows, quadrant :=
          ifelse( x <= mean(x) & y <= mean(y), paste0(quadrant,4),
            ifelse(x <= mean(x) & y > mean(y), paste0(quadrant,2),
              ifelse(x > mean(x) & y <= mean(y), paste0(quadrant,3), paste0(quadrant,1)))), 
        by=quadrant]

    i=i+1
    }

    return(PART[])

}

这是一个例子:

> set.seed(123);x=rnorm(1e5);y=rnorm(1e5)
> part.test(x,y)
                 x          y   quadrant prior.quadrant counts
     1: 2.45670228  2.4710128   q1111141        q111114      1
     2: 2.36216477  2.3211246   q1111144        q111114      1
     3: 2.03019608  3.1102172   q1111212        q111121      1
     4: 2.18349873  2.7801719   q1111213        q111121      1
     5: 2.14224180  2.5529947   q1111231        q111123      1
   ---                                                       
 99996: 0.51221861  0.1992352 q143234342      q14323434      4
 99997: 0.08995397 -0.6415489 q324423131      q32442313      4
 99998: 0.09069140 -0.6427392 q324423131      q32442313      4
 99999: 0.09077251 -0.6406127 q324423131      q32442313      4
100000: 0.09077963 -0.6413572 q324423131      q32442313      4
> system.time(part.test(x,y))
   user  system elapsed 
   3.45    0.00    3.53 

使用 data.table 提高此性能的最佳方法是什么?

编辑: 根据 Frank 的评论,我已将 setkey 移出循环。

详细说明我的评论,这里有一些改进:

f <- function(x, y, min.obs = 4){
    DT = data.table(x,y,q="q")[, counts := .N]

    while(TRUE){
        DT[counts >= min.obs, counts := .N, by=q]
        if (max(DT$counts) == min.obs) break

        w = DT[counts >= min.obs, which=TRUE]
        mDT = DT[w, lapply(.SD, mean), by=q, .SDcols = x:y]
        DT[mDT, on=.(q), q_new := {
          lox = x.x <= i.x
          loy = x.y <= i.y
          1L + lox + loy*2L
        }]

        DT[w, q := paste0(q, q_new)]
        DT[, q_new := NULL ]
    }

    setorder(DT[], counts, q, x, y)[]
}


system.time(res <- part.test(x,y))
#    user  system elapsed 
#    2.65    0.00    2.66 

system.time(fres <- f(x,y))
#    user  system elapsed 
#    0.65    0.05    0.70

# verify they match
fsetequal(
  zf <- setnames(copy(fres), "q", "quadrant"), 
  z <- copy(res)[, prior.quadrant := NULL ]
) # TRUE

也许它更快的原因:

  • GForce 用于计算 mDT 中的均值。
  • 使用算术代替 ifelse
  • Keying/sorting只做一次。

它可能比这更快。