一文让你了解AI产品的测试 评价人工智能算法模型的几个重要指标(续)

2020-06-10 16:07:09 浏览数 (1)

程序的实现

前面讲课那么多指标,其实在Python里面可以利用sklearn这个插件快速的画出这些指标和算法。利用这个工具之前当然需要下载安装这个插件。

>pip3 install sklearn

下面来讲解一下这个代码。

代码语言:javascript复制

# coding=UTF-8
from sklearn import metrics
from sklearn.metrics import  confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
import matplotlib.pylab as plt
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import  precision_recall_curve
 
#真实值
GTlist =  [1,1,0,1,1,0,1,0,0,1]
#模型预测值
Problist = [1,0,1,1,1,1,1,1,0,1]
y_true = np.array(GTlist)
y_pred = np.array(Problist)
#混淆矩阵
confusion_matrix =  confusion_matrix(y_true, y_pred)
print("混淆矩阵:")
print(confusion_matrix)
 
#准确性
accuracy = '{:.1%}'.format(accuracy_score(y_true,  y_pred))
print("准确性:",end='')
print(accuracy)
 
 
#精确性
precision =  '{:.1%}'.format(precision_score(y_true, y_pred))
print("精确性:",end='')
print(precision)
 
#召回率
recall =  '{:.1%}'.format(recall_score(y_true, y_pred))
print("召回率:",end='')
print(recall)
 
#F1值
f1score = '{:.1%}'.format(f1_score(y_true,  y_pred))
print("F1值:",end='')
print(f1score)
 
#初始化画图数据
#真实值
GTlist = [1.0, 1.0, 0.0, 1.0, 0.0, 1.0,  0.0, 1.0, 0.0, 1.0,0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
#模型预测值
Problist = [0.99, 0.98, 0.97, 0.93, 0.85,  0.80, 0.79, 0.75, 0.70, 0.65,0.64, 0.63, 0.55, 0.54, 0.51, 0.49, 0.30, 0.2,  0.1, 0.09]
fpr, tpr, thresholds =  metrics.roc_curve(GTlist, Problist, pos_label=1)
roc_auc = metrics.auc(fpr, tpr)  #auc为Roc曲线下的面积
print("AUC值:",end='')
print('{:.1%}'.format(roc_auc))
#ROC曲线
plt.plot(fpr, tpr, 'b',label='AUC =  %0.2f'% roc_auc)
plt.legend(loc='lower right')
# plt.plot([0, 1], [0, 1], 'r--')
plt.xlim([-0.1, 1.1])
plt.ylim([-0.1, 1.1])
plt.xlabel('False Positive Rate') #横坐标是fpr
plt.ylabel('True Positive Rate')  #纵坐标是tpr
plt.title('Receiver operating  characteristic example')
plt.show()
 
#P-R曲线 
plt.figure("P-R Curve")
plt.title('Precision/Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
#y_true为样本实际的类别,y_scores为样本为正例的概率
y_true = np.array(GTlist)
y_scores = np.array(Problist)
precision, recall, thresholds =  precision_recall_curve(y_true, y_scores)
plt.plot(recall,precision)
plt.show()

真实值GTlist =[1,1,0,1,1,0,1,0,0,1]

模型预测值Problist= [1,0,1,1,1,1,1,1,0,1]

现在有10位病人来看病,其中3号、6号、8号和9号病人是没有疾病的(绿色),其他剩余6位有疾病(红色)。

编号

1

2

3

4

5

6

7

8

9

10

实际

1

1

0

1

1

0

1

0

0

1

检查

1

0

1

1

1

1

1

1

0

1

1号、4号、5号、7号和10号病人被查出来(真阳性,红色);2号病人没有被查出来(漏诊,橙色);3号、6号和8号被误诊(误诊,蓝色),另外9号(真隐性,绿色),通过运行这段代码,得到如下结果:

混淆矩阵:

[[1 3]

[1 5]]

准确性:60.0%

精确性:62.5%

召回率:83.3%

F1值:71.4%

我们来验证一下,真阳性:5、真阴性:1、假阳性:3、假阴性:1,所以混淆矩阵为:

预测

实际

1

3

1

5

由此,可以看出算出来的矩阵与正式的矩阵的对应关系。假在前,真在后,一行代表实际中的实际中的一行。

准确性:(1 5)/10=60%

精确性:5/8=62.5%

召回率:5/6=83.3%

F1 Score=62.5%×83.3%×2/(62.5% 83.3%)=1.04125/1.458=71%

可见这些值都是正确的。接下来再看下面的数据。

#真实值

GTlist = [1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0,1.0, 0.0, 1.0,0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]

#模型预测值

Problist = [0.99, 0.98, 0.97, 0.93, 0.85,0.80, 0.79, 0.75, 0.70, 0.65,0.64, 0.63, 0.55, 0.54, 0.51, 0.49, 0.30, 0.2,0.1, 0.09]

GTlist表示真实样本,1.0代表真样本,0.0代表假样本;

Problist表示预测样本,每个值表示预测到对应真实样本为真的概率。比如第一个0.99表示预测第一个正样本的概率为99%,第三个0.97表示预测第三个假样本的概率为97%。通过运行我们得到如下曲线图。

我们考察A(0,0)、B(1,1)、C(0,1)、D(1,0)四个点:

  • A(0,0):表示真阳率与假阳率均为0,表示什么都没有测试到;
  • B(1,1):表示真阳率与假阳率均为100%;
  • C(0,1):真阳率为100%,假阳率均为0,测试到的全是真的;
  • D(1,0):真阳率为0,假阳率均为100%,测试到的全是假的。

由此可见C点的情况最高,所以曲线越靠近左上角说明算法最好。

另外,上面代码也会给出了化P-R图的方法,对于ROC曲线,采用同一个测试数据,画出来的图如下显示。

软件安全测试

https://study.163.com/course/courseMain.htm?courseId=1209779852&share=2&shareId=480000002205486

接口自动化测试

https://study.163.com/course/courseMain.htm?courseId=1209794815&share=2&shareId=480000002205486

DevOps 和Jenkins之DevOps

https://study.163.com/course/courseMain.htm?courseId=1209817844&share=2&shareId=480000002205486

DevOps与Jenkins 2.0之Jenkins

https://study.163.com/course/courseMain.htm?courseId=1209819843&share=2&shareId=480000002205486

Selenium自动化测试

https://study.163.com/course/courseMain.htm?courseId=1209835807&share=2&shareId=480000002205486

性能测试第1季:性能测试基础知识

https://study.163.com/course/courseMain.htm?courseId=1209852815&share=2&shareId=480000002205486

性能测试第2季:LoadRunner12使用

https://study.163.com/course/courseMain.htm?courseId=1209980013&share=2&shareId=480000002205486

性能测试第3季:JMeter工具使用

https://study.163.com/course/courseMain.htm?courseId=1209903814&share=2&shareId=480000002205486

性能测试第4季:监控与调优

https://study.163.com/course/courseMain.htm?courseId=1209959801&share=2&shareId=480000002205486

Django入门

https://study.163.com/course/courseMain.htm?courseId=1210020806&share=2&shareId=480000002205486

啄木鸟顾老师漫谈软件测试

https://study.163.com/course/courseMain.htm?courseId=1209958326&share=2&shareId=480000002205486

0 人点赞