使用 tidyverse 向量化 returns 多个变量的函数

vectorize a function that returns more than one variable using tidyverse

我有一个函数,returns 几个变量相互依赖。 输出它是一个具有 1 行和 n 列的数据框。输出中的列数取决于函数的输入之一。 我需要对其进行矢量化并加入“主”数据框,就像 'dplyr::mutate()' 那样。

我真的尽量使 reprex 尽可能简单:

#data
df <- data.frame("ob" = 1:30,
                 "ob_pattern" = sample(c("p1", "p2"), size = 30, replace = T),
                 "value" = runif(n = 30))
> head(df)
  ob ob_pattern     value
1  1         p1 0.5442453
2  2         p2 0.1274518
3  3         p2 0.4256460
4  4         p1 0.9319009
5  5         p2 0.9828048
6  6         p2 0.2309473

#patterns
df_pt <- data.frame("pattern" = c("p1", "p1", "p2", "p2", "p2"),
                    "name" = c("n1", "n2", "n1", "n2", "n3" ),
                    "perct" = c(0.4, 0.15, 0.3, 0.5, 0.18))

> df_pt
  pattern name perct
1      p1   n1  0.40
2      p1   n2  0.15
3      p2   n1  0.30
4      p2   n2  0.50
5      p2   n3  0.18

此函数创建 类 并将数据库中的值乘以模式 table

中的预定义模式
fun <- function(value, ob_pattern, df_pt){
  
  #filter the pattern
  sel_pt <- df_pt %>% 
    dplyr::filter(pattern == ob_pattern)
  
  out <- data.frame()
  
  for(i in 1:nrow(sel_pt)){
    out[1,i] <- sel_pt[i,2]
    out[2,i] <- sel_pt[i,3] / value
  }
  
  names(out) <- out[1,]
  out <-  out[-1,]
  return(out)
  
}

此功能“手动”运行良好:

fun(10, "p1", df_pt)
> fun(10, "p1", df_pt)
    n1    n2
2 0.04 0.015

fun(10, "p2", df_pt)
> fun(10, "p2", df_pt)
    n1   n2    n3
2 0.03 0.05 0.018

但是,在地图迭代中进展不顺利:

pmap(list(value = df$value, ob_pattern = df$ob_pattern, df_pt = df_pt), fun)

> pmap(list(value = df$value, ob_pattern = df$ob_pattern, df_pt = df_pt), fun)
Erro: Element 3 of `.l` must have length 1 or 30, not 3
Run `rlang::last_error()` to see where the error occurred.

df <- df %>% 
  mutate(pmap(list(value = value, ob_pattern = ob_pattern, df_pt = df_pt), fun))

> df <- df %>% 
+   mutate(pmap(list(value = value, ob_pattern = ob_pattern, df_pt = df_pt), fun))
Erro: Problem with `mutate()` input `..1`.
i `..1 = pmap(...)`.
x Element 3 of `.l` must have length 1 or 30, not 3
Run `rlang::last_error()` to see where the error occurred.

我的期望:

# A tibble: 6 x 30
     ob ob_pattern value    n1    n2     n3
  <dbl> <chr>      <dbl> <dbl> <dbl>  <dbl>
1     1 p1         0.544 1.36  3.63  NA    
2     2 p2         0.127 0.425 0.255  0.708
3     3 p2         0.426 1.42  0.851  2.36 
4     4 p1         0.932 2.33  6.21  NA    
5     5 p2         0.983 3.28  1.97   5.46 
6     6 p2         0.231 0.770 0.462  1.28 

问题是 df_pt 是一个 data.frame 并且它需要在每个循环元素中用作输入。所以,用 list 包裹它,这样它就可以作为一个整体被回收。当我们遍历 data.frame 时,列是一个单位,这会触发错误 Erro: Element 3 of .l must have length 1 or 30, not 3,因为列数是 3.

library(dplyr)
library(purrr)
pmap_dfr(list(value = df$value, ob_pattern = df$ob_pattern, 
      df_pt = list(df_pt)), fun, .id = 'ob') %>%
     mutate(ob_pattern = df$ob_pattern, .before = 2)

-输出

  ob ob_pattern                n1                n2                n3
1   1         p2 0.412805820786703 0.688009701311172 0.247683492472022
2   2         p2 0.819499036723223  1.36583172787204 0.491699422033934
3   3         p2 0.307851399008221 0.513085665013701 0.184710839404932
4   4         p1 0.512735060593463 0.192275647722549              <NA>
5   5         p1 0.583734910383962 0.218900591393986              <NA>
6   6         p1  1.26403823904009 0.474014339640033              <NA>
7   7         p1 0.520375965374508 0.195140987015441              <NA>
8   8         p2 0.519695574800472  0.86615929133412 0.311817344880283
9   9         p1 0.406595728747128 0.152473398280173              <NA>
10 10         p1  1.19690591834918 0.448839719380944              <NA>
11 11         p1 0.935134681128101 0.350675505423038              <NA>
12 12         p2 0.782381874921124  1.30396979153521 0.469429124952674
13 13         p1 0.902566162028802 0.338462310760801              <NA>
14 14         p2 0.412253449371353 0.687089082285588 0.247352069622812
15 15         p2 0.414083431765533 0.690139052942556  0.24845005905932
16 16         p2 0.540922520169042 0.901537533615069 0.324553512101425
17 17         p2 0.306604097963516 0.511006829939193  0.18396245877811
18 18         p2  1.94204963387021  3.23674938978369  1.16522978032213
19 19         p2 0.302096661043879 0.503494435073132 0.181257996626328
20 20         p1 0.478354496206454  0.17938293607742              <NA>
21 21         p2 0.406759159422302 0.677931932370503 0.244055495653381
22 22         p1 0.929982462421745 0.348743423408154              <NA>
23 23         p2 0.850658644553245  1.41776440758874 0.510395186731947
24 24         p1  1.24950965620306 0.468566121076146              <NA>
25 25         p1 0.807136438261923 0.302676164348221              <NA>
26 26         p2  75.9337007291282   126.55616788188  45.5602204374769
27 27         p2 0.487844654295068 0.813074423825113 0.292706792577041
28 28         p1 0.702944374408066 0.263604140403025              <NA>
29 29         p1 0.417447530041509 0.156542823765566              <NA>
30 30         p2  2.14866591202588  3.58110985337647  1.28919954721553

或者如果我们想在 mutate

中使用 pmap
library(tidyr)
df %>% 
   mutate(out = pmap(across(c(value, ob_pattern)), 
      ~ fun(..1, ..2, df_pt))) %>% 
   unnest_wider(c(out)) %>%
   type.convert(as.is = TRUE)

-输出

# A tibble: 30 × 6
      ob ob_pattern value    n1    n2     n3
   <int> <chr>      <dbl> <dbl> <dbl>  <dbl>
 1     1 p2         0.727 0.413 0.688  0.248
 2     2 p2         0.366 0.819 1.37   0.492
 3     3 p2         0.974 0.308 0.513  0.185
 4     4 p1         0.780 0.513 0.192 NA    
 5     5 p1         0.685 0.584 0.219 NA    
 6     6 p1         0.316 1.26  0.474 NA    
 7     7 p1         0.769 0.520 0.195 NA    
 8     8 p2         0.577 0.520 0.866  0.312
 9     9 p1         0.984 0.407 0.152 NA    
10    10 p1         0.334 1.20  0.449 NA    
# … with 20 more rows

注意:输出生成了 returns character 列,这只是因为 OP 的 fun 代码

中存在一些问题

或使用rowwise

df %>% 
  rowwise %>%
  mutate(out = fun(value, ob_pattern, df_pt)) %>%
  ungroup %>%
  unpack(out) %>%
  type.convert(as.is = TRUE)

-输出

# A tibble: 30 × 6
      ob ob_pattern value    n1    n2     n3
   <int> <chr>      <dbl> <dbl> <dbl>  <dbl>
 1     1 p2         0.727 0.413 0.688  0.248
 2     2 p2         0.366 0.819 1.37   0.492
 3     3 p2         0.974 0.308 0.513  0.185
 4     4 p1         0.780 0.513 0.192 NA    
 5     5 p1         0.685 0.584 0.219 NA    
 6     6 p1         0.316 1.26  0.474 NA    
 7     7 p1         0.769 0.520 0.195 NA    
 8     8 p2         0.577 0.520 0.866  0.312
 9     9 p1         0.984 0.407 0.152 NA    
10    10 p1         0.334 1.20  0.449 NA    
# … with 20 more rows

作为另一种方法,这是嵌套数据框的有力候选者。

在这种情况下,我们可以调整您的函数以从头开始获取过滤后的数据帧。

fun2 <- function(value, sel_pt){
  
  #filter the pattern
  out <- data.frame()
  
  for(i in 1:nrow(sel_pt)){
    out[1,i] <- sel_pt[i,1]
    out[2,i] <- sel_pt[i,2] / value
  }
  
  names(out) <- out[1,]
  out <-  out[-1,]
  return(out)
  
}

现在我们可以在 df_pt 上嵌套连接并将其映射为输入。

library(dplyr)
library(tidyr)
library(purrr)

df %>% 
  nest_join(df_pt, by = c(ob_pattern = "pattern")) %>% 
  mutate(output = map2(value, df_pt, fun2)) %>% 
  select(ob, ob_pattern, value, output) %>% 
  unnest_wider(output)

另一方面,这个 fun2() 可以很容易地改写如下。此 returns 列为数字,这可能是您想要的。

library(tibble)

fun3 <- function(value, sel_pt){
  
  sel_pt %>% 
    mutate(perct = perct / value) %>% 
    deframe()
}

df %>% 
  nest_join(df_pt, by = c(ob_pattern = "pattern")) %>% 
  mutate(output = map2(value, df_pt, fun3)) %>% 
  select(ob, ob_pattern, value, output) %>% 
  unnest_wider(output)
# A tibble: 30 x 6
      ob ob_pattern   value      n1     n2     n3
   <int> <chr>        <dbl>   <dbl>  <dbl>  <dbl>
 1     1 p1         0.898     0.445  0.167 NA    
 2     2 p1         0.413     0.970  0.364 NA    
 3     3 p2         0.507     0.592  0.987  0.355
 4     4 p2         0.544     0.551  0.918  0.331
 5     5 p2         0.504     0.595  0.992  0.357
 6     6 p1         0.00277 145.    54.2   NA    
 7     7 p1         0.453     0.883  0.331 NA    
 8     8 p1         0.175     2.29   0.858 NA    
 9     9 p1         0.595     0.673  0.252 NA    
10    10 p2         0.0358    8.37  13.9    5.02 
# ... with 20 more rows