机器学习-AUC-ROC-python

2021-02-04 11:15:17 浏览数 (2)

简介:

ROC(receiver operating characteristic curve):简称接收者操作特征曲线,是由二战中的电子工程师和雷达工程师发明的,主要用于检测此种方法的准确率有多高。图示:

如下图,其中class 0-5代表6种方法,或者6种手段,横轴为假阳性率,纵轴为真阳性率,越靠近左上方代表此种方法越准确。ROC代表曲线,而AUC代表一条曲线与下方以及右侧轴形成的面积。如果某种方法的准确率为100%,则AUC=1×1=1,AUC的区间在0-1之间,越大越好。

####实现:

>用一个官方例子来实现,鸢尾花的相关的分类问题,使用SVM来判断鸢尾花属于哪一类

>样本:鸢尾花分为三类,‘setosa’, ‘versicolor’, ‘virginica。4个特征:sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’ 150个样本,每类50个

环境:python3 conda sklearn 上述都是主要的

步骤:

代码语言:javascript复制
conda install  sklearn 
#安装sklearn
ipython qtconsole
#启动ipython的IDE
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn import svm, datasets
from sklearn.metrics import roc_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
#导入各种包,作用下面会说
iris = datasets.load_iris()
X = iris.data
y = iris.target
#导入鸢尾花的数据集,并且设定X和y,X指的是各种特征的数据,y指的是分类结果。他们均是np.array形式。
y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]
#n_classes为有几种分类,这里的n_classes为3
random_state = np.random.RandomState(0)
#设置随机数
n_samples, n_features = X.shape
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
#这里设置样本以及特征,n_samples为150,n_features为4

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5,  random_state=0)
#将数据集分为训练集和测试集,比例为1:1
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, random_state=random_state))
#设置一个svm的分类器
y_score = classifier.fit(X_train,y_train).decision_function(X_test)
#在数据集上运行,通过decision_function()计算得到的y_score的值,用在roc_curve()函数
y_score = classifier.fit(X_train, y_train).decision_function(X_test)
# 计算ROC
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
#显示到当前界面,保存为svm.png
plt.figure()
lw = 2
plt.plot(fpr[2], tpr[2], color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[2])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
plt.savefig('svm.png')

结果:

0 人点赞