模型的特征数量必须与输入相匹配?
Number of features of the model must match the input?
我正在尝试对我拥有的一些数据使用 RandomForestClassifier。代码如下:
print train_data[0,0:20]
print train_data[0,21::]
print test_data[0]
print 'Training...'
forest = RandomForestClassifier(n_estimators=100)
forest = forest.fit( train_data[0::,0::20], train_data[0::,21::] )
print 'Predicting...'
output = forest.predict(test_data)
但这会产生以下错误:
ValueError: Number of features of the model must match the input.
Model n_features is 3 and input n_features is 21
前三个打印语句的输出是:
[ 0. 0. 0. 0. 1. 0.
0. 0. 0. 0. 1. 0.
0. 0. 0. 37.7745986 -122.42589168
0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
1. 0.]
[ 0. 0. 0. 0. 0. 0.
0. 1. 0. 0. 1. 0.
0. 0. 0. 0. 37.73505101
-122.3995877 0. 0. 0. ]
我假设数据的格式对于我的 fit
/predict
调用是正确的,但它在 predict
上出错了。谁能看出我在这里做错了什么?
用于训练模型的输入数据是 train_data[0::,0::20]
,我认为这是一个错误(为什么跳过中间的特征?)——它应该是 train_data[0::,0:20]
而不是基于调试打印你一开始就做了。
此外,最后一列似乎代表了 train_data
和 test_data
中的标签。预测时,您可能希望在调用 predict
函数时传递 test_data[:, :20]
而不是 test_data
。
我正在尝试对我拥有的一些数据使用 RandomForestClassifier。代码如下:
print train_data[0,0:20]
print train_data[0,21::]
print test_data[0]
print 'Training...'
forest = RandomForestClassifier(n_estimators=100)
forest = forest.fit( train_data[0::,0::20], train_data[0::,21::] )
print 'Predicting...'
output = forest.predict(test_data)
但这会产生以下错误:
ValueError: Number of features of the model must match the input. Model n_features is 3 and input n_features is 21
前三个打印语句的输出是:
[ 0. 0. 0. 0. 1. 0.
0. 0. 0. 0. 1. 0.
0. 0. 0. 37.7745986 -122.42589168
0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
1. 0.]
[ 0. 0. 0. 0. 0. 0.
0. 1. 0. 0. 1. 0.
0. 0. 0. 0. 37.73505101
-122.3995877 0. 0. 0. ]
我假设数据的格式对于我的 fit
/predict
调用是正确的,但它在 predict
上出错了。谁能看出我在这里做错了什么?
用于训练模型的输入数据是 train_data[0::,0::20]
,我认为这是一个错误(为什么跳过中间的特征?)——它应该是 train_data[0::,0:20]
而不是基于调试打印你一开始就做了。
此外,最后一列似乎代表了 train_data
和 test_data
中的标签。预测时,您可能希望在调用 predict
函数时传递 test_data[:, :20]
而不是 test_data
。