使用 model.matrix 的单热编码

One-hot encoding using model.matrix

model.matrix中有些地方我不明白。当我在没有拦截的情况下输入一个二进制变量时,它 returns 两个级别。

> temp.data <- data.frame('x' = sample(c('A', 'B'), 1000, replace = TRUE))
> temp.data.table <- model.matrix( ~ 0 + x, data = temp.data)
> head(temp.data.table)
  xA xB
1  1  0
2  0  1
3  0  1
4  0  1
5  1  0
6  0  1

但是,当我进入另一个二进制级别时,它只创建了 3 列。这是为什么?是什么让函数的行为突然不同了?我该如何避免呢?

> temp.data <- data.frame('x' = sample(c('A', 'B'), 1000, replace = TRUE),
+                         'y' = sample(c('J', 'D'), 1000, replace = TRUE))
> temp.data.table <- model.matrix( ~ 0 + x + y, data = temp.data)
> head(temp.data.table)
  xA xB yJ
1  0  1  0
2  0  1  1
3  0  1  1
4  0  1  0
5  1  0  1
6  0  1  0

您需要使用 factors 并将 contrasts 设置为 FALSE。试试这个:

n <- 10
temp.data <- data.frame('x'=sample(c('A', 'B'), n, replace=TRUE),
                        'y'=factor(sample(c('J', 'D'), n, replace=TRUE)))
model.matrix( ~ 0 + x + y, data=temp.data,
              contrasts=list(y=contrasts(temp.data$y, contrasts=FALSE)))

#    xA xB yD yJ
# 1   0  1  1  0
# 2   1  0  0  1
# 3   0  1  1  0
# 4   1  0  0  1
# 5   0  1  0  1
# 6   1  0  1  0
# 7   1  0  1  0
# 8   0  1  1  0
# 9   0  1  0  1
# 10  0  1  1  0
# attr(,"assign")
# [1] 1 1 2 2
# attr(,"contrasts")
# attr(,"contrasts")$x
# [1] "contr.treatment"
# 
# attr(,"contrasts")$y
#   D J
# D 1 0
# J 0 1

要了解为什么会发生这种情况,请尝试:

contrasts(temp.data$y)
#   J
# D 0
# J 1

contrasts(temp.data$y, contrasts=F)
#   D J
# D 1 0
# J 0 1

对于您的 x 变量,这会通过设置 0 + 自动发生以删除截距。 (其实x也应该编码为factor)。

原因是,在线性回归中,因子变量的水平通常与参考水平(您可以使用 relevel 更改)进行比较。在您的模型矩阵中,使用 0 + 您可以删除第一个变量的截距,但不会删除以下变量(尝试 model.matrix( ~ 0 + y + x, data=temp.data) ,您只得到一个 xy )。这是在默认情况下使用处理对比的标准 contrasts 设置中确定的。

您可能想阅读 Rose Maier (2015) 的相关 post,对此进行了非常详细的解释:

您需要重新设置因子变量的对比。参见 this post

temp.data <- data.frame('x' = sample(c('A', 'B'), 1000, replace = TRUE),
+                         'y' = sample(c('J', 'D'), 1000, replace = TRUE))

dat = model.matrix(~ -1 +., data=temp.data, contrasts.arg = lapply(temp.data[,1:2], contrasts, contrasts=FALSE))
head(dat)

  xA xB yD yJ
1  0  1  0  1
2  1  0  0  1
3  1  0  0  1
4  1  0  0  1
5  0  1  1  0
6  0  1  0  1