版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1433068
一、多类回归简介
1. 基本介绍
代码语言:txt复制 如上篇所述,逻辑回归比较常用的是因变量为二分类的情况,这也是比较简单的一种形式。但在现实中,因变量的分类有时候多于两类,如疗效可能是“无效”“显效”“痊愈”三类,当然可以把其中两类进行合并,然后仍然按照二分类逻辑回归进行分析,但是合并的弊端是显而易见的,它可能损失一定的信息。而多分类则充分利用了完整的信息,可能提供更多的结果。如果目标类别数超过两个,这时就需要使用多类回归(Multinomial Regression)。
代码语言:txt复制 在统计学里,多类回归是一个将逻辑回归一般化成多类别问题得到的分类方法。用更加专业的话来说,它是用来预测一个具有类别分布的因变量不同可能结果的概率的模型。在多类回归中,因变量是根据一系列自变量(就是我们所说的特征、观测变量)来预测得到的。具体来说,就是通过将自变量和相应参数进行线性组合之后,使用某种概率模型来计算预测因变量中得到某个结果的概率,而自变量对应的参数,即回归系数,是通过训练数据计算得到的。
2. 模型介绍
代码语言:txt复制 实现多类回归模型最简单的方法是,对于所有K个可能的分类结果,运行K−1个独立二元逻辑回归模型,在运行过程中把其中一个类别看成是主类别,然后将其它K−1个类别和所选择的主类别分别进行回归。通过这样的方式,如果选择结果K作为主类别的话,我们可以得到以下公式:
代码语言:txt复制 上面的公式中已经引入了所有可能结果对应的回归系数集合了。然后对公式左右两边进行指数化处理可得以下公式:
需要注意的是,最后得到的概率之和必须等于1,基于这个事实我们可以得到:
这样就可以把以上公式代入到前面的公式中得到:
代码语言:txt复制 通过这样的方法我们就能计算出所有给定未预测样本情况下得到某个结果的概率。上面公式中所涉及的每一个权重向量
中的未知系数,可以通过最大后验概率(MAP)来计算,同时也可以使用其它方法来计算,如一些基于梯度的算法。
代码语言:txt复制 如果使用二元逻辑回归公式的对数模型的话,可以直接将其扩展成多类回归模型,形式如下:
这里用一个额外项
来确保所有概率能够形成一个概率分布,使得这些概率的和等于1。
然后将等式两边进行指数化,得到以下公式:
由于所有概率之和等于1,因此可以得到Z的推导公式:
综合以上公式,最后可以得到每一个结果对应的概率公式:
代码语言:txt复制 仔细观察的话可以发现,所有的概率都具有以下形式:
代码语言:txt复制 我们把具有以下形式的函数称为softmax函数:
代码语言:txt复制 这个函数能够将
之间的差别放大,当存在一个
比所有值中的最大值小很多的话,那么它对应的softmax函数值就会趋于0。相反,当
是最大值的时候,除非排在第二大的值跟它很接近,否则softmax函数值会趋于1。所以softmax函数可以构造出一个类似平滑函数一样的加权平均数。我们可以把上面的公式写成如下softmax函数的形式:
二、MADlib的多类回归函数
1. 训练函数
(1) 语法
代码语言:javascript复制multinom(source_table,
model_table,
dependent_varname,
independent_varname,
ref_category,
link_func,
grouping_col,
optim_params,
verbose
)
(2) 参数
参数名称 | 数据类型 | 描述 |
---|---|---|
source_table | VARCHAR | 包含训练数据的表名。 |
model_table | VARCHAR | 包含输出模型的表名。主输出表列和概要输出表列如表2、3所示。 |
dependent_varname | VARCHAR | 因变量列名。 |
independent_varname | VARCHAR | 评估使用的自变量的表达式列表,一般显式地由包括一个常数1项的自变量列表提供。 |
link_function(可选) | VARCHAR | 缺省值为'logit'。连接函数参数,当前仅支持logit。 |
ref_category(可选) | VARCHAR | 缺省值为'0',该参数指定参考类别。在做多类回归时,如果因变量Y有n个值,以其中一个类别作为参考类别,其它类别都同它相比较生成n-1个非冗余的logit变量模型。对于参考类别,其模型中所有系数均为0。 |
grouping_col(可选) | VARCHAR | 缺省值为NULL。和SQL中的“GROUP BY”类似,是一个将输入数据集分成离散组的表达式,每个组运行一个回归。此值为NULL时,将不使用分组,并产生一个单一的结果模型。 |
optim_params(可选) | VARCHAR | 缺省值为'max_iter=100,optimizer=irls,tolerance=1e-6',指定优化参数。 |
verbose(可选) | BOOLEAN | 缺省值为FALSE,指定是否提供训练的详细输出结果。 |
表1 multinom函数参数说明
列名 | 数据类型 | 描述 |
---|---|---|
<...> | TEXT | 分组列,取决于grouping_col输入,可能是多个列。 |
category | VARCHAR | 表示分类值的字符串 |
coef | FLOAT8[] | 回归系数向量。 |
log_likelihood | FLOAT8 | 对数似然比l(c)。 |
std_err | FLOAT8[] | 系数的标准方差向量。 |
z_stats | FLOAT8[] | 系数的z-统计量向量。 |
p_values | FLOAT8[] | 系数的P值向量。 |
num_rows_processed | INTEGER | 实际处理的行数。 |
num_missing_rows_skipped | INTEGER | 训练时因为缺失值或错误跳过的行数。 |
num_iterations | INTEGER | 实际迭代次数。 |
表2 multinom函数主输出表列说明
代码语言:txt复制 训练函数在产生输出表的同时,还会创建一个名为<model_table>_summary的概要表,具有以下列:
列名 | 数据类型 | 描述 |
---|---|---|
Method | VARCHAR | 'multinom',描述模型的字符串。 |
source_table | VARCHAR | 源表名。 |
model_table | VARCHAR | 模型表名。 |
dependent_varname | VARCHAR | 因变量表达式。 |
independent_varname | VARCHAR | 自变量表达式。 |
ref_category | VARCHAR | 参考类别的字符串表示。 |
link_func | VARCHAR | 连接函数参数,当前只实现了'logit'。 |
grouping_col | VARCHAR | 分组列。 |
optimizer_params | VARCHAR | 包含所有优化参数的字符串,形式是‘optimizer=..., max_iter=..., tolerance=...’。 |
num_all_groups | INTEGER | 分组数。 |
num_failed_groups | INTEGER | 失败分组数。 |
total_rows_processed | BIGINT | 所有分组处理的总行数。 |
total_rows_skipped | BIGINT | 所有组由于缺少值或失败跳过的总行数。 |
表3 multinom函数概要输出表列说明
2. 预测函数
(1) 语法
代码语言:javascript复制multinom_predict(model_table,
predict_table_input,
output_table,
predict_type,
verbose,
id_column
)
(2) 参数
参数名称 | 数据类型 | 描述 |
---|---|---|
model_table | TEXT | 训练函数生成的模型表名,是multinom()函数的输出表。 |
predict_table_input | TEXT | 包含被预测数据的表名。表中必须有作为主键的ID列。 |
output_table | TEXT | 包含预测结果的输出表名。输出表的列根据predict_type参数而有所不同。当predict_type = response时,输出表中包含两列:SERIAL类型的id,表示主键,TEXT类型的category列,包含预测的类别。predict_type = probability时,除id列外,每个类别输出一列,列名就是类别值,列值数据类型为FLOAT8,表示预测为该类别的概率。 |
predict_type | TEXT | 'response'或'probability'。使用前者将输出预测最大概率的类别值,使用后者将输出每种类别的预测概率。 |
verbose | BOOLEAN | 控制是否显示详细信息,缺省值为FALSE。 |
id_column | TEXT | 输入表中的ID列名。 |
表4 multinom_predict函数参数说明
三、示例
1. 问题提出
代码语言:txt复制 下表给出了对某中学20名视力低下学生视力监测的结果数据。试用多类回归方法分析视力低下程度(由轻到重共3级)与年龄、性别(1代表男性,2代表女性)之间的关系。
编号 | 视力低下程度 | 性别 | 年龄 |
---|---|---|---|
1 | 1 | 1 | 15 |
2 | 1 | 1 | 15 |
3 | 2 | 1 | 14 |
4 | 2 | 2 | 16 |
5 | 3 | 2 | 16 |
6 | 3 | 2 | 17 |
7 | 2 | 2 | 17 |
8 | 2 | 1 | 18 |
9 | 1 | 1 | 14 |
10 | 3 | 2 | 18 |
11 | 1 | 1 | 17 |
12 | 1 | 2 | 17 |
13 | 1 | 1 | 15 |
14 | 2 | 1 | 18 |
15 | 1 | 2 | 15 |
16 | 1 | 2 | 15 |
17 | 3 | 2 | 17 |
18 | 1 | 1 | 15 |
19 | 1 | 1 | 15 |
20 | 2 | 2 | 16 |
表5 视力监测结果
2. 训练模型
代码语言:javascript复制-- 建立测试数据表并装载原始数据
drop table if exists t1;
create table t1 (id int, y int, x1 int, x2 int);
insert into t1 values
(1, 1, 1, 15), (2, 1, 1, 15), (3, 2, 1, 14), (4, 2, 2, 16),
(5, 3, 2, 16), (6, 3, 2, 17), (7, 2, 2, 17), (8, 2, 1, 18),
(9, 1, 1, 14), (10, 3, 2, 18), (11, 1, 1, 17), (12, 1, 2, 17),
(13, 1, 1, 15), (14, 2, 1, 18), (15, 1, 2, 15), (16, 1, 2, 15),
(17, 3, 2, 17), (18, 1, 1, 15), (19, 1, 1, 15), (20, 2, 2, 16);
-- 调用训练函数
drop table if exists t1_output, t1_output_summary;
select madlib.multinom('t1',
't1_output',
'y',
'array[1, x1, x2]',
'1',
'logit'
);
-- 查看回归结果
x on
select * from t1_output;
代码语言:txt复制 结果:
代码语言:javascript复制-[ RECORD 1 ]------ ---------------------------------------------------------
category | 2
coef | {-14.8290303728741,0.732451965917644,0.83559834958054}
log_likelihood | -13.678884731229
std_err | {8.21112520279341,1.18343112456502,0.498230781591342}
z_stats | {-1.80596812332484,0.618922344286716,1.67713112166946}
p_values | {0.0709233188904113,0.535967517828425,0.093516844250734}
num_rows_processed | 20
num_rows_skipped | 0
num_iterations | 14
-[ RECORD 2 ]------ ---------------------------------------------------------
category | 3
coef | {-66.5008632884155,16.0811838608692,2.1124007986729}
log_likelihood | -13.678884731229
std_err | {755.074438584232,377.375648716325,1.1813345463193}
z_stats | {-0.088071930249824,0.0426131996475413,1.78814782421672}
p_values | {0.929819506081099,0.966009873251665,0.073752161545248}
num_rows_processed | 20
num_rows_skipped | 0
num_iterations | 14
3. 使用模型进行预测(使用源表数据)
代码语言:javascript复制x off
drop table if exists t1_prd;
select madlib.multinom_predict('t1_output', 't1', 't1_prd', 'probability');
-- 显示预测值
select * from t1_prd order by id;
代码语言:txt复制 结果:
代码语言:javascript复制 id | 1 | 2 | 3
---- -------------------- -------------------- ----------------------
1 | 0.826726567679996 | 0.173273426274487 | 6.04551768703297e-09
2 | 0.826726567679996 | 0.173273426274487 | 6.04551768703297e-09
3 | 0.916690056274327 | 0.0833099429149167 | 8.10755888689936e-10
4 | 0.38637457408413 | 0.38848495778712 | 0.22514046812875
5 | 0.38637457408413 | 0.38848495778712 | 0.22514046812875
6 | 0.12290156351699 | 0.284982917588031 | 0.592115518894979
7 | 0.12290156351699 | 0.284982917588031 | 0.592115518894979
8 | 0.28005405528707 | 0.719944787200939 | 1.15751199141851e-06
9 | 0.916690056274327 | 0.0833099429149167 | 8.10755888689936e-10
10 | 0.0216536931969841 | 0.115794825319112 | 0.862551481483904
11 | 0.472878151884198 | 0.527121611725927 | 2.36389875351574e-07
12 | 0.12290156351699 | 0.284982917588031 | 0.592115518894979
13 | 0.826726567679996 | 0.173273426274487 | 6.04551768703297e-09
14 | 0.28005405528707 | 0.719944787200939 | 1.15751199141851e-06
15 | 0.663808165241336 | 0.289409315483704 | 0.04678251927496
16 | 0.663808165241336 | 0.289409315483704 | 0.04678251927496
17 | 0.12290156351699 | 0.284982917588031 | 0.592115518894979
18 | 0.826726567679996 | 0.173273426274487 | 6.04551768703297e-09
19 | 0.826726567679996 | 0.173273426274487 | 6.04551768703297e-09
20 | 0.38637457408413 | 0.38848495778712 | 0.22514046812875
(20 rows)
代码语言:txt复制 返回预测类别值:
代码语言:javascript复制x off
drop table if exists t1_prd;
select madlib.multinom_predict('t1_output', 't1', 't1_prd', 'response');
-- 显示预测值
select t0.id, t1.y realvalue, t0.category predict,
case when t1.y = t0.category then 'T' else 'F' end istrue
from t1_prd t0, t1
where t0.id = t1.id
order by id;
代码语言:txt复制 结果:
代码语言:javascript复制 id | realvalue | predict | istrue
---- ----------- --------- --------
1 | 1 | 1 | T
2 | 1 | 1 | T
3 | 2 | 1 | F
4 | 2 | 2 | T
5 | 3 | 2 | F
6 | 3 | 3 | T
7 | 2 | 3 | F
8 | 2 | 2 | T
9 | 1 | 1 | T
10 | 3 | 3 | T
11 | 1 | 2 | F
12 | 1 | 3 | F
13 | 1 | 1 | T
14 | 2 | 2 | T
15 | 1 | 1 | T
16 | 1 | 1 | T
17 | 3 | 3 | T
18 | 1 | 1 | T
19 | 1 | 1 | T
20 | 2 | 2 | T
(20 rows)
代码语言:txt复制 可以看到该模型在源数据上的预测正确率只有75%。