一份半监督学习的指南-伪标签学习

2020-12-16 17:56:14 浏览数 (1)

1 引言

在ML中,有3种机器学习方法-监督学习、无监督学习和强化学习技术。 我们所知道的监督学习是指数据带有标签的情况, 无监督学习是仅存在数据而没有标签的情况,强化学习算法的思路非常简单,以游戏为例,如果在游戏中采取某种策略可以取得较高的得分,那么就进一步“强化”这种策略,以期继续取得较好的结果。

想象一下这样一种情况,在训练中,标记数据的数量更少,而未标记数据的数量更多。 一种称为半监督学习( [Semi-Supervised Learning],SSL)的新技术,它是监督学习和非监督学习的混合体。 顾名思义,半监督学习中同时存在一组标记的训练数据和另一组未标记的训练数据。 我们可以将这种情况想像成Google图片或Facebook通过其面孔(数据)识别出图片中的人物并根据该人物先前存储的图像生成建议名称(标签)的情况。

在本文中,我们将讨论如何使用半监督学习技术生成伪标签。

2 Pseudo-Labelling 伪标签

伪标签是使用标记的数据模型预测未标记数据并进行标记的过程。 首先,模型已经训练了包含标签的数据集,该模型用于为未标记的数据集生成伪标签。 最后,将数据集和标签(原始标签和伪标签)组合在一起以进行最终模型训练。 之所以称为(意味着虚幻),是因为它们可能是真实标签,也可能不是真实标签,并且是通过我们基于类似的数据模型生成的标签。

该方法的主旨思想其实很简单。首先,在标签数据上训练模型,然后使用经过训练的模型来预测无标签数据的标签,从而创建伪标签。此外,将标签数据和新生成的伪标签数据结合起来作为新的训练数据。

3 Python 实现

在这个例子中,我们使用了sklearn中的breast cancer数据集。我们知道整个已经包含了标签,但我们要修改它,将数据分成两部分,一部分有标签,另一部分没有标签。我们将从经过训练的带标签数据模型中为未带标签的数据生成我们自己的标签,然后最后使用两者合并的数据集来训练最终的模型。

3.1 数据集

Breast cancer dataset是预测肿瘤是良性(B)还是恶性(M)的分类问题。前两列为1)id和2)diagnosis(标签):

代码语言:javascript复制
a)radius_mean(从中心到外围点的距离的平均值)
b)texture_mean(灰度值的标准偏差)
c)perimeter_mean(周长)
d)area_mean(面积)
e)smoothness_mean(半径长度的局部变化)
f)compactness_mean(周长^ 2 /面积– 1.0)
g)concavity_mean(轮廓凹部的严重程度)
h) concave points_mean(轮廓的凹面部分的数量)

3.2 导入包

代码语言:javascript复制
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier

3.3 加载数据集

代码语言:javascript复制
X,y = load_breast_cancer(True)
X.shape
代码语言:javascript复制
(569, 30)

3.4 分割数据集

代码语言:javascript复制
x_train,x_test,y_train,_ = train_test_split(X,y,test_size=.6)
x_train.shape,y_train.shape,x_test.shape
代码语言:javascript复制
((227, 30), (227,), (342, 30)

3.5 训练模型

代码语言:javascript复制
model1 = RandomForestClassifier()
history = model1.fit(x_train,y_train)
history
代码语言:javascript复制
RandomForestRegressor(bootstrap=True, ccp_alpha=0.0, criterion=’mse’,
max_depth=None, max_features=’auto’, max_leaf_nodes=None,
max_samples=None, min_impurity_decrease=0.0,
min_impurity_split=None, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
n_estimators=100, n_jobs=None, oob_score=False,
random_state=None, verbose=0, warm_start=False)

3.6 评分

代码语言:javascript复制
model1.score(x_train,y_train)
代码语言:javascript复制
1.0

3.7 预测

代码语言:javascript复制
y_new = model1.predict(x_test)
y_new.shape
代码语言:javascript复制
(342,)

合并数据集

代码语言:javascript复制
final_X = np.concatenate((x_train,x_test))
final_X.shape
代码语言:javascript复制
(569, 30)

合并原始标签与伪标签

代码语言:javascript复制
final_Y = np.concatenate((y_train,y_test))
final_Y.shape
代码语言:javascript复制
(569,)

基于合并的数据集训练最终模型

代码语言:javascript复制
model2 = RandomForestRegressor()
model2.fit(final_X,final_Y)
model2.score(final_X,final_Y)
代码语言:javascript复制
1.0

4 结论

伪标签的实现到此为止,大家可以根据自己的想法去比赛中尝试吧。

0 人点赞