pyplot做PR-curve

2020-10-10 10:35:59 浏览数 (1)

我们可以用sklearn.metrics中的precision_recall_curve()和auc()计算出PR-AUC,然后用matplotlib.pyplot画出PR-curve:

代码语言:javascript复制
from sklearn.metrics import precision_recall_curve, auc

clf = LogisticRegression()
clf.fit(X_train, y_train)
y_pred_proba = clf.predict_proba(X_test)[::,1]

clf.fit(X_train_e, y_train_e)
y_pred_proba_e = clf.predict_proba(X_test_e)[::,1]

precision, recall, thresholds = precision_recall_curve(y_test,  y_pred_proba)
precision_e, recall_e, thresholds_e = precision_recall_curve(y_test_e,  y_pred_proba_e)

pr_auc = auc(recall, precision)
pr_auc_e = auc(recall_e, precision_e)

plt.plot(recall,precision, color = 'blue')
plt.plot(recall_e,precision_e, color = 'red')
plt.title('Precision/Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend(loc="upper right", labels = ['Full model {}'.format(round(pr_auc, 2)), 'Expression only model {}'.format(round(pr_auc_e, 2))])
plt.show()

值得注意的是,对于特别不平衡的样本,虽然ROC-AUC可能会很好看,但是PR-AUC多半很一般,甚至很不好,上采样和下采样是非常有必要的,另外不要被ROC_AUC所蒙蔽。

欢迎关注!

0 人点赞