MADlib——基于SQL的数据挖掘解决方案(14)——回归之多类回归

2019-05-25 19:36:07 浏览数 (1)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1433068

一、多类回归简介

1. 基本介绍

代码语言:txt复制
    如上篇所述,逻辑回归比较常用的是因变量为二分类的情况,这也是比较简单的一种形式。但在现实中,因变量的分类有时候多于两类,如疗效可能是“无效”“显效”“痊愈”三类,当然可以把其中两类进行合并,然后仍然按照二分类逻辑回归进行分析,但是合并的弊端是显而易见的,它可能损失一定的信息。而多分类则充分利用了完整的信息,可能提供更多的结果。如果目标类别数超过两个,这时就需要使用多类回归(Multinomial Regression)。
代码语言:txt复制
    在统计学里,多类回归是一个将逻辑回归一般化成多类别问题得到的分类方法。用更加专业的话来说,它是用来预测一个具有类别分布的因变量不同可能结果的概率的模型。在多类回归中,因变量是根据一系列自变量(就是我们所说的特征、观测变量)来预测得到的。具体来说,就是通过将自变量和相应参数进行线性组合之后,使用某种概率模型来计算预测因变量中得到某个结果的概率,而自变量对应的参数,即回归系数,是通过训练数据计算得到的。

2. 模型介绍

代码语言:txt复制
    实现多类回归模型最简单的方法是,对于所有K个可能的分类结果,运行K−1个独立二元逻辑回归模型,在运行过程中把其中一个类别看成是主类别,然后将其它K−1个类别和所选择的主类别分别进行回归。通过这样的方式,如果选择结果K作为主类别的话,我们可以得到以下公式:
代码语言:txt复制
    上面的公式中已经引入了所有可能结果对应的回归系数集合了。然后对公式左右两边进行指数化处理可得以下公式:

需要注意的是,最后得到的概率之和必须等于1,基于这个事实我们可以得到:

这样就可以把以上公式代入到前面的公式中得到:

代码语言:txt复制
    通过这样的方法我们就能计算出所有给定未预测样本情况下得到某个结果的概率。上面公式中所涉及的每一个权重向量

中的未知系数,可以通过最大后验概率(MAP)来计算,同时也可以使用其它方法来计算,如一些基于梯度的算法。

代码语言:txt复制
    如果使用二元逻辑回归公式的对数模型的话,可以直接将其扩展成多类回归模型,形式如下:

这里用一个额外项

来确保所有概率能够形成一个概率分布,使得这些概率的和等于1。

然后将等式两边进行指数化,得到以下公式:

由于所有概率之和等于1,因此可以得到Z的推导公式:

综合以上公式,最后可以得到每一个结果对应的概率公式:

代码语言:txt复制
    仔细观察的话可以发现,所有的概率都具有以下形式:
代码语言:txt复制
    我们把具有以下形式的函数称为softmax函数:
代码语言:txt复制
    这个函数能够将

之间的差别放大,当存在一个

比所有值中的最大值小很多的话,那么它对应的softmax函数值就会趋于0。相反,当

是最大值的时候,除非排在第二大的值跟它很接近,否则softmax函数值会趋于1。所以softmax函数可以构造出一个类似平滑函数一样的加权平均数。我们可以把上面的公式写成如下softmax函数的形式:

二、MADlib的多类回归函数

1. 训练函数

(1) 语法

代码语言:javascript复制
multinom(source_table,  
         model_table,  
         dependent_varname,  
         independent_varname,  
         ref_category,  
         link_func,  
         grouping_col,  
         optim_params,  
         verbose  
        )  

(2) 参数

参数名称

数据类型

描述

source_table

VARCHAR

包含训练数据的表名。

model_table

VARCHAR

包含输出模型的表名。主输出表列和概要输出表列如表2、3所示。

dependent_varname

VARCHAR

因变量列名。

independent_varname

VARCHAR

评估使用的自变量的表达式列表,一般显式地由包括一个常数1项的自变量列表提供。

link_function(可选)

VARCHAR

缺省值为'logit'。连接函数参数,当前仅支持logit。

ref_category(可选)

VARCHAR

缺省值为'0',该参数指定参考类别。在做多类回归时,如果因变量Y有n个值,以其中一个类别作为参考类别,其它类别都同它相比较生成n-1个非冗余的logit变量模型。对于参考类别,其模型中所有系数均为0。

grouping_col(可选)

VARCHAR

缺省值为NULL。和SQL中的“GROUP BY”类似,是一个将输入数据集分成离散组的表达式,每个组运行一个回归。此值为NULL时,将不使用分组,并产生一个单一的结果模型。

optim_params(可选)

VARCHAR

缺省值为'max_iter=100,optimizer=irls,tolerance=1e-6',指定优化参数。

verbose(可选)

BOOLEAN

缺省值为FALSE,指定是否提供训练的详细输出结果。

表1 multinom函数参数说明

列名

数据类型

描述

<...>

TEXT

分组列,取决于grouping_col输入,可能是多个列。

category

VARCHAR

表示分类值的字符串

coef

FLOAT8[]

回归系数向量。

log_likelihood

FLOAT8

对数似然比l(c)。

std_err

FLOAT8[]

系数的标准方差向量。

z_stats

FLOAT8[]

系数的z-统计量向量。

p_values

FLOAT8[]

系数的P值向量。

num_rows_processed

INTEGER

实际处理的行数。

num_missing_rows_skipped

INTEGER

训练时因为缺失值或错误跳过的行数。

num_iterations

INTEGER

实际迭代次数。

表2 multinom函数主输出表列说明

代码语言:txt复制
    训练函数在产生输出表的同时,还会创建一个名为<model_table>_summary的概要表,具有以下列:

列名

数据类型

描述

Method

VARCHAR

'multinom',描述模型的字符串。

source_table

VARCHAR

源表名。

model_table

VARCHAR

模型表名。

dependent_varname

VARCHAR

因变量表达式。

independent_varname

VARCHAR

自变量表达式。

ref_category

VARCHAR

参考类别的字符串表示。

link_func

VARCHAR

连接函数参数,当前只实现了'logit'。

grouping_col

VARCHAR

分组列。

optimizer_params

VARCHAR

包含所有优化参数的字符串,形式是‘optimizer=..., max_iter=..., tolerance=...’。

num_all_groups

INTEGER

分组数。

num_failed_groups

INTEGER

失败分组数。

total_rows_processed

BIGINT

所有分组处理的总行数。

total_rows_skipped

BIGINT

所有组由于缺少值或失败跳过的总行数。

表3 multinom函数概要输出表列说明

2. 预测函数

(1) 语法

代码语言:javascript复制
multinom_predict(model_table,  
                 predict_table_input,  
                 output_table,  
                 predict_type,  
                 verbose,  
                 id_column  
                ) 

(2) 参数

参数名称

数据类型

描述

model_table

TEXT

训练函数生成的模型表名,是multinom()函数的输出表。

predict_table_input

TEXT

包含被预测数据的表名。表中必须有作为主键的ID列。

output_table

TEXT

包含预测结果的输出表名。输出表的列根据predict_type参数而有所不同。当predict_type = response时,输出表中包含两列:SERIAL类型的id,表示主键,TEXT类型的category列,包含预测的类别。predict_type = probability时,除id列外,每个类别输出一列,列名就是类别值,列值数据类型为FLOAT8,表示预测为该类别的概率。

predict_type

TEXT

'response'或'probability'。使用前者将输出预测最大概率的类别值,使用后者将输出每种类别的预测概率。

verbose

BOOLEAN

控制是否显示详细信息,缺省值为FALSE。

id_column

TEXT

输入表中的ID列名。

表4 multinom_predict函数参数说明

三、示例

1. 问题提出

代码语言:txt复制
    下表给出了对某中学20名视力低下学生视力监测的结果数据。试用多类回归方法分析视力低下程度(由轻到重共3级)与年龄、性别(1代表男性,2代表女性)之间的关系。

编号

视力低下程度

性别

年龄

1

1

1

15

2

1

1

15

3

2

1

14

4

2

2

16

5

3

2

16

6

3

2

17

7

2

2

17

8

2

1

18

9

1

1

14

10

3

2

18

11

1

1

17

12

1

2

17

13

1

1

15

14

2

1

18

15

1

2

15

16

1

2

15

17

3

2

17

18

1

1

15

19

1

1

15

20

2

2

16

表5 视力监测结果

2. 训练模型

代码语言:javascript复制
-- 建立测试数据表并装载原始数据  
drop table if exists t1;  
create table t1 (id int, y int, x1 int, x2 int);  
insert into t1 values   
(1, 1, 1, 15), (2, 1, 1, 15), (3, 2, 1, 14), (4, 2, 2, 16),  
(5, 3, 2, 16), (6, 3, 2, 17), (7, 2, 2, 17), (8, 2, 1, 18),  
(9, 1, 1, 14), (10, 3, 2, 18), (11, 1, 1, 17), (12, 1, 2, 17),  
(13, 1, 1, 15), (14, 2, 1, 18), (15, 1, 2, 15), (16, 1, 2, 15),  
(17, 3, 2, 17), (18, 1, 1, 15), (19, 1, 1, 15), (20, 2, 2, 16);  
  
-- 调用训练函数  
drop table if exists t1_output, t1_output_summary;  
select madlib.multinom('t1',  
                       't1_output',  
                       'y',  
                       'array[1, x1, x2]',  
                       '1',  
                       'logit'  
                       );  
  
-- 查看回归结果  
x on  
select * from t1_output;
代码语言:txt复制
    结果:
代码语言:javascript复制
-[ RECORD 1 ]------ ---------------------------------------------------------  
category           | 2  
coef               | {-14.8290303728741,0.732451965917644,0.83559834958054}  
log_likelihood     | -13.678884731229  
std_err            | {8.21112520279341,1.18343112456502,0.498230781591342}  
z_stats            | {-1.80596812332484,0.618922344286716,1.67713112166946}  
p_values           | {0.0709233188904113,0.535967517828425,0.093516844250734}  
num_rows_processed | 20  
num_rows_skipped   | 0  
num_iterations     | 14  
-[ RECORD 2 ]------ ---------------------------------------------------------  
category           | 3  
coef               | {-66.5008632884155,16.0811838608692,2.1124007986729}  
log_likelihood     | -13.678884731229  
std_err            | {755.074438584232,377.375648716325,1.1813345463193}  
z_stats            | {-0.088071930249824,0.0426131996475413,1.78814782421672}  
p_values           | {0.929819506081099,0.966009873251665,0.073752161545248}  
num_rows_processed | 20  
num_rows_skipped   | 0  
num_iterations     | 14  

3. 使用模型进行预测(使用源表数据)

代码语言:javascript复制
x off  
drop table if exists t1_prd;  
select madlib.multinom_predict('t1_output', 't1', 't1_prd', 'probability');  
-- 显示预测值  
select * from t1_prd order by id; 
代码语言:txt复制
    结果:
代码语言:javascript复制
 id |         1          |         2          |          3             
---- -------------------- -------------------- ----------------------  
  1 |  0.826726567679996 |  0.173273426274487 | 6.04551768703297e-09  
  2 |  0.826726567679996 |  0.173273426274487 | 6.04551768703297e-09  
  3 |  0.916690056274327 | 0.0833099429149167 | 8.10755888689936e-10  
  4 |   0.38637457408413 |   0.38848495778712 |     0.22514046812875  
  5 |   0.38637457408413 |   0.38848495778712 |     0.22514046812875  
  6 |   0.12290156351699 |  0.284982917588031 |    0.592115518894979  
  7 |   0.12290156351699 |  0.284982917588031 |    0.592115518894979  
  8 |   0.28005405528707 |  0.719944787200939 | 1.15751199141851e-06  
  9 |  0.916690056274327 | 0.0833099429149167 | 8.10755888689936e-10  
 10 | 0.0216536931969841 |  0.115794825319112 |    0.862551481483904  
 11 |  0.472878151884198 |  0.527121611725927 | 2.36389875351574e-07  
 12 |   0.12290156351699 |  0.284982917588031 |    0.592115518894979  
 13 |  0.826726567679996 |  0.173273426274487 | 6.04551768703297e-09  
 14 |   0.28005405528707 |  0.719944787200939 | 1.15751199141851e-06  
 15 |  0.663808165241336 |  0.289409315483704 |     0.04678251927496  
 16 |  0.663808165241336 |  0.289409315483704 |     0.04678251927496  
 17 |   0.12290156351699 |  0.284982917588031 |    0.592115518894979  
 18 |  0.826726567679996 |  0.173273426274487 | 6.04551768703297e-09  
 19 |  0.826726567679996 |  0.173273426274487 | 6.04551768703297e-09  
 20 |   0.38637457408413 |   0.38848495778712 |     0.22514046812875  
(20 rows) 
代码语言:txt复制
    返回预测类别值:
代码语言:javascript复制
x off  
drop table if exists t1_prd;  
select madlib.multinom_predict('t1_output', 't1', 't1_prd', 'response');  
-- 显示预测值  
select t0.id, t1.y realvalue, t0.category predict,   
       case when t1.y = t0.category then 'T' else 'F' end istrue  
  from t1_prd t0, t1   
 where t0.id = t1.id  
 order by id;  
代码语言:txt复制
    结果:
代码语言:javascript复制
 id | realvalue | predict | istrue   
---- ----------- --------- --------  
  1 |         1 | 1       | T  
  2 |         1 | 1       | T  
  3 |         2 | 1       | F  
  4 |         2 | 2       | T  
  5 |         3 | 2       | F  
  6 |         3 | 3       | T  
  7 |         2 | 3       | F  
  8 |         2 | 2       | T  
  9 |         1 | 1       | T  
 10 |         3 | 3       | T  
 11 |         1 | 2       | F  
 12 |         1 | 3       | F  
 13 |         1 | 1       | T  
 14 |         2 | 2       | T  
 15 |         1 | 1       | T  
 16 |         1 | 1       | T  
 17 |         3 | 3       | T  
 18 |         1 | 1       | T  
 19 |         1 | 1       | T  
 20 |         2 | 2       | T  
(20 rows)  
代码语言:txt复制
    可以看到该模型在源数据上的预测正确率只有75%。

0 人点赞