基于sklearn的LogisticRegression鸢尾花多类分类实践

2020-07-13 17:38:00 浏览数 (1)

鸢尾花(拼音:yuān wěi huā)又名:蓝蝴蝶、紫蝴蝶、扁竹花等,鸢尾属约300种,原产于中国中部及日本,是法国的国花。鸢尾花主要色彩为蓝紫色,有“蓝色妖姬”的美誉,鸢尾花因花瓣形如鸢鸟尾巴而称之,有蓝、紫、黄、白、红等颜色,英文irises音译俗称为“爱丽丝”

本文使用sklearn的逻辑斯谛回归模型,进行鸢尾花多分类预测,对OvR与OvO多分类方法下的预测结果进行对比。

1. 问题描述

  • 给定鸢尾花的特征数据集(花萼、花瓣的长和宽尺寸)
  • 预测其属于哪个品种(Setosa,Versicolor,Virginica)

2. 数据介绍

代码语言:javascript复制
from sklearn import datasets
iris = datasets.load_iris()
print(dir(iris))    # 查看data所具有的属性或方法
# ['DESCR', 'data', 'feature_names', 'filename', 'target', 'target_names']

我们看见数据有很多属性或方法,我们依次来看一看:

2.1 数据描述

代码语言:javascript复制
print(iris.DESCR)   # 数据描述
  • 数据包含150个(每个类型的花50个)
  • 每个数据里有4个花的尺寸信息(花萼、花瓣的长宽)以及其分类class
  • 描述里给出了4种尺寸信息的(分布区间,均值,方差,分类相关系数)
  • 数据是否缺失某些值,作者,日期,来源,数据应用,参考文献
代码语言:javascript复制
.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
                
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

.. topic:: References

   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...

2.2 数据

代码语言:javascript复制
print(iris.data)	# 特征数据
# 150行4列 <class 'numpy.ndarray'>
代码语言:javascript复制
print(iris.feature_names)	# 特征名称
# ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
代码语言:javascript复制
print(iris.filename)	# 文件路径
C:Users***AppDataRoamingPythonPython37site-packagessklearndatasetsdatairis.csv
代码语言:javascript复制
print(iris.target)	# 分类标签 size 150
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
代码语言:javascript复制
print(iris.target_names)	# 分类名称 3类 花的名称
# ['setosa' 'versicolor' 'virginica']

2.3 数据可视化

由于平面只能展示2维特征,我们取2个特征进行进行观看。

代码语言:javascript复制
def show_data_set(X, y, data):
    plt.plot(X[y == 0, 0], X[y == 0, 1], 'rs', label=data.target_names[0])
    plt.plot(X[y == 1, 0], X[y == 1, 1], 'bx', label=data.target_names[1])
    plt.plot(X[y == 2, 0], X[y == 2, 1], 'go', label=data.target_names[2])
    plt.xlabel(data.feature_names[0])
    plt.ylabel(data.feature_names[1])
    plt.title("鸢尾花2维数据")
    plt.legend()
    plt.rcParams['font.sans-serif'] = 'SimHei'  # 消除中文乱码
    plt.show()
iris = datasets.load_iris()
# print(dir(iris))    # 查看data所具有的属性或方法
# print(iris.data)    # 数据
# print(iris.DESCR)   # 数据描述
X = iris.data[:, :2]  # 取前2列特征sepal(平面只能展示2维)
# X = iris.data[:, 2:4]   # petal两个特征
# X = iris.data  # 全部4个特征
y = iris.target  # 分类
show_data_set(X, y, iris)

3. 模型选择

本人相关文章:

  • 逻辑斯谛回归模型( Logistic Regression,LR)
  • 基于sklearn的LogisticRegression二分类实践

sklearn多类和多标签算法:

  • Multiclass classification 多类分类 意味着一个分类任务需要对多于两个类的数据进行分类。比如,对一系列的橘子,苹果或者梨的图片进行分类。多类分类假设每一个样本有且仅有一个标签:一个水果可以被归类为苹果,也可以是梨,但不能同时被归类为两类。
  • 固有的多类分类器: sklearn.linear_model.LogisticRegression (setting multi_class=”multinomial”)
  • 1对多的多类分类器: sklearn.linear_model.LogisticRegression (setting multi_class=”ovr”)

分类器Classifier方法:

  • One-vs-the-rest (OvR),也叫 one-vs-all,1对多, 在 OneVsRestClassifier 模块中执行。 这个方法在于每一个类都将用一个分类器进行拟合。 对于每一个分类器,该类将会和其他所有的类有所区别。除了它的计算效率之外 (只需要 n_classes 个分类器), 这种方法的优点是它具有可解释性。 因为每一个类都可以通过有且仅有一个分类器来代表,所以通过检查一个类相关的分类器就可以获得该类的信息。这是最常用的方法,也是一个合理的默认选择。
  • One-vs-one (OvO),OneVsOneClassifier 1对1分类器 将会为每一对类别构造出一个分类器,在预测阶段,收到最多投票的类别将会被挑选出来。 当存在结时(两个类具有同样的票数的时候), 1对1分类器会选择总分类置信度最高的类,其中总分类置信度是由下层的二元分类器 计算出的成对置信等级累加而成。 因为这需要训练出 n_classes * (n_classes - 1) / 2 个分类器, 由于复杂度为 O(n_classes^2),这个方法通常比 one-vs-the-rest 慢。然而,这个方法也有优点,比如说是在没有很好的缩放 n_samples 数据的核方法中。 这是由于每个单独的学习问题只涉及一小部分数据,而 one-vs-the-rest 将会使用 n_classes 次完整的数据。OvO准确率会比OvR高。

3.1 固有的多类分类器

  • sklearn.linear_model.LogisticRegression (setting multi_class=”multinomial”)

相关multiclass参数选择的help说明:

In the multiclass case, the training algorithm uses the one-vs-rest (OvR) scheme if the ‘multi_class’ option is set to ‘ovr’, and uses the cross-entropy loss if the ‘multi_class’ option is set to ‘multinomial’. (Currently the ‘multinomial’ option is supported only by the ‘lbfgs’, ‘sag’, ‘saga’ and ‘newton-cg’ solvers.)

multi_class : {‘auto’, ‘ovr’, ‘multinomial’}, default=‘auto’ If the option chosen is ‘ovr’, then a binary problem is fit for each label. For ‘multinomial’ the loss minimised is the multinomial loss fit across the entire probability distribution, even when the data is binary. ‘multinomial’ is unavailable when solver=‘liblinear’. ‘auto’ selects ‘ovr’ if the data is binary, or if solver=‘liblinear’, and otherwise selects ‘multinomial’.

直接设置LogisticRegression的参数:multi_class='multinomial', solver='newton-cg',代码如下:

代码语言:javascript复制
def test1(X_train, X_test, y_train, y_test, multi_class='multinomial', solver='newton-cg'):
    log_reg = LogisticRegression(multi_class=multi_class, solver=solver)  
    # 调用multinomial多分类,求解器 newton-cg or lbfgs
    log_reg.fit(X_train, y_train)
    predict_train = log_reg.predict(X_train)
    sys.stdout.write("LR(multi_class = %s, solver = %s) Train Accuracy : %.4gn" % (
        multi_class, solver, metrics.accuracy_score(y_train, predict_train)))
    predict_test = log_reg.predict(X_test)
    sys.stdout.write("LR(multi_class = %s, solver = %s) Test  Accuracy : %.4gn" % (
        multi_class, solver, metrics.accuracy_score(y_test, predict_test)))
    plot_decision_boundary(4, 8.5, 1.5, 4.5, lambda x: log_reg.predict(x)) # 4个特征下注释掉,前两特征
    # plot_decision_boundary(0.5, 7.5, 0, 3, lambda x: log_reg.predict(x))  # 4个特征下注释掉,后两特征
    plot_data(X_train, y_train)

3.2 1对多的多类分类器

  • sklearn.linear_model.LogisticRegression (setting multi_class=”ovr”)

直接设置LogisticRegression的参数:multi_class='ovr', solver='liblinear'',代码如下:

代码语言:javascript复制
def test1(X_train, X_test, y_train, y_test, multi_class='ovr', solver='liblinear'):
    log_reg = LogisticRegression(multi_class=multi_class, solver=solver)  
    # 调用ovr多分类,设置求解器 liblinear
    log_reg.fit(X_train, y_train)
    predict_train = log_reg.predict(X_train)
    sys.stdout.write("LR(multi_class = %s, solver = %s) Train Accuracy : %.4gn" % (
        multi_class, solver, metrics.accuracy_score(y_train, predict_train)))
    predict_test = log_reg.predict(X_test)
    sys.stdout.write("LR(multi_class = %s, solver = %s) Test  Accuracy : %.4gn" % (
        multi_class, solver, metrics.accuracy_score(y_test, predict_test)))
    plot_decision_boundary(4, 8.5, 1.5, 4.5, lambda x: log_reg.predict(x))  # 4个特征下注释掉,前两特征
    # plot_decision_boundary(0.5, 7.5, 0, 3, lambda x: log_reg.predict(x))  # 4个特征下注释掉,后两特征
    plot_data(X_train, y_train)

3.3 OneVsRestClassifier

代码语言:javascript复制
class sklearn.multiclass.OneVsRestClassifier(estimator, n_jobs=None)

分类器接受一个评估器estimator对象,先定义一个LR模型log_reg,将log_reg传入OvR分类器 ovr = OneVsRestClassifier(log_reg)

代码语言:javascript复制
def test2(X_train, X_test, y_train, y_test):
    #  multi_class默认auto
    # 'auto' selects 'ovr' if the data is binary, or if solver='liblinear',
    #  and otherwise selects 'multinomial'.
    #  看完help知道auto选择的是ovr,因为下面求解器选的是 liblinear
    #  所以test1和test2是同种效果,不一样的写法
    log_reg = LogisticRegression(solver='liblinear')
    ovr = OneVsRestClassifier(log_reg)	# 传入LR至OvR分类器
    ovr.fit(X_train, y_train)
    predict_train = ovr.predict(X_train)
    sys.stdout.write("LR(ovr) Train Accuracy : %.4gn" % (
        metrics.accuracy_score(y_train, predict_train)))
    predict_test = ovr.predict(X_test)
    sys.stdout.write("LR(ovr) Test  Accuracy : %.4gn" % (
        metrics.accuracy_score(y_test, predict_test)))
    plot_decision_boundary(4, 8.5, 1.5, 4.5, lambda x: ovr.predict(x))  # 4个特征下注释掉,前两特征
    # plot_decision_boundary(0.5, 7.5, 0, 3, lambda x: ovr.predict(x))  # 4个特征下注释掉,后两特征
    plot_data(X_train, y_train)

3.4 OneVsOneClassifier

代码语言:javascript复制
class sklearn.multiclass.OneVsOneClassifier(estimator, n_jobs=None)

分类器接受一个评估器estimator对象,先定义一个LR模型log_reg,将log_reg传入OvO分类器 ovo = OneVsOneClassifier(log_reg)

代码语言:javascript复制
def test3(X_train, X_test, y_train, y_test):
    # For multiclass problems, only 'newton-cg', 'sag', 'saga' and 'lbfgs' handle multinomial loss;
    log_reg = LogisticRegression(multi_class='multinomial', solver='newton-cg')
	# ovo多分类,传入LR(multinomial,newton-cg or lbfgs),测试时,选择multi_class='ovr',结果一致,谁帮忙解释下
    ovo = OneVsOneClassifier(log_reg)  
    ovo.fit(X_train, y_train)
    predict_train = ovo.predict(X_train)
    sys.stdout.write("LR(ovo) Train Accuracy : %.4gn" % (
        metrics.accuracy_score(y_train, predict_train)))
    predict_test = ovo.predict(X_test)
    sys.stdout.write("LR(ovo) Test  Accuracy : %.4gn" % (
        metrics.accuracy_score(y_test, predict_test)))
    plot_decision_boundary(4, 8.5, 1.5, 4.5, lambda x: ovr.predict(x))  # 4个特征下注释掉,前两特征
    # plot_decision_boundary(0.5, 7.5, 0, 3, lambda x: ovr.predict(x))  # 4个特征下注释掉,后两特征
    plot_data(X_train, y_train)

4. 结果分析

执行预测:

代码语言:javascript复制
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=777)  # 默认test比例0.25
test1(X_train, X_test, y_train, y_test, multi_class='ovr', solver='liblinear')
test2(X_train, X_test, y_train, y_test)
test1(X_train, X_test, y_train, y_test, multi_class='multinomial', solver='newton-cg')
test3(X_train, X_test, y_train, y_test)

参数分类模式

data

LR(ovr, liblinear) 调用ovr

OvR分类器(传入LR(liblinear))

LR(multinomial, newton-cg)

OvO分类器(传入LR(multinomial, newton-cg))

seed(520), 2 features [sepal, L,W]

准确率:train / test

0.7679,0.8421

0.7679,0.8421

0.7768,0.8947

0.7768,0.8684

seed(777), 2 features [sepal, L,W]

准确率:train / test

0.7589,0.7368

0.7589,0.7368

0.7768,0.8158

0.7946,0.8158

seed(520), 2 features [petal, L,W]

准确率:train / test

0.8750,0.9474

0.8750,0.9474

0.9554,1

0.9554,1

seed(777), 2 features [petal, L,W]

准确率:train / test

0.9196,0.9474

0.9196,0.9474

0.9554,1

0.9554,1

seed(520), 4 features

-

-

-

-

-

准确率:train / test

0.9464,1

0.9464,1

0.9643,1

0.9732,1

seed(777), 4 features

-

-

-

-

-

准确率:train / test

0.9464,1

0.9464,1

0.9643,1

0.9732,1

  • 前两列是OvR模式的多分类,代码写法有区别,预测结果完全一样
  • 后两列是OvO模式的多分类(sklearn里没有提供 LR 内置的'ovo'选项)
  • 对比两种模式的多分类预测效果,OvO比OvR要好,但OvO是 O(n2)的复杂度
  • 在以sepal的长宽为特征的预测中,2维分类线可见setosa与剩余2类线性可分,剩余两类之间线性不可分
  • 在以petal的长宽为特征的预测相比于sepal的两个特征预测,petal的预测准确率高,由图也可看出,分界线较好的区分了3个种类
  • 在使用4维特征下进行预测,训练准确率OvO比OvR要好,测试准确率均达到100%,使用4维特征比使用2维特征预测,4维特征预测准确率更高

对于上面OvR,OvO分类器传入的 LR 模型(里面的参数该怎么填写),在上表的基础上做了如下测试:(如果有大佬看见这里,请赐教!)

代码语言:javascript复制
    OvR分类器(传入LR(ovr,liblinear)) 
增加 OvR分类器(传入LR(multinomial, newton-cg)) 
-------------------------------------------------
    OvO分类器(传入LR(multinomial, newton-cg))
增加 OvO分类器(传入LR(ovr,liblinear))

参数/准确率train, test

LR(ovr, liblinear) 调用ovr

OvR分类器(传入LR(ovr,liblinear))

OvR分类器(传入LR(multinomial, newton-cg))

LR(multinomial, newton-cg)

OvO分类器(传入LR(multinomial, newton-cg))

OvO分类器(传入LR(ovr,liblinear))

seed(520), 2 features [sepal, L,W]

0.7679,0.8421

0.7679,0.8421

0.7857,0.8947

0.7768,0.8947

0.7768,0.8684

0.7500,07105

seed(777), 2 features [sepal, L,W]

0.7589,0.7368

0.7589,0.7368

0.7589,0.8158

0.7768,0.8158

0.7946,0.8158

0.7232,0.7105

seed(520), 2 features [petal, L,W]

0.8750,0.9474

0.8750,0.9474

0.9375,1

0.9554,1

0.9554,1

0.9464,1

seed(777), 2 features [petal, L,W]

0.9196,0.9474

0.9196,0.9474

0.9464,1

0.9554,1

0.9554,1

0.9554,0.9737

seed(520), 4 features

0.9464,1

0.9464,1

0.9464,1

0.9643,1

0.9732,1

0.9732,1

seed(777), 4 features

0.9464,1

0.9464,1

0.9464,1

0.9643,1

0.9732,1

0.9821,0.9737

代码语言:javascript复制
    OvR分类器(传入LR(ovr,liblinear)) 
增加 OvR分类器(传入LR(multinomial, newton-cg)) # OvR 该参数下效果更好
-------------------------------------------------
    OvO分类器(传入LR(multinomial, newton-cg)) # OvO 该参数下效果更好
增加 OvO分类器(传入LR(ovr,liblinear))	

根据上面的数据,个人妄自推测:

  • 可能大部分情况下,OvR < OvO,LR(‘ovr’) < LR(‘multinomial’)
  • 搭配起来呢,所以同一OvR或者OvO下,传入LR(‘multinomial’)预测结果准确率更高

这块还请大佬指点迷津!!!

5. 附完整代码

代码语言:javascript复制
'''
	遇到不熟悉的库、模块、类、函数,可以依次:
	1)百度(google确实靠谱一些),如"matplotlib.pyplot",会有不错的博客供学习参考
	2)"终端-->python-->import xx-->help(xx.yy)",一开始的时候这么做没啥用,但作为资深工程师是必备技能
	3)试着修改一些参数,观察其输出的变化,在后面的程序中,会不断的演示这种办法
'''
# written by hitskyer, I just wanna say thank you !
# modified by Michael Ming on 2020.2.20
# Python 3.7
import sys
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multiclass import OneVsOneClassifier


def show_data_set(X, y, data):
    plt.plot(X[y == 0, 0], X[y == 0, 1], 'rs', label=data.target_names[0])
    plt.plot(X[y == 1, 0], X[y == 1, 1], 'bx', label=data.target_names[1])
    plt.plot(X[y == 2, 0], X[y == 2, 1], 'go', label=data.target_names[2])
    plt.xlabel(data.feature_names[0])
    plt.ylabel(data.feature_names[1])
    plt.title("鸢尾花2维数据")
    plt.legend()
    plt.rcParams['font.sans-serif'] = 'SimHei'  # 消除中文乱码
    plt.show()


def plot_data(X, y):
    plt.plot(X[y == 0, 0], X[y == 0, 1], 'rs', label='setosa')
    plt.plot(X[y == 1, 0], X[y == 1, 1], 'bx', label='versicolor')
    plt.plot(X[y == 2, 0], X[y == 2, 1], 'go', label='virginica')
    plt.xlabel("sepal length (cm)")
    plt.ylabel("sepal width (cm)")
    # plt.xlabel("petal length (cm)")
    # plt.ylabel("petal width (cm)")
    plt.title("预测分类边界")
    plt.legend()
    plt.rcParams['font.sans-serif'] = 'SimHei'  # 消除中文乱码
    plt.show()


def plot_decision_boundary(x_min, x_max, y_min, y_max, pred_func):
    h = 0.01
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)


def test1(X_train, X_test, y_train, y_test, multi_class='ovr', solver='liblinear'):
    log_reg = LogisticRegression(multi_class=multi_class, solver=solver)  # 调用ovr多分类
    log_reg.fit(X_train, y_train)
    predict_train = log_reg.predict(X_train)
    sys.stdout.write("LR(multi_class = %s, solver = %s) Train Accuracy : %.4gn" % (
        multi_class, solver, metrics.accuracy_score(y_train, predict_train)))
    predict_test = log_reg.predict(X_test)
    sys.stdout.write("LR(multi_class = %s, solver = %s) Test  Accuracy : %.4gn" % (
        multi_class, solver, metrics.accuracy_score(y_test, predict_test)))
    plot_decision_boundary(4, 8.5, 1.5, 4.5, lambda x: log_reg.predict(x))  # 4个特征下注释掉,前两特征
    # plot_decision_boundary(0.5, 7.5, 0, 3, lambda x: log_reg.predict(x))  # 4个特征下注释掉,后两特征
    plot_data(X_train, y_train)


def test2(X_train, X_test, y_train, y_test):
    #  multi_class默认auto
    # 'auto' selects 'ovr' if the data is binary, or if solver='liblinear',
    #  and otherwise selects 'multinomial'.
    #  看完help知道auto选择的是ovr,因为下面求解器选的是 liblinear
    #  所以test1和test2是同种效果,不一样的写法
    log_reg = LogisticRegression(solver='liblinear')
    ovr = OneVsRestClassifier(log_reg)
    ovr.fit(X_train, y_train)
    predict_train = ovr.predict(X_train)
    sys.stdout.write("LR(ovr) Train Accuracy : %.4gn" % (
        metrics.accuracy_score(y_train, predict_train)))
    predict_test = ovr.predict(X_test)
    sys.stdout.write("LR(ovr) Test  Accuracy : %.4gn" % (
        metrics.accuracy_score(y_test, predict_test)))
    plot_decision_boundary(4, 8.5, 1.5, 4.5, lambda x: ovr.predict(x))  # 4个特征下注释掉,前两特征
    # plot_decision_boundary(0.5, 7.5, 0, 3, lambda x: ovr.predict(x))  # 4个特征下注释掉,后两特征
    plot_data(X_train, y_train)


def test3(X_train, X_test, y_train, y_test):
    # For multiclass problems, only 'newton-cg', 'sag', 'saga' and 'lbfgs' handle multinomial loss;
    log_reg = LogisticRegression(multi_class='multinomial', solver='newton-cg')
    ovo = OneVsOneClassifier(log_reg)  # ovo多分类,传入LR(multinomial,newton-cg or lbfgs)
    ovo.fit(X_train, y_train)
    predict_train = ovo.predict(X_train)
    sys.stdout.write("LR(ovo) Train Accuracy : %.4gn" % (
        metrics.accuracy_score(y_train, predict_train)))
    predict_test = ovo.predict(X_test)
    sys.stdout.write("LR(ovo) Test  Accuracy : %.4gn" % (
        metrics.accuracy_score(y_test, predict_test)))
    plot_decision_boundary(4, 8.5, 1.5, 4.5, lambda x: ovo.predict(x))  # 4个特征下注释掉,前两特征
    # plot_decision_boundary(0.5, 7.5, 0, 3, lambda x: ovo.predict(x))  # 4个特征下注释掉,后两特征
    plot_data(X_train, y_train)


if __name__ == '__main__':
    iris = datasets.load_iris()
    # print(dir(iris))    # 查看data所具有的属性或方法
    # print(iris.data)    # 数据
    # print(iris.DESCR)   # 数据描述
    X = iris.data[:, :2]  # 取前2列特征sepal(平面只能展示2维)
    # X = iris.data[:, 2:4]   # petal两个特征
    # X = iris.data  # 全部4个特征
    y = iris.target  # 分类
    show_data_set(X, y, iris)
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=777)  # 默认test比例0.25
    test1(X_train, X_test, y_train, y_test, multi_class='ovr', solver='liblinear')
    test2(X_train, X_test, y_train, y_test)
    test1(X_train, X_test, y_train, y_test, multi_class='multinomial', solver='newton-cg')
    test3(X_train, X_test, y_train, y_test)

0 人点赞