MNIST数据集:二分类问题
MNIST数据集是一组由美国高中生和人口调查局员工手写的70,000个数字的图片,每张图片上面有代表的数字标记。
这个数据集被广泛使用,被称之为机器学习领域的“Hello World”,主要是被用于分类问题。本文是对MNIST数据集执行一个二分类的建模
关键词:随机梯度下降、二元分类、混淆矩阵、召回率、精度、性能评估
导入数据
在这里是将一份存放在本地的mat文件的数据导进来:
In [1]:
代码语言:javascript复制import pandas as pd
import numpy as np
import scipy.io as si
# from sklearn.datasets import fetch_openml
In [2]:
代码语言:javascript复制mnist = si.loadmat('mnist-original.mat')
In [3]:
代码语言:javascript复制type(mnist) # 查看数据类型
Out[3]:
代码语言:javascript复制dict
In [4]:
代码语言:javascript复制mnist.keys()
Out[4]:
代码语言:javascript复制dict_keys(['__header__', '__version__', '__globals__', 'mldata_descr_ordering', 'data', 'label'])
我们发现导进来的数据是一个字典。其中data和label两个键的值就是我们想要的特征和标签数据
创建特征和标签
In [5]:
代码语言:javascript复制# 修改1:一定要转置
X, y = mnist["data"].T, mnist["label"].T
X.shape
Out[5]:
代码语言:javascript复制(70000, 784)
总共是70000张图片,每个图片中有784个特征。图片是28*28的像素,所以每个特征代表一个像素点,取值从0-255。
In [6]:
代码语言:javascript复制y.shape
Out[6]:
代码语言:javascript复制(70000, 1)
In [7]:
代码语言:javascript复制y # 每个图片有个专属的数字
Out[7]:
代码语言:javascript复制array([[0.],
[0.],
[0.],
...,
[9.],
[9.],
[9.]])
显示一张图片
In [8]:
代码语言:javascript复制import matplotlib as mpl
import matplotlib.pyplot as plt
one_digit = X[0]
one_digit_image = one_digit.reshape(28, 28)
plt.imshow(one_digit_image, cmap="binary")
plt.axis("off")
plt.show()
In [9]:
代码语言:javascript复制y[0] # 真实的标签的确是0
Out[9]:
代码语言:javascript复制array([0.]) # 结果是0
标签类型转换
元数据中标签是字符串,我们需要转成整数类型
In [10]:
代码语言:javascript复制y.dtype
Out[10]:
代码语言:javascript复制dtype('<f8')
In [11]:
代码语言:javascript复制y = y.astype(np.uint8)
创建训练集和测试集
前面的6万条是训练集,后面的1万条是测试集
In [12]:
代码语言:javascript复制X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
二元分类器
比如现在有1张图片,显示是0,我们识别是:“0和非0”,两种情形即可,这就是简单的二元分类问题
In [13]:
代码语言:javascript复制y_train_0 = (y_train == 0) # 挑选出5的部分
y_test_0 = (y_test == 0)
随机梯度下降分类器SGD
使用scikit-learn自带的SGDClassifier分类器:能够处理非常大型的数据集,同时SGD适合在线学习
In [14]:
代码语言:javascript复制from sklearn.linear_model import SGDClassifier
sgd_c = SGDClassifier(random_state=42) # 设置随机种子,保证运行结果相同
sgd_c.fit(X_train, y_train_0)
/Applications/downloads/anaconda/anaconda3/lib/python3.7/site-packages/sklearn/utils/validation.py:993: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
y = column_or_1d(y, warn=True)
Out[14]:
代码语言:javascript复制SGDClassifier(random_state=42)
结果验证
在这里我们检查下数字0的图片:结果为True
In [15]:
代码语言:javascript复制sgd_c.predict([one_digit]) # one_digit是0,非5 表示为False
Out[15]:
代码语言:javascript复制array([ True])
性能测量1-交叉验证
一般而言,分类问题的评估比回归问题要困难的多。
自定义交差验证(优化)
- 每个折叠由StratifiedKFold执行分层抽样,产生的每个类别中的比例符合原始数据中的比例
- 每次迭代会创建一个分类器的副本,用训练器对这个副本进行训练,然后测试集进行测试
- 最后预测出准确率,输出正确的比例
In [16]:
代码语言:javascript复制# K折交叉验证
from sklearn.model_selection import StratifiedKFold
# 用于生成分类器的副本
from sklearn.base import clone
# 实例化对象
k_folds = StratifiedKFold(
n_splits = 3, # 3折
shuffle=True, # add 一定要设置shuffle才能保证random_state生效
random_state=42
)
# 每个折叠由StratifiedKFold执行分层抽样
for train_index, test_index in k_folds.split(X_train, y_train_0):
# 分类器的副本
clone_c = clone(sgd_c)
X_train_folds = X_train[train_index] # 训练集的索引号
y_train_folds = y_train_0[train_index]
X_test_fold = X_train[test_index] # 测试集的索引号
y_test_fold = y_train_0[test_index]
clone_c.fit(X_train_folds, y_train_folds) # 模型训练
y_pred = clone_c.predict(X_test_fold) # 预测
n_correct = sum(y_pred == y_test_fold) # 预测准确的数量
print(n_correct / len(y_pred)) # 预测准确的比例
运行的结果如下:
代码语言:javascript复制[0.09875 0.09875 0.09875 ... 0.90125 0.90125 0.90125]
[0.0987 0.0987 0.0987 ... 0.9013 0.9013 0.9013]
[0.0987 0.0987 0.0987 ... 0.9013 0.9013 0.9013]
scikit_learn的交叉验证
使用cross_val_score来评估分类器:
In [17]:
代码语言:javascript复制# 评估分类器的效果
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_c, # 模型
X_train, # 数据集
y_train_0,
cv=3, # 3折
scoring="accuracy" # 准确率
)
# 结果
array([0.98015, 0.95615, 0.9706 ])
可以看到准确率已经达到了95%以上,效果是相当的可观
自定义一个“非0”的简易分类器,看看效果:
In [18]:
代码语言:javascript复制from sklearn.base import BaseEstimator # 基分类器
class Never0Classifier(BaseEstimator):
def fit(self, X, y=None):
return self
def predict(self, X):
return np.zeros((len(X), 1), dtype=bool)
In [19]:
代码语言:javascript复制never_0_clf = Never0Classifier()
cross_val_score(
never_0_clf, # 模型
X_train, # 训练集样本
y_train_0, # 训练集标签
cv=3, # 折数
scoring="accuracy"
)
Out[19]:
代码语言:javascript复制array([0.70385, 1. , 1. ])
In [20]:
统计数据中每个字出现的次数:
代码语言:javascript复制pd.DataFrame(y).value_counts()
Out[20]:
代码语言:javascript复制1 7877
7 7293
3 7141
2 6990
9 6958
0 6903
6 6876
8 6825
4 6824
5 6313
dtype: int64
In [21]:
代码语言:javascript复制6903 / 70000
Out[21]:
下面显示大约有10%的概率是0这个数字
代码语言:javascript复制0.09861428571428571
In [22]:
代码语言:javascript复制(0.70385 1 1) / 3
Out[22]:
代码语言:javascript复制0.9012833333333333
可以看到判断“非0”准确率基本在90%左右,因为只有大约10%的样本是属于数字0。
所以如果猜测一张图片是非0,大约90%的概率是正确的。
性能测量2-混淆矩阵
预测结果
评估分类器性能更好的方法是混淆矩阵,总体思路是统计A类别实例被划分成B类别的次数
混淆矩阵是通过预测值和真实目标值来进行比较的。
cross_val_predict函数返回的是每个折叠的预测结果,而不是评估分数
In [23]:
代码语言:javascript复制from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(
sgd_c, # 模型
X_train, # 特征训练集
y_train_0, # 标签训练集
cv=3 # 3折
)
y_train_pred
Out[23]:
代码语言:javascript复制array([ True, True, True, ..., False, False, False])
混淆矩阵
In [24]:
代码语言:javascript复制# 导入混淆矩阵
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_0, y_train_pred)
Out[24]:
代码语言:javascript复制array([[52482, 1595],
[ 267, 5656]])
混淆矩阵中:行表示实际类别,列表示预测类别
- 第一行表示“非0”:52482张被正确地分为“非0”(真负类),有1595张被错误的分成了“0”(假负类)
- 第二行表示“0”:267被错误地分为“非0”(假正类),有5656张被正确地分成了“0”(真正类)
In [25]:
代码语言:javascript复制# 假设一个完美的分类器:只存在真正类和真负类,它的值存在于对角线上
y_train_perfect_predictions = y_train_0
confusion_matrix(y_train_0, y_train_perfect_predictions)
Out[25]:
代码语言:javascript复制array([[54077, 0],
[ 0, 5923]])
精度和召回率
召回率的公式为:
混淆矩阵显示的内容:
- 左上:真负
- 右上:假正
- 左下:假负
- 右下:真正
精度:正类预测的准确率
召回率(灵敏度或真正类率):分类器正确检测到正类实例的比例
计算精度和召回率
In [26]:
代码语言:javascript复制from sklearn.metrics import precision_score, recall_score
precision_score(y_train_0, y_train_pred) # 精度
Out[26]:
代码语言:javascript复制0.78003034064267
In [27]:
代码语言:javascript复制recall_score(y_train_0, y_train_pred) # 召回率
Out[27]:
代码语言:javascript复制0.9549214924869154
F_1系数
F_1系数是精度和召回率的谐波平均值。只有当召回率和精度都很高的时候,分类器才会得到较高的F_1分数