python 问题 ---> 10 个数据 = mnist.data[示例]

python problem ---> 10 data = mnist.data[sample]

我是 python 的新手。我正在尝试了解此 K NN 算法的工作原理 我尝试应用此代码。

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)

print (mnist.data.shape)


print (mnist.target.shape)
import numpy as np
sample = np.random.randint(70000, size=5000)
data = mnist.data[sample]
target = mnist.target[sample]
from sklearn.model_selection import train_test_split

xtrain, xtest, ytrain, ytest = train_test_split(data, target, train_size=0.8)

但它不起作用,它显示错误

(70000, 784)
(70000,)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-6-3b6254553355> in <module>
      8 import numpy as np
      9 sample = np.random.randint(70000, size=5000)
---> 10 data = mnist.data[sample]
     11 #target = mnist.target[sample]
     12 #from sklearn.model_selection import train_test_split

~\anaconda3\lib\site-packages\pandas\core\frame.py in __getitem__(self, key)
   3028             if is_iterator(key):
   3029                 key = list(key)
-> 3030             indexer = self.loc._get_listlike_indexer(key, axis=1, raise_missing=True)[1]
   3031 
   3032         # take() does not accept boolean indexers

你正在索引 pandas 数据帧,你应该使用 .loc 或 .iloc,正如 here 所指出的那样,而不是你习惯使用 numpy 数组的正常索引,这应该有效:

data = mnist.data.loc[sample]
target = mnist.target.loc[sample]