R:下标越界

R: subscript is out of bounds

我正在使用 R。我正在尝试按照此处的本教程进行函数优化:https://rpubs.com/Argaadya/bayesian-optimization

对于这个例子,我首先生成一些随机数据:

#load libraries
library(dplyr)


# create some data for this example
a1 = rnorm(1000,100,10)
b1 = rnorm(1000,100,5)
c1 = sample.int(1000, 1000, replace = TRUE)
train_data = data.frame(a1,b1,c1)

从这里开始,我定义了我想要优化的功能(“健身”)。此函数采用 7 个输入并计算“总”平均值(单个标量值)。此函数所需的输入是:

要优化的函数(“fitness”)定义如下:

#define fitness function : returns a single scalar value called "total"
fitness <- function(random_1, random_2, random_3, random_4, split_1, split_2, split_3) {

    #bin data according to random criteria
    train_data <- train_data %>% mutate(cat = ifelse(a1 <= random_1 & b1 <= random_3, "a", ifelse(a1 <= random_2 & b1 <= random_4, "b", "c")))
    
    train_data$cat = as.factor(train_data$cat)
    
    #new splits
    a_table = train_data %>%
        filter(cat == "a") %>%
        select(a1, b1, c1, cat)
    
    b_table = train_data %>%
        filter(cat == "b") %>%
        select(a1, b1, c1, cat)
    
    c_table = train_data %>%
        filter(cat == "c") %>%
        select(a1, b1, c1, cat)
    

    
    #calculate  quantile ("quant") for each bin
    
    table_a = data.frame(a_table%>% group_by(cat) %>%
                             mutate(quant = quantile(c1, prob = split_1)))
    
    table_b = data.frame(b_table%>% group_by(cat) %>%
                             mutate(quant = quantile(c1, prob = split_2)))
    
    table_c = data.frame(c_table%>% group_by(cat) %>%
                             mutate(quant = quantile(c1, prob = split_3)))
    
    
    
    
    #create a new variable ("diff") that measures if the quantile is bigger than the value of "c1"
    table_a$diff = ifelse(table_a$quant > table_a$c1,1,0)
    table_b$diff = ifelse(table_b$quant > table_b$c1,1,0)
    table_c$diff = ifelse(table_c$quant > table_c$c1,1,0)
    
    #group all tables
    
    final_table = rbind(table_a, table_b, table_c)
# calculate the total mean : this is what needs to be optimized
    mean = mean(final_table$diff)
    
    
}

**目标:**我现在想使用本教程中的“贝叶斯优化”算法 (https://rpubs.com/Argaadya/bayesian-optimization)。 objective就是求出这7个数产生最大的“均值”的值:

首先,定义“搜索范围”:

#define search bound for the 7 inputs

library(rBayesianOptimization)

random_1 = NULL
 random_2 = NULL
random_3 = NULL
 random_4 = NULL
split_1 = NULL
split_2 = NULL
split_3 = NULL

search_bound <- list(random_1 = c(80,120), random_2 = c(random_1,120),
                     random_3 = c(85,120), random_4 = c(random_2, 120),  split_1 = c(0,1), split_2 = c(0,1), split_3 = c(0,1))

二、设置初始样本:

#set initial sample:

set.seed(123)
search_grid <- data.frame(random_1 = runif(20,80,120), 
                          random_2 = runif(20,random_1,120),
                          random_3 = runif(20,85,120),
random_4 = runif(20,random_2,120),
split_1= runif(20,0,1),
split_2 = runif(20,0,1),
split_3 = runif(20,0,1)
)

最后,运行贝叶斯优化算法:

#run the bayesian optimization algorithm:
set.seed(1)
bayes_finance_ei <- BayesianOptimization(FUN = fitness, bounds = search_bound, 
                     init_grid_dt = search_grid, init_points = 0, 
                     n_iter = 10, acq = "ei")

但这会产生以下错误:

Error in FUN(X[[i]], ...) : subscript out of bounds

有人可以告诉我我做错了什么吗?我以为我遵循了教程中的所有必要步骤?

谢谢

您的代码中似乎存在一些错误,例如我认为您的适应度函数没有以所需格式返回数据,并且您的一些向量在定义之前就已被使用。

我做了一些更改,使您的代码更符合教程,并且似乎没有错误地完成,但我不能说结果是否“正确”或是否适合您使用-案例:

#load libraries
library(tidyverse)
#install.packages("rBayesianOptimization")
library(rBayesianOptimization)

# create some data for this example
a1 = rnorm(1000,100,10)
b1 = rnorm(1000,100,5)
c1 = sample.int(1000, 1000, replace = TRUE)
train_data = data.frame(a1,b1,c1)

#define fitness function : returns a single scalar value called "total"
fitness <- function(random_1, random_2, random_3, random_4, split_1, split_2, split_3) {
  
  #bin data according to random criteria
  train_data <- train_data %>% mutate(cat = ifelse(a1 <= random_1 & b1 <= random_3, "a", ifelse(a1 <= random_2 & b1 <= random_4, "b", "c")))
  
  train_data$cat = as.factor(train_data$cat)
  
  #new splits
  a_table = train_data %>%
    filter(cat == "a") %>%
    select(a1, b1, c1, cat)
  
  b_table = train_data %>%
    filter(cat == "b") %>%
    select(a1, b1, c1, cat)
  
  c_table = train_data %>%
    filter(cat == "c") %>%
    select(a1, b1, c1, cat)
  
  
  
  #calculate  quantile ("quant") for each bin
  
  table_a = data.frame(a_table%>% group_by(cat) %>%
                         mutate(quant = quantile(c1, prob = split_1)))
  
  table_b = data.frame(b_table%>% group_by(cat) %>%
                         mutate(quant = quantile(c1, prob = split_2)))
  
  table_c = data.frame(c_table%>% group_by(cat) %>%
                         mutate(quant = quantile(c1, prob = split_3)))
  
  
  
  
  #create a new variable ("diff") that measures if the quantile is bigger than the value of "c1"
  table_a$diff = ifelse(table_a$quant > table_a$c1,1,0)
  table_b$diff = ifelse(table_b$quant > table_b$c1,1,0)
  table_c$diff = ifelse(table_c$quant > table_c$c1,1,0)
  
  #group all tables
  final_table = rbind(table_a, table_b, table_c)
  
  # calculate the total mean : this is what needs to be optimized
  mean = mean(final_table$diff)
  
  # Based on the tutorial you linked, the fitness func
  # needs to return a list (a Score and a Pred)
  # I'm not sure if this is in line with your intended use-case
  # but it seems to work
  result <- list(Score = mean, Pred = 0)
  return(result)
}

# There were some bugs in this section,
# e.g. you were trying to call vectors ("random_1")
# that hadn't been defined yet
set.seed(123)
random_1 = runif(20,80,120)
random_2 = runif(20, random_1, 120)
random_3 = runif(20,85,120)
random_4 = runif(20, random_2, 120)
split_1= runif(20,0,1)
split_2 = runif(20,0,1)
split_3 = runif(20,0,1)

search_bound <- list(random_1 = c(80,120),
                     random_2 = c(random_1,120),
                     random_3 = c(85,120),
                     random_4 = c(random_2, 120), 
                     split_1 = c(0,1),
                     split_2 = c(0,1),split_3 = c(0,1))

search_grid <- data.frame(random_1, random_2, random_3, random_4, split_1, split_2, split_3)


set.seed(1)
bayes_finance_ei <- BayesianOptimization(FUN = fitness, bounds = search_bound, 
                                         init_grid_dt = search_grid, init_points = 0, 
                                         n_iter = 10, acq = "ei")
#> elapsed = 0.076  Round = 1   random_1 = 91.5031  random_2 = 116.8522 random_3 = 89.9980  random_4 = 118.9459 split_1 = 0.2436195 split_2 = 0.599989  split_3 = 0.6478935 Value = 0.6020 
#> elapsed = 0.023  Round = 2   random_1 = 111.5322 random_2 = 117.3987 random_3 = 99.50912 random_4 = 117.6454 split_1 = 0.6680556 split_2 = 0.3328235 split_3 = 0.3198206 Value = 0.4650 
#> elapsed = 0.026  Round = 3   random_1 = 96.35908 random_2 = 111.5012 random_3 = 99.48035 random_4 = 114.7645 split_1 = 0.4176468 split_2 = 0.488613  split_3 = 0.30772   Value = 0.4520 
#> elapsed = 0.024  Round = 4   random_1 = 115.3207 random_2 = 119.9732 random_3 = 97.90959 random_4 = 119.9805 split_1 = 0.7881958 split_2 = 0.9544738 split_3 = 0.2197676 Value = 0.8840 
#> elapsed = 0.024  Round = 5   random_1 = 117.6187 random_2 = 119.1801 random_3 = 90.33557 random_4 = 119.8480 split_1 = 0.1028646 split_2 = 0.4829024 split_3 = 0.3694889 Value = 0.4690 
#> elapsed = 0.024  Round = 6   random_1 = 81.82226 random_2 = 108.8724 random_3 = 89.85821 random_4 = 113.8633 split_1 = 0.4348927 split_2 = 0.8903502 split_3 = 0.9842192 Value = 0.9090 
#> elapsed = 0.025  Round = 7   random_1 = 101.1242 random_2 = 111.3939 random_3 = 93.15619 random_4 = 118.3654 split_1 = 0.984957  split_2 = 0.9144382 split_3 = 0.1542023 Value = 0.8130 
#> elapsed = 0.023  Round = 8   random_1 = 115.6968 random_2 = 118.2535 random_3 = 101.3087 random_4 = 119.6723 split_1 = 0.8930511 split_2 = 0.608735  split_3 = 0.091044  Value = 0.7490 
#> elapsed = 0.024  Round = 9   random_1 = 102.0574 random_2 = 107.2457 random_3 = 94.30904 random_4 = 117.3770 split_1 = 0.8864691 split_2 = 0.4106898 split_3 = 0.1419069 Value = 0.3830 
#> elapsed = 0.022  Round = 10  random_1 = 98.26459 random_2 = 101.4622 random_3 = 115.0240 random_4 = 109.6157 split_1 = 0.1750527 split_2 = 0.1470947 split_3 = 0.6900071 Value = 0.4150 
#> elapsed = 0.023  Round = 11  random_1 = 118.2733 random_2 = 119.9362 random_3 = 86.60409 random_4 = 119.9843 split_1 = 0.1306957 split_2 = 0.9352998 split_3 = 0.6192565 Value = 0.9250 
#> elapsed = 0.023  Round = 12  random_1 = 98.13337 random_2 = 117.8636 random_3 = 100.4770 random_4 = 119.2079 split_1 = 0.6531019 split_2 = 0.3012289 split_3 = 0.8913941 Value = 0.4050 
#> elapsed = 0.025  Round = 13  random_1 = 107.1028 random_2 = 116.0110 random_3 = 112.9624 random_4 = 118.8439 split_1 = 0.3435165 split_2 = 0.06072057    split_3 = 0.6729991 Value = 0.3040 
#> elapsed = 0.023  Round = 14  random_1 = 102.9053 random_2 = 116.5036 random_3 = 89.26647 random_4 = 116.5058 split_1 = 0.6567581 split_2 = 0.9477269 split_3 = 0.7370777 Value = 0.9340 
#> elapsed = 0.022  Round = 15  random_1 = 84.11699 random_2 = 85.0002  random_3 = 104.6332 random_4 = 101.6362 split_1 = 0.3203732 split_2 = 0.7205963 split_3 = 0.5211357 Value = 0.5130 
#> elapsed = 0.023  Round = 16  random_1 = 115.9930 random_2 = 117.9075 random_3 = 92.2286  random_4 = 118.3681 split_1 = 0.1876911 split_2 = 0.1422943 split_3 = 0.6598384 Value = 0.1610 
#> elapsed = 0.024  Round = 17  random_1 = 89.84351 random_2 = 112.7160 random_3 = 89.46361 random_4 = 115.4826 split_1 = 0.7822943 split_2 = 0.5492847 split_3 = 0.8218055 Value = 0.5790 
#> elapsed = 0.025  Round = 18  random_1 = 81.68238 random_2 = 89.97462 random_3 = 111.3658 random_4 = 108.3733 split_1 = 0.09359499    split_2 = 0.9540912 split_3 = 0.7862816 Value = 0.7880 
#> elapsed = 0.023  Round = 19  random_1 = 93.11683 random_2 = 101.6705 random_3 = 116.3266 random_4 = 108.1188 split_1 = 0.466779  split_2 = 0.5854834 split_3 = 0.9798219 Value = 0.7410 
#> elapsed = 0.026  Round = 20  random_1 = 118.1801 random_2 = 118.6017 random_3 = 98.1062  random_4 = 118.7571 split_1 = 0.5115055 split_2 = 0.4045103 split_3 = 0.4394315 Value = 0.4420 
#> elapsed = 0.023  Round = 21  random_1 = 96.45638 random_2 = 111.5322 random_3 = 87.30664 random_4 = 117.3987 split_1 = 0.1028404 split_2 = 1.0000    split_3 = 0.9511012 Value = 0.9910 
#> elapsed = 0.022  Round = 22  random_1 = 117.5734 random_2 = 91.5031  random_3 = 120.0000 random_4 = 116.8522 split_1 = 2.220446e-16  split_2 = 1.0000    split_3 = 2.220446e-16  Value = 0.0020 
#> elapsed = 0.023  Round = 23  random_1 = 111.5021 random_2 = 111.5322 random_3 = 120.0000 random_4 = 117.3987 split_1 = 0.1188052 split_2 = 1.0000    split_3 = 0.6455233 Value = 0.1910 
#> elapsed = 0.028  Round = 24  random_1 = 80.0000  random_2 = 92.90239 random_3 = 86.18939 random_4 = 116.8635 split_1 = 0.2557032 split_2 = 1.0000    split_3 = 0.3517052 Value = 0.5170 
#> elapsed = 0.026  Round = 25  random_1 = 90.64032 random_2 = 92.3588  random_3 = 112.3187 random_4 = 117.3987 split_1 = 1.0000    split_2 = 1.0000    split_3 = 1.0000    Value = 0.9970 
#> elapsed = 0.022  Round = 26  random_1 = 100.4363 random_2 = 104.1665 random_3 = 106.2099 random_4 = 117.3987 split_1 = 1.0000    split_2 = 1.0000    split_3 = 1.0000    Value = 0.9970 
#> elapsed = 0.022  Round = 27  random_1 = 119.0981 random_2 = 91.5031  random_3 = 120.0000 random_4 = 117.3987 split_1 = 1.0000    split_2 = 1.0000    split_3 = 1.0000    Value = 0.9980 
#> elapsed = 0.023  Round = 28  random_1 = 89.95279 random_2 = 101.9462 random_3 = 85.0000  random_4 = 116.9137 split_1 = 2.220446e-16  split_2 = 1.0000    split_3 = 1.0000    Value = 0.9980 
#> elapsed = 0.027  Round = 29  random_1 = 113.5928 random_2 = 91.5031  random_3 = 120.0000 random_4 = 117.3987 split_1 = 1.0000    split_2 = 2.220446e-16  split_3 = 1.0000    Value = 0.9980 
#> elapsed = 0.027  Round = 30  random_1 = 116.9869 random_2 = 91.5031  random_3 = 120.0000 random_4 = 117.2949 split_1 = 1.0000    split_2 = 0.505048  split_3 = 1.0000    Value = 0.9980 
#> 
#>  Best Parameters Found: 
#> Round = 27   random_1 = 119.0981 random_2 = 91.5031  random_3 = 120.0000 random_4 = 117.3987 split_1 = 1.0000    split_2 = 1.0000    split_3 = 1.0000    Value = 0.9980

reprex package (v2.0.0)

于 2021-07-08 创建