提问者:小点点

学习。predict()默认阈值


我正在研究一个不平衡类的分类问题(5%1)。我想预测课程,而不是概率。

在二进制分类问题中,是scikit的分类器。predict()默认使用0.5?如果没有,默认的方法是什么?如果是,我该如何更改它?

在Scikit中,一些分类器有class_weight='auto'选项,但并非所有分类器都有。如果使用class_weight='car'.预测()会使用实际的人口比例作为阈值吗?

在像MultinomialNB这样不支持class_weight的分类器中,有什么方法可以做到这一点?除了使用predict_proba()然后自己计算类。


共3个答案

匿名用户

可使用clf设置阈值。预测概率()

例如:

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state = 2)
clf.fit(X_train,y_train)
# y_pred = clf.predict(X_test)  # default threshold is 0.5
y_pred = (clf.predict_proba(X_test)[:,1] >= 0.3).astype(bool) # set threshold as 0.3

匿名用户

是scikit的分类器。predict()默认使用0.5?

在概率分类器中,是的。正如其他人所解释的那样,从数学角度来看,这是唯一合理的阈值。

在不支持class\u-weight的多项式nb这样的分类器中,该怎么做?

您可以设置class_prior,这是每个类y的先验概率P(y)。这有效地改变了决策边界。例如。

# minimal dataset
>>> X = [[1, 0], [1, 0], [0, 1]]
>>> y = [0, 0, 1]
# use empirical prior, learned from y
>>> MultinomialNB().fit(X,y).predict([1,1])
array([0])
# use custom prior to make 1 more likely
>>> MultinomialNB(class_prior=[.1, .9]).fit(X,y).predict([1,1])
array([1])

匿名用户

scikit学习中的阈值对于二元分类是0.5,对于多类分类,无论哪个类具有最大的概率。在许多问题中,通过调整阈值可以获得更好的结果。然而,这必须小心进行,而不是对坚持测试数据,而是对培训数据进行交叉验证。如果您对测试数据进行任何阈值调整,那么您只是过度拟合了测试数据。

大多数调整阈值的方法都是基于接收机工作特性(ROC)和Youden的J统计,但也可以通过其他方法进行调整,如使用遗传算法进行搜索。

以下是一篇同行评议杂志文章,描述了在医学中的这一做法:

http://www.ncbi.nlm.nih.gov/pmc/articles/PMC2515362/

据我所知,在Python中没有这样做的包,但是在Python中用暴力搜索找到它相对简单(但是效率低下)。

这是一些R代码。

## load data
DD73OP <- read.table("/my_probabilites.txt", header=T, quote="\"")

library("pROC")
# No smoothing
roc_OP <- roc(DD73OP$tc, DD73OP$prob)
auc_OP <- auc(roc_OP)
auc_OP
Area under the curve: 0.8909
plot(roc_OP)

# Best threshold
# Method: Youden
#Youden's J statistic (Youden, 1950) is employed. The optimal cut-off is the threshold that maximizes the distance to the identity (diagonal) line. Can be shortened to "y".
#The optimality criterion is:
#max(sensitivities + specificities)
coords(roc_OP, "best", ret=c("threshold", "specificity", "sensitivity"), best.method="youden")
#threshold specificity sensitivity 
#0.7276835   0.9092466   0.7559022