Scikit-learn:覆盖 class 中的 class 方法

Scikit-learn: overriding a class method in a classifier

我正在尝试覆盖 classifier class 的 predict_proba 方法。据我所知,最简单的方法 seen 如果适用的话,是对基本 class 方法的输入进行预处理或对其输出进行后处理。

class RandomForestClassifierWrapper(RandomForestClassifier):

    def predict_proba(self, X):
        pre_process(X)
        ret = super(RandomForestClassifierWrapper, self).predict_proba(X)
        return post_process(ret)

但是,我想做的是复制一个变量,该变量是在基础 class 方法中本地创建的,在方法 returns 时进行处理和垃圾收集。我打算处理存储在这个变量中的中间结果。有没有一种直接的方法可以做到这一点而不会弄乱 base class internals?

无法从外部访问方法的局部变量。由于您拥有基本分类器的代码,您可以做的是通过从基本分类器复制代码并根据需要处理局部变量来覆盖 predict_proba 方法。

尝试覆盖:

class RandomForestClassifierWrapper(RandomForestClassifier):

    def predict_proba(self, X):
            check_is_fitted(self, 'n_outputs_')

            # Check data
            X = check_array(X, dtype=DTYPE, accept_sparse="csr")

            # Assign chunk of trees to jobs
            n_jobs, n_trees, starts = _partition_estimators(self.n_estimators,
                                                            self.n_jobs)

            # Parallel loop
            all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose,
                                 backend="threading")(

            # do something with all_proba

            return all_proba