R 中用于对散点图中的线性特征进行分组的函数

Function in R to group linear features in a scatterplot

我有一个使用 ggplot.

生成的散点图

数据在这里:

Retirees<- data.frame(No_Races = c(11L, 11L, 21L, 12L, 15L, 10L, 23L, 13L, 19L, 11L, 22L, 14L, 
  17L, 33L, 12L, 45L, 15L, 55L, 22L, 41L, 19L, 16L, 13L, 26L, 23L, 
  10L, 20L, 44L, 17L, 14L, 21L, 28L, 18L, 11L, 22L, 37L, 26L, 30L, 
  15L, 23L, 31L, 47L, 12L, 20L, 16L, 36L, 37L, 29L, 21L, 17L, 68L, 
  56L, 13L, 22L, 27L, 18L, 23L, 14L, 28L, 19L, 24L, 10L, 15L, 20L, 
  35L, 25L, 36L, 31L, 26L, 21L, 37L, 16L, 32L, 27L, 11L, 22L, 33L, 
  28L, 17L, 40L, 23L, 35L, 53L, 65L, 12L, 18L, 30L, 66L, 24L, 36L, 
  48L, 49L, 25L, 63L, 19L, 38L, 32L, 45L, 13L, 26L, 78L, 39L, 52L, 
  46L, 33L, 20L, 60L, 27L, 34L, 41L, 69L, 14L, 21L, 42L, 28L, 35L, 
  57L, 50L, 36L, 29L, 22L, 81L, 37L, 15L, 38L, 23L, 46L, 31L, 48L, 
  16L, 24L, 32L, 57L, 49L, 41L, 33L, 25L, 42L, 17L, 34L, 26L, 35L, 
  18L, 27L, 54L, 45L, 37L, 28L, 47L, 19L, 57L, 29L, 39L, 10L, 20L, 
  30L, 40L, 41L, 31L, 21L, 42L, 32L, 54L, 11L, 22L, 33L, 44L, 66L, 
  67L, 45L, 34L, 23L, 69L, 46L, 35L, 47L, 59L, 71L, 36L, 12L, 24L, 
  60L, 61L, 37L, 25L, 38L, 51L, 64L, 13L, 39L, 26L, 53L, 80L, 67L, 
  27L, 68L, 55L, 14L, 42L, 28L, 43L, 58L, 29L, 44L, 74L, 15L, 45L, 
  30L, 46L, 31L, 47L, 79L, 16L, 32L, 48L, 49L, 33L, 50L, 17L, 34L, 
  52L, 35L, 53L, 18L, 72L, 36L, 54L, 37L, 56L, 19L, 38L, 39L, 20L, 
  40L, 61L, 41L, 83L, 21L, 42L, 43L, 65L, 87L, 22L, 45L, 23L, 46L, 
  93L, 47L, 71L, 24L, 72L, 49L, 25L, 100L, 51L, 26L, 52L, 27L, 
  54L, 28L, 85L, 57L, 29L, 59L, 30L, 31L, 32L, 96L, 65L, 33L, 67L, 
  34L, 35L, 36L, 37L, 74L, 38L, 77L, 39L, 40L, 41L, 42L, 43L, 44L, 
  45L, 46L, 47L, 48L, 49L, 50L, 51L, 52L, 53L, 54L, 55L, 56L, 59L, 
  60L, 61L, 62L, 64L, 65L, 68L, 69L, 71L, 73L, 77L, 82L, 83L, 92L, 
  98L), 
  Perc_Retired = c(54.5454545454545, 45.4545454545455, 42.8571428571429, 41.6666666666667, 
      40, 40, 39.1304347826087, 38.4615384615385, 36.8421052631579, 
                         36.3636363636364, 36.3636363636364, 35.7142857142857, 35.2941176470588, 
                         33.3333333333333, 33.3333333333333, 33.3333333333333, 33.3333333333333, 
                         32.7272727272727, 31.8181818181818, 31.7073170731707, 31.5789473684211, 
                         31.25, 30.7692307692308, 30.7692307692308, 30.4347826086957, 
                         30, 30, 29.5454545454545, 29.4117647058824, 28.5714285714286, 
                         28.5714285714286, 28.5714285714286, 27.7777777777778, 27.2727272727273, 
                         27.2727272727273, 27.027027027027, 26.9230769230769, 26.6666666666667, 
                         26.6666666666667, 26.0869565217391, 25.8064516129032, 25.531914893617, 
                         25, 25, 25, 25, 24.3243243243243, 24.1379310344828, 23.8095238095238, 
                         23.5294117647059, 23.5294117647059, 23.2142857142857, 23.0769230769231, 
                         22.7272727272727, 22.2222222222222, 22.2222222222222, 21.7391304347826, 
                         21.4285714285714, 21.4285714285714, 21.0526315789474, 20.8333333333333, 
                         20, 20, 20, 20, 20, 19.4444444444444, 19.3548387096774, 19.2307692307692, 
                         19.047619047619, 18.9189189189189, 18.75, 18.75, 18.5185185185185, 
                         18.1818181818182, 18.1818181818182, 18.1818181818182, 17.8571428571429, 
                         17.6470588235294, 17.5, 17.3913043478261, 17.1428571428571, 16.9811320754717, 
                         16.9230769230769, 16.6666666666667, 16.6666666666667, 16.6666666666667, 
                         16.6666666666667, 16.6666666666667, 16.6666666666667, 16.6666666666667, 
                         16.3265306122449, 16, 15.8730158730159, 15.7894736842105, 15.7894736842105, 
                         15.625, 15.5555555555556, 15.3846153846154, 15.3846153846154, 
                         15.3846153846154, 15.3846153846154, 15.3846153846154, 15.2173913043478, 
                         15.1515151515152, 15, 15, 14.8148148148148, 14.7058823529412, 
                         14.6341463414634, 14.4927536231884, 14.2857142857143, 14.2857142857143, 
                         14.2857142857143, 14.2857142857143, 14.2857142857143, 14.0350877192982, 
                         14, 13.8888888888889, 13.7931034482759, 13.6363636363636, 13.5802469135802, 
                         13.5135135135135, 13.3333333333333, 13.1578947368421, 13.0434782608696, 
                         13.0434782608696, 12.9032258064516, 12.5, 12.5, 12.5, 12.5, 12.280701754386, 
                         12.2448979591837, 12.1951219512195, 12.1212121212121, 12, 11.9047619047619, 
                         11.7647058823529, 11.7647058823529, 11.5384615384615, 11.4285714285714, 
                         11.1111111111111, 11.1111111111111, 11.1111111111111, 11.1111111111111, 
                         10.8108108108108, 10.7142857142857, 10.6382978723404, 10.5263157894737, 
                         10.5263157894737, 10.3448275862069, 10.2564102564103, 10, 10, 
                         10, 10, 9.75609756097561, 9.67741935483871, 9.52380952380952, 
                         9.52380952380952, 9.375, 9.25925925925926, 9.09090909090909, 
                         9.09090909090909, 9.09090909090909, 9.09090909090909, 9.09090909090909, 
                         8.95522388059701, 8.88888888888889, 8.82352941176471, 8.69565217391304, 
                         8.69565217391304, 8.69565217391304, 8.57142857142857, 8.51063829787234, 
                         8.47457627118644, 8.45070422535211, 8.33333333333333, 8.33333333333333, 
                         8.33333333333333, 8.33333333333333, 8.19672131147541, 8.10810810810811, 
                         8, 7.89473684210526, 7.84313725490196, 7.8125, 7.69230769230769, 
                         7.69230769230769, 7.69230769230769, 7.54716981132075, 7.5, 7.46268656716418, 
                         7.40740740740741, 7.35294117647059, 7.27272727272727, 7.14285714285714, 
                         7.14285714285714, 7.14285714285714, 6.97674418604651, 6.89655172413793, 
                         6.89655172413793, 6.81818181818182, 6.75675675675676, 6.66666666666667, 
                         6.66666666666667, 6.66666666666667, 6.52173913043478, 6.45161290322581, 
                         6.38297872340426, 6.32911392405063, 6.25, 6.25, 6.25, 6.12244897959184, 
                         6.06060606060606, 6, 5.88235294117647, 5.88235294117647, 5.76923076923077, 
                         5.71428571428571, 5.66037735849057, 5.55555555555556, 5.55555555555556, 
                         5.55555555555556, 5.55555555555556, 5.40540540540541, 5.35714285714286, 
                         5.26315789473684, 5.26315789473684, 5.12820512820513, 5, 5, 4.91803278688525, 
                         4.8780487804878, 4.81927710843374, 4.76190476190476, 4.76190476190476, 
                         4.65116279069767, 4.61538461538461, 4.59770114942529, 4.54545454545455, 
                         4.44444444444444, 4.34782608695652, 4.34782608695652, 4.3010752688172, 
                         4.25531914893617, 4.22535211267606, 4.16666666666667, 4.16666666666667, 
                         4.08163265306122, 4, 4, 3.92156862745098, 3.84615384615385, 3.84615384615385, 
                         3.7037037037037, 3.7037037037037, 3.57142857142857, 3.52941176470588, 
                         3.50877192982456, 3.44827586206897, 3.38983050847458, 3.33333333333333, 
                         3.2258064516129, 3.125, 3.125, 3.07692307692308, 3.03030303030303, 
                         2.98507462686567, 2.94117647058824, 2.85714285714286, 2.77777777777778, 
                         2.7027027027027, 2.7027027027027, 2.63157894736842, 2.5974025974026, 
                         2.56410256410256, 2.5, 2.4390243902439, 2.38095238095238, 2.32558139534884, 
                         2.27272727272727, 2.22222222222222, 2.17391304347826, 2.12765957446809, 
                         2.08333333333333, 2.04081632653061, 2, 1.96078431372549, 1.92307692307692, 
                         1.88679245283019, 1.85185185185185, 1.81818181818182, 1.78571428571429, 
                         1.69491525423729, 1.66666666666667, 1.63934426229508, 1.61290322580645, 
                         1.5625, 1.53846153846154, 1.47058823529412, 1.44927536231884, 
                         1.40845070422535, 1.36986301369863, 1.2987012987013, 1.21951219512195, 
                         1.20481927710843, 1.08695652173913, 1.02040816326531))

绘制时数据看起来像这样

ggplot(Retirees,aes(x=No_Races,y=Perc_Retired)) + 
  geom_point()

这里有明显的线性组 - R 中是否有任何函数可以让我将每个点分组到线性组之一。

我试过 k 均值聚类,但不出所料,它们不遵循线性组:

评论太长了,所以我 post 它作为答案。

我的印象是,对于这类问题,dbscan表现不错。例如,请参阅小插图 here for a nice visualization of various clustering algorithms and their performance. DBSCAN is implemented in R in a package of the same name. See here。不用说,你需要调整一些参数,但经过大约一分钟的修补,我产生了以下内容:

library(dbscan)

db <- dbscan(Retirees, eps = 2, minPts = 3)
plot(Retirees, col = db$cluster + 1L, pch = db$cluster + 1L)

这显然并不完美,但是嘿,我花了大约一分钟的时间。我相信您可以通过适当调整参数来改善这一点。

它不起作用的原因是因为您忘记了适当地缩放数据。

但是,我要告诉你一个坏消息:

这些不是集群,而是工件。

显然,某些数据在转换之前的某个时间点被离散化,因此您会得到这些丑陋的人工制品。他们不是真的。

请注意,这些行跟在 1/x、2/x 等之后

准确地说,撤消百分比。然后您会看到 "clusters",其中 1、2、3 已经退休。没用 "clusters" 找新东西。

在进入答案之前,我想到了- 一个警告: 确保您的数据没有这种形式,因为数据生成过程中存在一些离散化,或者由于具有您随后要转换为百分比的离散值。

也就是说,即使这是神器,我仍然认为这是一个有趣的问题。我最终创建了一个混合 linear-regression/k-means 之类的东西,我认为它可以很好地解决你的问题。我不确定我是否在重新发明轮子,而且我确信可以进行改进,但我认为它应该适合您。

第一条注意事项(我的设计选择):既然您已经愿意尝试 k-means,我就坚持使用这种方法。从本质上讲,这样做的缺点是您需要选择 $K$(簇数)。我建议尝试一些值,看看哪个值能给您带来您想要的结果。

第二个注意事项(局部优化和优化):我创建的代码使用了一种类型 "genetic algorithm"(至少我认为这是它的正确名称)来克服局部最优。本质上它同时从随机初始起点运行 n.unit 优化,每次迭代都会采用最差的优化并将其替换为最好的优化。从本质上讲,这减少了浪费在处理陷入局部最优解的蹩脚解决方案上的时间。此外,我还使用组成员的随机抽样来进一步帮助克服局部最优。我发现它在我的模拟数据上运行得很好,而且速度非常快。

首先 - 我模拟的数据看起来有点像你的。

library(tidyverse)

lambda <- seq(0.01, .1, by=0.01)
x <- 5:100
pars <- expand.grid("lambda" = lambda, "x" = x)
dat <- cbind(pars$x, exp(-pars$lambda*pars$x))
plot(dat)

其次 - 对数转换您的数据

我首先建议您对 Perc_Retired 的值进行对数转换。这可能会使它们看起来更线性。

dat.log <- cbind(dat[,1], log(dat[,2]))
plot(dat.log)

第三——其实"Linear-Regression-K-Means–Thing"我做的

# Create Design Matrix and Response
X <- cbind(1, dat.log[,1]) # Design matrix, add intercept
Y <- dat.log[,2] # Responses
K <- 10  # Number of Clusters
n.unit <- 10 # Number of parallel optimizations to run
n <- nrow(X) # Number of observations

# Function to convert vector/matrix to probabilities (e.g., normalize to 1)
miniclo <- function (c){
  if (is.vector(c)) 
    c <- matrix(c, nrow = 1)
  (c/rowSums(c)) 
}

# Random Initialize the group assignments. 
gs <- as.list(1:n.unit) %>% 
  map(~sample(1:K, size = nrow(dat.log), replace = TRUE)) 

n.iter <- 100 # Number of iterations to run the optimization for
for (i in 1:n.iter) {
  # Start out by fitting linear regressons to each group
  fits <- gs %>% 
    map(~split(1:n, .x)) %>% 
    at_depth(2, ~lm.fit(X[.x,], Y[.x])) # Fit models

  # Calculate the squared residuals of each data-point to each
  # groups fitted linear model. Note I also add a small value (0.0001)
  # to avoid zero values that can show up later when converting to probabilities
  # and inverting. 
  sq.resids <- fits %>% 
    at_depth(2, "coefficients") %>% # Extract Coefficients
    at_depth(2, ~((Y-X%*%.x)^2)+0.0001) %>%  # Predict and Compute Squared Residuals
    map(bind_cols)

  # Store which "unit" which of the n.unit optimiztions did the 
  # best and which did the worst
  best.worst <- sq.resids %>% 
    map(sum) %>% 
    unlist()
  best.worst <- c(which.min(best.worst), which.max(best.worst))

  # Compute new group assignements, notice I convert the relative
  # squared residual of each model into a probability and then use the 
  # inverse of this probability as the probability that a data-point
  # belongs to that group
  new.gs <- sq.resids  %>% 
    map(miniclo)  %>% # Add small value to fix zeros
    map(~.x^(-1)) %>% 
    map(miniclo) %>% 
    map(~apply(.x, 1, function(x) sample(1:K, 1, replace = TRUE, prob = x))) # Add sampling to get over some minima

  # Replace the worst unit with the best
  new.gs[[best.worst[2]]] <- new.gs[[best.worst[1]]]

  # Update the cluster assignemnts for each unit
  gs <- new.gs
}

# Now investigate results of best unit and plot 
pal <- rainbow(10) 
best.g <- new.gs[[best.worst[1]]]

# Function to enable easier plotting of the results.
plot_groupings <- function(d, g, col=rainbow(length(unique(g)))){
  K <- length(unique(g))
  plot(d)
  for (i in 1:K){
    d.local <- d[g == i,]
    d.local <- d.local[order(d.local[,1]), ]
    lines(d.local, col = col[i], lwd=2)
  }
}

我们在

上拟合模型的对数转换尺度上的第一个图
plot_groupings(dat.log, best.g)

现在按数据的原始比例绘制

plot_groupings(dat, best.g)

如您所见,该模型效果很好。少数点被错误分类(如果你增加 n.unitn.iter 我打赌这可能会有所改善)但总的来说我对结果非常满意。

潜在的概括或替代方案

请注意,这种使用平方残差作为进行 k 均值聚类的方法的方法非常普遍,适用于其他模型(不仅仅是线性模型)。我很想知道以前是否有人这样做过(如果我真的在这里发明了一些新东西,我会感到震惊)。

我没有做太多的网络搜索来找到它。您也可以通过搜索 "functional data clustering" 或类似的东西找到一些好东西。

无论如何,希望这对您有所帮助。

更新:应用于您提供的数据

我将我的模型应用于您提供的数据,唯一的区别是我还对 No_Races 变量(不仅仅是 Perc_Retired)变量进行了对数转换,以线性化数据。

我认为结果比 dbscan 好一点。