前言
前文已经介绍了 fastText 开源工具的安装,接下来使用 fastText 工具来解决具体的文本分类问题(fastText 还可以训练词向量,此时 fastText 可以看成是 word2vec 的改进版,具体还是使用 skip-gram 或 CBOW 模型进行建模)。文本分类的目标是将一些文档分配到一个或者多个预先定义的类别中。
- 将文档分配到一个预定义类别称为单标签分类,比如判断一封邮件是否为垃圾邮件,一封邮件要不是垃圾邮件要不是非垃圾邮件;
- 将文档分配到多个预定义类别称为多标签分类,比如知乎问题上的话题标签,一个问题可能对应着多个话题标签;
fastText 既能解决单标签分类问题,又能解决多标签分类问题。
准备数据集
使用 fastText 工具解决文本分类任务时,存放数据集的文本文件必须满足以下两个条件:
- 文本文件中的每一行对应一个文档;
- 文档的类别标签以
__label__
为前缀放在文档的最前面;
下面举两个符合条件的小例子。
- 单标签数据集:
__label__1 i love you
__label__0 i hate you
上面的单标签数据集中一共有 2 个文档(每一行一个文档),第一个文档 "i love you",对应的类别标签为 1(具体类别名为 __label__
前缀后面的文本),第二个文档 "i hate you",对应的类别标签为 0。
- 多标签数据集:
__label__sauce __label__cheese How much does potato starch affect a cheese sauce recipe?
__label__food-safety __label__acidity Dangerous pathogens capable of growing in acidic environments
__label__cast-iron __label__stove How do I cover up the white spots on my cast iron stove?
__label__restaurant Michelin Three Star Restaurant; but if the chef is not there
__label__knife-skills __label__dicing Without knife skills, how can I quickly and accurately dice vegetables?
__label__storage-method __label__equipment __label__bread What's the purpose of a bread box?
__label__baking __label__food-safety __label__substitutions __label__peanuts how to seperate peanut oil from roasted peanuts at home?
__label__chocolate American equivalent for British chocolate terms
__label__baking __label__oven __label__convection Fan bake vs bake
__label__sauce __label__storage-lifetime __label__acidity __label__mayonnaise Regulation and balancing of readymade packed mayonnaise and other sauces
多标签数据集中的不同的类别标签用空格来分割。比如:对于 "Regulation and balancing of readymade packed mayonnaise and other sauces" 文档的类别标签有 sauce、storage-lifetime、acidity 和 mayonnaise 5 个。
单标签和多标签数据集在 fastText 的使用上并没有太大区别。为了方便,接下来以上面简单的多标签数据集为例来介绍 fastText。在这之前首先将上面的多标签数据集保存到一个名为 train.data 的文本文件中。
训练 fastText 模型
代码语言:javascript复制>>> import fasttext
>>> model = fasttext.train_supervised(input = r"./train.data")
Read 0M words
Number of words: 80
Number of labels: 20
Progress: 100.0% words/sec/thread: 8153 lr: 0.000000 avg.loss: 3.019448 ETA: 0h 0m 0s
使用 train_supervised(input = r"./train.data")
函数训练模型,其中 input 参数指定包含训练数据集的文本文件,函数返回在训练集上训练好的模型对象,我们可以通过这个模型对象访问训练模型的各种信息。
>>> model.words # 训练集的词汇表
['</s>', 'the', 'and', 'of', 'I',
'How', 'how', 'a', 'bake', 'dice', 'oil', 'peanut', 'seperate', 'to', 'box?', 'bread', 'purpose', "What's", 'vegetables?', 'peanuts', 'accurately', 'quickly', 'can', 'skills,', 'knife', 'there', 'terms', 'sauces', 'other', 'mayonnaise', 'packed', 'readymade', 'balancing', 'Regulation', 'vs', 'Fan', 'from', 'chocolate', 'British', 'for', 'equivalent', 'American', 'home?', 'at', 'Without', 'roasted', 'pathogens', 'spots', 'white', 'up', 'cover', 'do', 'environments', 'acidic', 'in', 'growing', 'capable', 'is', 'Dangerous', 'recipe?', 'sauce', 'cheese', 'affect',
'starch', 'potato', 'does', 'much', 'Restaurant;', 'not', 'on', 'my', 'cast', 'iron', 'but', 'Star', 'chef',
'if', 'Three', 'stove?', 'Michelin']
>>> model.labels # 训练集的类别标签
['__label__food-safety', '__label__acidity', '__label__sauce', '__label__baking', '__label__storage-lifetime', '__label__mayonnaise', '__label__dicing', '__label__knife-skills', '__label__cheese', '__label__convection', '__label__oven', '__label__restaurant', '__label__chocolate', '__label__storage-method', '__label__equipment', '__label__cast-iron',
'__label__stove', '__label__bread', '__label__peanuts', '__label__substitutions']
保存加载模型
使用 save_model
函数保存模型到指定的文件中,相对应的使用 load_model
函数到指定文件中加载模型。
>>> model.save_model(r"./model.bin")
代码语言:javascript复制>>> import fasttext
>>> load_model = fasttext.load_model(r"./model.bin")
评估模型
这里只是为了演示 fastText 工具的使用,为了方便,将训练集直接作为测试集使用。
代码语言:javascript复制>>> model.test(r"./train.data")
(10, 0.7, 0.2916666666666667)
其中 10 为用于测试的样本数量,0.7 为精确度,0.2916 为召回率。计算前 3 个类别的准确度和召回率。
代码语言:javascript复制>>> model.test(r"./train.data", k = 3)
(10, 0.4666666666666667, 0.5833333333333334)
模型预测
代码语言:javascript复制>>> model.predict("Which baking dish is best to bake a banana bread ?")
(('__label__storage-method',), array([0.05001693]))
对于多标签分类,还可以预测概率值最高的前 5 个类别标签。
代码语言:javascript复制>>> model.predict("Which baking dish is best to bake a banana bread ?", k = 5)
(('__label__storage-method', '__label__oven', '__label__baking', '__label__sauce', '__label__bread'), array([0.05001704, 0.05001483, 0.05001277, 0.05001069, 0.05001068]))