StratifiedShuffleSplit 报告 n_iter 的多个参数

StratifiedShuffleSplit reporting multiple args for n_iter

我正在尝试使用 scikit-learn 的 StratifiedShuffleSplit 对我的数据集进行一次拆分,以保留 class 样本比率。

from sklearn.datasets import load_files
from sklearn.model_selection import StratifiedShuffleSplit
dataset = load_files('reviews/aggregated/')
split = StratifiedShuffleSplit(dataset.target, n_iter=1, test_size=0.2)
train_idx, test_idx = next(iter(split))
train_X, train_y = dataset.data[train_idx], dataset.target[train_idx]
test_X, test_y = dataset.data[test_idx], dataset.target[test_idx]

这给了我以下错误:

TypeError: __init__() got multiple values for keyword argument 'n_iter'

但我显然只为它传递了一个值。 StratifiedShuffleSplit 是否与数据集不兼容?文档似乎没有答案

原来 the documentation 已经过时了。查看文档字符串,我发现正确的做法是:

sss = StratifiedShuffleSplit(n_iter=1, test_size=0.2)
train_idx, test_idx = next(sss.split(dataset.data, dataset.target))