代码语言:javascript复制KNN决策树探究泰坦尼克号幸存者问题
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import classification_report
import graphviz #决策树可视化
代码语言:javascript复制data = pd.read_csv(r"titanic_data.csv")
data.drop("PassengerId",axis = 1,inplace = True) #删除id这一列
代码语言:javascript复制data
Survived | Pclass | Sex | Age | |
---|---|---|---|---|
0 | 0 | 3 | male | 22.0 |
1 | 1 | 1 | female | 38.0 |
2 | 1 | 3 | female | 26.0 |
3 | 1 | 1 | female | 35.0 |
4 | 0 | 3 | male | 35.0 |
... | ... | ... | ... | ... |
886 | 0 | 2 | male | 27.0 |
887 | 1 | 1 | female | 19.0 |
888 | 0 | 3 | female | NaN |
889 | 1 | 1 | male | 26.0 |
890 | 0 | 3 | male | 32.0 |
891 rows × 4 columns
代码语言:javascript复制data.loc[data["Sex"] == "male","Sex"] = 1
data.loc[data["Sex"] == "female","Sex"] = 0
代码语言:javascript复制data
Survived | Pclass | Sex | Age | |
---|---|---|---|---|
0 | 0 | 3 | 1 | 22.0 |
1 | 1 | 1 | 0 | 38.0 |
2 | 1 | 3 | 0 | 26.0 |
3 | 1 | 1 | 0 | 35.0 |
4 | 0 | 3 | 1 | 35.0 |
... | ... | ... | ... | ... |
886 | 0 | 2 | 1 | 27.0 |
887 | 1 | 1 | 0 | 19.0 |
888 | 0 | 3 | 0 | NaN |
889 | 1 | 1 | 1 | 26.0 |
890 | 0 | 3 | 1 | 32.0 |
891 rows × 4 columns
代码语言:javascript复制data.fillna(data["Age"].mean(),inplace = True) #用均值来填充缺失值
代码语言:javascript复制data
Survived | Pclass | Sex | Age | |
---|---|---|---|---|
0 | 0 | 3 | 1 | 22.000000 |
1 | 1 | 1 | 0 | 38.000000 |
2 | 1 | 3 | 0 | 26.000000 |
3 | 1 | 1 | 0 | 35.000000 |
4 | 0 | 3 | 1 | 35.000000 |
... | ... | ... | ... | ... |
886 | 0 | 2 | 1 | 27.000000 |
887 | 1 | 1 | 0 | 19.000000 |
888 | 0 | 3 | 0 | 29.699118 |
889 | 1 | 1 | 1 | 26.000000 |
890 | 0 | 3 | 1 | 32.000000 |
891 rows × 4 columns
代码语言:javascript复制Dtc = DecisionTreeClassifier(max_depth = 5,random_state =8) #构建决策树
Dtc.fit(data.iloc[:,1:],data["Survived"]) #模型训练
pre = Dtc.predict(data.iloc[:,1:]) #模型预测
代码语言:javascript复制print(classification_report(pre,data["Survived"])) #混淆矩阵
代码语言:javascript复制 precision recall f1-score support
0 0.88 0.84 0.86 573
1 0.73 0.79 0.76 318
accuracy 0.82 891
macro avg 0.81 0.82 0.81 891
weighted avg 0.83 0.82 0.82 891
代码语言:javascript复制pre == data["Survived"] #比较模型预测值与实际值是否一致
代码语言:javascript复制0 True
1 True
2 True
3 True
4 True
...
886 True
887 True
888 False
889 False
890 True
Name: Survived, Length: 891, dtype: bool
可视化
代码语言:javascript复制dot_data = export_graphviz(Dtc,feature_names = ["Pclass","Sex","Age"],class_names="Survive")
代码语言:javascript复制graph = graphviz.Source(dot_data)
graph