◆◆
数据集来源
◆◆
该数据集来自若干新闻站点2012年6月—7月期间国内,国际,体育,社会,娱乐等18个频道的新闻数据。
根据新闻正文内容分析新闻的类别数据集官网链接:
http://www.sogou.com/labs/resource/tce.php.
该数据集样例格式如下所示:
在 FlyAI竞赛平台上 提供了超详细的参考代码,我们可以通过参加搜狗新闻文本分类预测练习赛进行进一步学习和优化。
◆◆
代码实现
◆◆
1.1、算法流程及实现
算法流程主要分为以下四个部分进行介绍:
1.数据加载
2.构建网络
3.模型训练
1.数据加载
对每条新闻数据的读取和处理是在processor.py文件中完成。
具体实现如下:
2.构建网络
由于是搜狗新闻文本类数据,这里我们可以使用一维卷积Conv1D BiGRU来构建网络,网络结构如下所示:
运行summary()方法后输出的网络结构如下图:
3.模型训练
这里我们设置了epoch为5,batch为128,采用adam优化器来训练网络,EarlyStopping可以加速调参过程。然后通过调用FlyAI提供的train_log方法可以在训练过程中实时的看到训练集和验证集的准确率及损失变化曲线。
1.2.最终结果
通过使用自定义CNN网络结构 双向GRU网络的方法,在epoch为10,batch为128的条件下使用adam优化器下不断优化模型参数,使用early_stopping规则在model训练达到early_stopping条件时提前终止训练提高model优化效率,最终模型在测试集的准确率达到91 。