1.3 广告算法专题 - 交叉验证

2020-09-08 15:27:22 浏览数 (1)

  • 1.背景说明
  • 2.引出:验证数据的概念
  • 3.交叉验证
  • 4.实现

本文阐述交叉验证的相关内容,以及其中要注意的点 下面使用线性模型来进行关键点的讨论

1. 背景说明

在无论是线性模型或者svm等几乎所有的模型训练中都会用到的一项规则,那就是将训练数据分为训练数据和测试数据,来看使用训练数据训练出来的模型在测试数据上的效果

那么,在使用了一些正则化项避免过拟合的过程中,可能我们还需要一些操作

咱们先回顾一些内容,点击跳转查看【1.1 广告算法专题 -线性回归】

在正规方程求解后,得到

theta

求解后的式子,为防止

X^TX

不可逆或者在模型训练中防止过拟合,通过增加

lambda

扰动来进行定义

begin{aligned} theta = (X^TX lambda I)^{-1}X^Ty end{aligned}

然后得到相应的代价函数是【加入了相应的 L2 正则化项】

begin{aligned} J(overrightarrow theta) = frac{1}{2}sum_{i=1}^{m}(h_{overrightarrow theta}(x^{(i)}) - y^{(i)})^2 lambda sum_{j=1}^n theta_j^2 end{aligned}

2. 引出:验证数据的概念

那么,在这个时候我们就想要知道在进行训练数据求得

theta

的过程中,需要给定

lambda

的设定,但是给多大好。由此,我们引出了验证数据的概念

**重点:**给定不同的

lambda

值,进行在训练数据上的模型训练。然后使用验证数据进行对不同

lambda

的到的模型进行效果对比,选择出得分最高的模型。

然后,按照上述的方式,再进行不同特征或者不同模型的训练,挑出每个特征下或者不同模型下的得分最优项。

最后,不同的模型使用测试数据再进行效果比较,选择出相对最优的模型。

下面咱们拿一个图来描述一下

第(1)部分,是不同的模

第(2)部分,都会使用训练数据来训练样本

第(3)部分,第(1)部分中不同的模型使用不用的

lambda

进行训练,训练的结果到(4)

第(4)部分,将不同参数下的模型进行验证数据的验证

第(5)部分,选取效果最好的一组,得到相应的

lambda

theta

第(6)部分,将不同的模型下对应最好的

lambda

theta

进行测试数据的评估,找出近似最优模型

下面就差第(8)部分,进行合适的数据选取了,对模型训练有很大的帮助,看下面内容

3. 交叉验证

交叉验证主要用于防止模型过于复杂而引起的过拟合,是一种评价训练数据的数据集泛化能力的统计方法。其基本思想是将原始数据进行划分,分成训练集和测试集,训练集用来对模型进行训练,测试集用来测试训练得到的模型,以此来作为模型的评价指标

将原始数据划分为不同的部分,而不是固定的比例分配,常用的可能就是3折交叉验证,5折交叉验证。就是使用其中的

n-1

份进行训练数据,剩余的 1 份进行验证数据,如下图

这样3折交叉验证或者5折交叉验证是随机划分的折数,进行模型的训练和验证

4. 实现

使用到了Python库是 sklearn 中的 GridSearchCV 函数

这里的例子使用经典的广告效果数据,特征包括 'TV', 'Radio', 'Newspaper' 【大家这个数据网上很多随意下载一个就行】

下来整体看下数据的分布形式

代码语言:javascript复制
#!/usr/bin/python
# -*- coding:utf-8 -*-

import matplotlib.pyplot as plt
import pandas as pd

if __name__ == "__main__":
    # pandas读入
    data = pd.read_csv('data/Advertising.csv')    # TV、Radio、Newspaper、Sales
    x = data[['TV', 'Radio', 'Newspaper']]
    # x = data[['TV', 'Radio']]
    y = data['Sales']

    # 数据绘制
    plt.plot(x['TV'], y, 'ro', label='TV')
    plt.plot(x['Radio'], y, 'g^', label='Radio')
    plt.plot(x['Newspaper'], y, 'mv', label='Newspaer')
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.show()

再来看整体的代码

代码语言:javascript复制
#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso, Ridge, LinearRegression
from sklearn.model_selection import GridSearchCV   # 交叉验证 CV: cross validation


if __name__ == "__main__":
    # pandas读入
    data = pd.read_csv('data/Advertising.csv')    # TV、Radio、Newspaper、Sales
    x = data[['TV', 'Radio', 'Newspaper']]
    # x = data[['TV', 'Radio']]
    y = data['Sales']

    # 数据绘制
    plt.plot(x['TV'], y, 'ro', label='TV')
    plt.plot(x['Radio'], y, 'g^', label='Radio')
    plt.plot(x['Newspaper'], y, 'mv', label='Newspaer')
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.show()

    x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1)
    model = Lasso()
    # model = Ridge()

    alpha_can = np.logspace(-3, 2, 10)
    lasso_model = GridSearchCV(model, param_grid={'alpha': alpha_can}, cv=5)
    lasso_model.fit(x, y)
    print '超参数:n', lasso_model.best_params_

    y_hat = lasso_model.predict(np.array(x_test))
    mse = np.average((y_hat - np.array(y_test)) ** 2)  # Mean Squared Error
    rmse = np.sqrt(mse)  # Root Mean Squared Error
    print mse, rmse

    t = np.arange(len(x_test))
    plt.plot(t, y_test, 'r-', linewidth=2, label='Test')
    plt.plot(t, y_hat, 'g-', linewidth=2, label='Predict')
    plt.legend(loc='upper right')
    plt.grid()
    plt.show()

Lasso 得到的结果:

代码语言:javascript复制
超参数:
{'alpha': 2.1544346900318843}
mse: 1.9152263138298522 rmse: 1.3839170184045906
再看看使用 Ridge 回归的得分情况

Ridge 得到的结果

代码语言:javascript复制
超参数:
{'alpha': 100.0}
mse: 1.8102674184411307 rmse: 1.3454617863176683

看起来效果要比 Lasso 的情况要好一些

下面有几个点要说明一下:

  • 我们在进行学习的时候,一直使用的是
lambda

,在sklearn中使用的是

alpha
alpha
alpha

单独看下

  • >>> alpha_can = np.logspace(-3, 2, 10) >>> alpha_can array([1.00000000e-03, 3.59381366e-03, 1.29154967e-02, 4.64158883e-02, 1.66810054e-01, 5.99484250e-01, 2.15443469e 00, 7.74263683e 00, 2.78255940e 01, 1.00000000e 02])
  • 后面的一些代码是对测试数据进行了模型效果评价,计算出mse(均方误差)和 rmse(均方根误差)进行模型效果参考

并且进行了数据的打印,如下图

好了,交叉验证的内容先聊到这里,更加深刻地内容将会把这些基本的算法模型整理差不多之后,再进行了一个深度的剖析!

作者:Johngo

配图:Pexels

0 人点赞