磐创AI分享
作者 | Eryk Lewinson 编译 | VK 来源 | Towards Data Science
决策树是一类非常重要的机器学习模型,也是许多更高级算法的组成部分,如随机林或著名的XGBoost。这些树也是基线模型的良好起点,我们随后尝试使用更复杂的算法对其进行改进。
决策树的最大优点之一是它的可解释性——在拟合模型之后,它是一组有效的规则,可以用来预测目标变量。这也是为什么很容易绘制规则并将其展示给涉众,这样他们就可以很容易地理解模型的底层逻辑。当然,只要树不太深。
使用scikitlearn和matplotlib的组合,可视化决策树非常简单。然而,有一个很好的名为dtreeviz的库,它带来了更多内容,可以创建了不仅更漂亮而且能传达更多决策过程信息的可视化效果。
在本文中,我将首先展示绘制决策树的“旧方法”,然后介绍使用dtreeviz的改进方法。
安装程序
一如既往,我们需要从导入所需的库开始。
代码语言:javascript复制import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris, load_boston
from sklearn import tree
from dtreeviz.trees import *
然后,我们从scikit learn加载Iris数据集。我们还将讨论一个回归示例,但稍后将为此加载波士顿住房数据集。
代码语言:javascript复制# 加载数据集
iris = load_iris()
boston = load_boston()
“老办法”
下一步包括创建训练/测试集,并将决策树分类器与iris数据集相匹配。在本文中,我们只关注可视化决策树。因此,我们不注意拟合模型或寻找一组好的超参数(关于这些主题的文章很多)。我们唯一要“调整”的是树的最大深度—我们将其限制为3,这样树仍然可以适应图像并保持可读性。
代码语言:javascript复制# 准备数据
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 拟合
clf = tree.DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X_train, y_train)
现在我们有了一个合适的决策树模型,我们可以继续可视化的树。我们从最简单的方法开始-使用scikit learn中的plot_tree函数。
代码语言:javascript复制tree.plot_tree(clf);
好吧,这也不错。但是它的可读性不强,例如,没有特征名称(只有它们的列索引)或类标签。我们可以通过运行以下代码片段轻松地改进这一点。
代码语言:javascript复制tree.plot_tree(clf,
feature_names = iris.feature_names,
class_names=iris.target_names,
rounded=True,
filled = True);
好多了!现在,我们可以很容易地解释决策树。也可以使用graphviz库来可视化决策树,但是,结果非常相似,具有与上图相同的元素集。这就是为什么我们将在这里跳过它。
dtreeviz
在了解了绘制决策树的老方法之后,让我们直接进入dtreeviz方法。
代码语言:javascript复制viz = dtreeviz(clf,
x_data=X_train,
y_data=y_train,
target_name='class',
feature_names=iris.feature_names,
class_names=list(iris.target_names),
title="Decision Tree - Iris data set")
viz
代码片段几乎是不言自明的,因此我们可以继续讨论结果。首先,让我们花一点时间来确认它有多大的改进,特别是考虑到函数调用非常相似。
让我们一步一步地看图表。在每个节点上,我们都可以看到用于分割观测值的特征的堆叠直方图,并按类别着色。
通过这种方式,我们可以看到类是如何通过来分割的。x轴的小三角形是拆分点。在第一个柱状图中,我们可以清楚地看到,所有观察到的刚毛类的花瓣长度都小于2.45厘米。
树的右分支表示选择大于或等于拆分值的值,而左分支表示选择小于拆分值的值。叶节点用饼图表示,饼图显示叶中的观察值属于哪个类。这样,我们就可以很容易地看到哪个类是最主要的,所以也可以看到模型的预测。
在这张图上,我们没有看到的是每个节点的基尼系数。在我看来,柱状图提供了更多关于分割的直观信息,在向利益相关者呈现的情况下,基尼的值可能没有那么重要。
注意:我们也可以为测试集创建一个类似的可视化,我们只需要在调用函数时替换x_data和y_data参数。
如果你不喜欢直方图并且希望简化绘图,可以指定fancy=False来接收以下简化绘图。
dtreeviz的另一个方便的功能是提高模型的可解释性,即在绘图上突出显示特定观测值的路径。通过这种方式,我们可以清楚地看到哪些特征有助于类预测。
使用下面的代码片段,我们突出显示测试集的第一个样本的路径。
代码语言:javascript复制viz = dtreeviz(clf,
x_data=X_train,
y_data=y_train,
target_name='class',
feature_names=iris.feature_names,
class_names=list(iris.target_names),
title="Decision Tree - Iris data set",
#orientation="LR",
X=X_test[0])
viz
这张图与前一张非常相似,然而,橙色突出显示清楚地显示了样本所遵循的路径。此外,我们可以在每个直方图上看到橙色三角形。它表示给定特征的观察值。最后,我们看到了这个样本的所有特征的值,用于决策的特征用橙色突出显示。在这种情况下,只有两个特征被用来预测观察属于花色类。
提示:我们还可以通过设置orientation=“LR”从上到下再从左到右更改绘图的方向。在本文中我们不展示它,因为对于屏幕较窄的设备,图表的缩放效果不会很好。
最后,我们可以用通俗易懂的英语打印这个观察预测所用的决定。为此,我们运行以下命令。
代码语言:javascript复制print(explain_prediction_path(clf, X_test[0],
feature_names=iris.feature_names,
explanation_type="plain_english"))
# 2.45 <= petal length (cm) < 4.75
# petal width (cm) < 1.65
这样,我们就可以清楚地看到这个观察所满足的条件。
回归示例
我们已经介绍了一个分类示例,它显示了库的大多数有趣的功能。但为了完整性起见,我们还讨论了一个回归问题的例子,来说明曲线图是如何不同的。我们使用另一个流行的数据集——波士顿住房数据集。我们使用一组不同的地区来预测波士顿某些地区的房价中值。
代码语言:javascript复制# 准备数据
X = boston.data
y = boston.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 拟合
reg = tree.DecisionTreeRegressor(max_depth=2, random_state=42)
reg.fit(X_train, y_train)
# 绘图
viz = dtreeviz(reg,
x_data=X_train,
y_data=y_train,
target_name='price',
feature_names=boston.feature_names,
title="Decision Tree - Boston housing",
show_node_labels = True)
viz
代码已经让人感觉很相似了。唯一的变化是我们添加了show_node_labels = True。对于较大的决策树,它尤其方便。
让我们深入研究分类树和回归树之间的区别。这一次,我们不看直方图,而是检查用于分割和目标的特征散点图。在这些散点图上,我们看到一些虚线。其解释如下:
- 水平线是决策节点中左右边的目标平均值。
- 垂直线是分割点。它与黑色三角形表示的信息完全相同。
在叶节点中,虚线表示叶内目标的平均值,这也是模型的预测。
我们已经展示了我们可以突出某个观察的决策路径。我们可以更进一步,只绘制用于预测的节点。为此,我们指定show_just_path=True。下图仅显示上面树中选定的节点。
结论
在本文中,我演示了如何使用dtreeviz库来创建决策树的优雅而有见地的可视化。玩了一段时间之后,我肯定会继续使用它作为可视化决策树的工具。我相信使用这个库创建的图对于那些不经常使用ML的人来说更容易理解,并且可以帮助向涉众传达模型的逻辑。
还值得一提的是,dtreeviz支持XGBoost和Spark MLlib树的一些可视化。
你可以在我的GitHub上找到本文使用的代码:https://github.com/erykml/medium_articles/blob/master/Machine Learning/decision_tree_visualization.ipynb
如果你喜欢这篇文章,你可能还对以下内容之一感兴趣:
https://towardsdatascience.com/improve-the-train-test-split-with-the-hashing-function-f38f32b721fb
https://towardsdatascience.com/lazy-predict-fit-and-evaluate-all-the-models-from-scikit-learn-with-a-single-line-of-code-7fe510c7281
https://towardsdatascience.com/explaining-feature-importance-by-example-of-a-random-forest-d9166011959e
参考引用
https://github.com/parrt/dtreeviz
https://explained.ai/decision-tree-viz/index.html