sklearn 的 `RandomizedSearchCV` 不适用于 `np.random.RandomState`

sklearn's `RandomizedSearchCV` not working with `np.random.RandomState`

我正在尝试优化管道并想尝试给 RandomizedSearchCV 一个 np.random.RandomState 对象。我不能让它工作,但我可以给它其他发行版。

有什么特殊的语法可以用来给 RandomSearchCV 一个 np.random.RandomState(0).uniform(0.1,1.0)?

from scipy import stats
import numpy as np
from sklearn.neighbors import KernelDensity
from sklearn.grid_search import RandomizedSearchCV

# Generate data
x = np.random.normal(5,1,size=int(1e3))

# Make model
model = KernelDensity()

# Gridsearch for best params
# This one works
search_params = RandomizedSearchCV(model, param_distributions={"bandwidth":stats.uniform(0.1, 1)}, n_iter=30, n_jobs=2)
search_params.fit(x[:, None])

# RandomizedSearchCV(cv=None, error_score='raise',
#           estimator=KernelDensity(algorithm='auto', atol=0, bandwidth=1.0, breadth_first=True,
#        kernel='gaussian', leaf_size=40, metric='euclidean',
#        metric_params=None, rtol=0),
#           fit_params={}, iid=True, n_iter=30, n_jobs=2,
#           param_distributions={'bandwidth': <scipy.stats._distn_infrastructure.rv_frozen object at 0x106ab7da0>},
#           pre_dispatch='2*n_jobs', random_state=None, refit=True,
#           scoring=None, verbose=0)

# This one doesn't work :(
search_params = RandomizedSearchCV(model, param_distributions={"bandwidth":np.random.RandomState(0).uniform(0.1, 1)}, n_iter=30, n_jobs=2)
# TypeError: object of type 'float' has no len()

您观察到的是预期的,因为 np.random.RandomState() 类型对象的 class 方法 uniform 在调用时立即绘制样本.

与此相比,您对 scipy 的 stats.uniform() 的使用创建了一个尚未从 采样的分布。 (虽然我不确定它是否像您预期的那样工作;请注意参数)。

如果您想合并基于 np.random.RandomState() 的内容,您必须构建自己的 class,如 docs 中所述:

This example uses the scipy.stats module, which contains many useful distributions for sampling parameters, such as expon, gamma, uniform or randint. In principle, any function can be passed that provides a rvs (random variate sample) method to sample a value. A call to the rvs function should provide independent random samples from possible parameter values on consecutive calls.