通过 sklearn 加载数据集
在 scikit-learn 的 datasets 模块中,包含很多机器学习和统计学中的经典数据集。
代码语言:javascript复制import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn import datasets
我们可以调用 load_iris
函数来加载鸢尾花数据集。
iris = datasets.load_iris()
调用 load_iris
函数返回的 iris 是 Bunch 对象,Bunch 对象是 sklearn 对数据集进行进一步封装的数据类型。Bunch 和 Python 内置的字典非常类似,也包含建和值。
print(iris.keys())
'''
dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename'])
'''
其中:
- data - 包含花萼长度、花萼宽度、花瓣长度、花瓣宽度的测量数据, 形状为 (150, 4) 的 ndarray 数组(NumPy 数组)。如果加载数据集时设置参数
as_frame = True
,则返回的是 DataFrame(Pandas 中的数据结构)。 - target - 包含每朵鸢尾花的品种,形状为 (150, ) 的 ndarray 数组(NumPy 数组)。如果加载数据集时设置参数
as_frame = True
,则返回的是 Series(Pandas 中的数据结构)。 - frame - 返回形状为 (150, 5) 的 DataFram,只有当加载数据集时指定
as_frame = True
,才会返回包含 data 和 target 的 DataFram; - target_names - 类别标签的名字,包括 setosa, versicolor 和 virginica 三种鸢尾花类别;
- feature_names - 对特征的说明,一共有四个特征,分别是花萼长度、花萼宽度、花瓣长度、花瓣宽度;
- DESCR - 对数据集的简要说明;
- filename - 加载的数据集在本地保存的文件路径。
DESCR 属性可以查看数据集的简要说明。
代码语言:javascript复制print(iris.DESCR)
'''
.. _iris_dataset:
Iris plants dataset
--------------------
Many, many more ...
'''
data 属性返回的是形状为 (150, 4) 的二维数组,每一行代表一朵鸢尾花(一个样本),每一列表示鸢尾花的一个属性(一个特征)。
代码语言:javascript复制print(iris.data.shape)
'''
(150, 4)
'''
feature_names 属性可以查看 4 个特征具体的含义,也就是上面提到的花萼长度、花萼宽度、花瓣长度、花瓣宽度。
代码语言:javascript复制print(iris.feature_names)
'''
['sepal length (cm)'
, 'sepal width (cm)'
, 'petal length (cm)'
, 'petal width (cm)']
'''
target 属性返回形状为 (150, ) 的一维数组,每一行代表一朵鸢尾花,对应的值为鸢尾花所属的类别。target_names 属性可以查看类别的具体含义,也就是上面提到的 setosa, versicolor 和 virginica 三种鸢尾花类别
代码语言:javascript复制print(iris.target.shape)
'''
(150,)
'''
print(iris.target_names)
'''
['setosa' 'versicolor' 'virginica']
'''
简单的数据探索
好的数据能够帮助我们更好的泛化机器学习模型,所以在构建机器学习模型之前,通常需要对数据进行检查和探索。通过可视化的方式来检查和探索数据是机器学习中比较常用的方法。
对于分类问题,通常会绘制散点图,将其中一个特征作为横坐标轴,将另一个特征作为纵坐标轴,而将样本的类别用不同颜色或样式进行区分。鸢尾花数据集一共有 4 个特征,为了使用散点图进行可视化,选择前两个特征进行可视化。
代码语言:javascript复制X = iris.data[:, :2] # 选取样本的前两个特征
y = iris.target
plt.scatter(X[y == 0, 0], X[y == 0, 1], color = "red")
plt.scatter(X[y == 1, 0], X[y == 1, 1], color = "blue")
plt.scatter(X[y == 2, 0], X[y == 2, 1], color = "green")
plt.show()
这里的 X[y == 0, 0]
使用了前面介绍的 Fancy Indexing 和比较运算,其中:
y == 0
返回一个形状为 (150, ) 的布尔数组,X 和 y 相同位置表示的是相同的鸢尾花,如果鸢尾花的类别为第一类,则对应位置为 True,否则为 False;X[y == 0]
使用的是 Fancy Indexing,通过布尔数组筛选出类别为第一类的鸢尾花的测量数据;X[y == 0, 0]
筛选出类别为第一类的鸢尾花的测量数据的第一个特征。
marker 参数可以指定散点图中点的样式,更多样式可以查看 官网文档。
代码语言:javascript复制plt.scatter(X[y == 0, 0], X[y == 0, 1], color = "red", marker = "o")
plt.scatter(X[y == 1, 0], X[y == 1, 1], color = "blue", marker = " ")
plt.scatter(X[y == 2, 0], X[y == 2, 1], color = "green", marker = "x")
plt.show()
这里只使用鸢尾花的前两个特征,通过绘制的散点图可以看出红色的第一类和其余的蓝色第二类和绿色的第三类有着比较明显的边界线,但是蓝色第二类和绿色的第三类之间没有明显的边界限,不容易区分开。
如果使用鸢尾花数据集的后两个特征。
代码语言:javascript复制X = iris.data[:, 2:] # 选取样本的后两个特征
plt.scatter(X[y == 0, 0], X[y == 0, 1], color = "red", marker = "o")
plt.scatter(X[y == 1, 0], X[y == 1, 1], color = "blue", marker = " ")
plt.scatter(X[y == 2, 0], X[y == 2, 1], color = "green", marker = "x")
plt.show()
通过绘制结果可以看出,红色第一类的与其余两类依然比较容易区分,虽然蓝色第二类和绿色第三类之间依然不太容易区分,但是相比较使用前两个特征有了明显的改善。
如果特征不是太多,我们可以绘制散点图矩阵对所有特征进行两两的可视化,虽然依然无法查看两个以上特征之间的联系,但是散点图矩阵也能够帮助我们更好的了解数据。
在 Pandas 中,scatter_matrix 函数能够绘制散点图矩阵。不过要使用 scatter_matrix 函数,我们需要将 ndarray 数组(NumPy 数组)转换为 DataFrame,或者在加载鸢尾花数据集的时候直接指定 as_frame = True
,并使用 frame 属性返回鸢尾花数据集的 DataFrame 类型。
import pandas as pd
from pandas.plotting import scatter_matrix
import mglearn
from sklearn import datasets
iris = datasets.load_iris(as_frame = True)
iris_dataframe = iris.frame # 返回DataFrame
iris_data = iris_dataframe.loc[:, ('sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)')]
iris_target = iris_dataframe.loc[:, ('target')]
代码语言:javascript复制grr = scatter_matrix(iris_data
, c=iris_target
, figsize=(15, 15)
, marker='o'
, hist_kwds={'bins': 20}
, s=60
, alpha=.8
, cmap=mglearn.cm3)
References:
Python3入门机器学习 经典算法与应用: https://coding.imooc.com/class/chapter/169.html#Anchor