具有多个标签的逻辑回归 PySpark MLlib 问题
Logistic Regression PySpark MLlib issue with multiple labels
我正在尝试创建一个 LogisticRegression 模型 (LogisticRegressionWithSGD),但出现
错误
org.apache.spark.SparkException: Input validation failed.
如果我给它二进制输入(0,1 而不是 0,1,2)它确实会成功。
示例输入:
parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]
代码:
model = LogisticRegressionWithSGD.train(parsed_data)
spark中的Logistic Regression模型应该只用于二元分类吗?
虽然从文档中不清楚(您必须深入研究 source code 才能意识到),LogisticRegressionWithSGD
仅适用于二进制数据;对于多项式回归,您应该使用 LogisticRegressionWithLBFGS
:
from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD
from pyspark.mllib.regression import LabeledPoint
parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]
model = LogisticRegressionWithSGD.train(sc.parallelize(parsed_data)) # gives error:
# org.apache.spark.SparkException: Input validation failed.
model = LogisticRegressionWithLBFGS.train(sc.parallelize(parsed_data), numClasses=3) # works OK
我正在尝试创建一个 LogisticRegression 模型 (LogisticRegressionWithSGD),但出现
错误org.apache.spark.SparkException: Input validation failed.
如果我给它二进制输入(0,1 而不是 0,1,2)它确实会成功。
示例输入:
parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]
代码:
model = LogisticRegressionWithSGD.train(parsed_data)
spark中的Logistic Regression模型应该只用于二元分类吗?
虽然从文档中不清楚(您必须深入研究 source code 才能意识到),LogisticRegressionWithSGD
仅适用于二进制数据;对于多项式回归,您应该使用 LogisticRegressionWithLBFGS
:
from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD
from pyspark.mllib.regression import LabeledPoint
parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]
model = LogisticRegressionWithSGD.train(sc.parallelize(parsed_data)) # gives error:
# org.apache.spark.SparkException: Input validation failed.
model = LogisticRegressionWithLBFGS.train(sc.parallelize(parsed_data), numClasses=3) # works OK