文本生成是自然语言处理中非常重要且热门的领域。摘要抽取、智能回复、诗词创作、生成标题、生成商品描述、机器人写新闻等等都属于文本生成的范畴,应用极其广泛。
今天给大家推荐一个高效、易用的文本生成的开源项目-TextBox(妙笔),是由中国人民大学 AI BOX 团队推出的。
首先奉上 GitHub 地址:https://github.com/RUCAIBox/TextBox。
TextBox 是基于 Python 和 Pytorch 开发的,该文本生成库提供了 21 种文本生成算法和 9 种数据集。
算法主要涵盖两项任务:
- 无条件(无输入)生成
- 有条件的(Seq2Seq)生成,包括机器翻译、文本摘要、对话系统等
TextBox 的总体架构是这样的:
我们再来看下 TextBox 的四大特色。
- 统一和模块化的框架。TextBox 建立在 PyTorch 的基础上,将各种模型分离为一组高度可重用的模块,被设计为高度模块化。
- 全面的模型、基准数据集和标准化评估。TextBox 还包含多种文本生成模型,涵盖基于 VAE、GAN、RNN 或 Transformer 的模型以及预训练语言模型(PLM)的类别。
- 可扩展且灵活的框架。TextBox 在文本生成模型、RNN 编码器-解码器、Transformer编码器-解码器和预训练语言模型中提供了各种常用功能或模块的便捷接口。
- 轻松便捷地开始使用。TextBox 提供了灵活的配置文件,可以让绿色的手在不修改源代码的情况下进行实验,并允许研究人员通过修改少量配置来进行定性分析。
from textbox.quick_start import run_textbox
run_textbox(config_dict={'model': 'RNN',
'dataset': 'COCO',
'data_path': './dataset'})
我们再来看下如何快速上手 TextBox。
关于安装
Python >= 3.6.2
torch >= 1.6.0
GCC >= 5.1.0
TextBox 支持 pip 和源码安装。
- pip安装:pip install textbox
- 源码安装: git clone https://github.com/RUCAIBox/TextBox.git && cd TextBox
pip install -e . --verbose
如何运行
环境配置好了,我们如何来使用这个库呢?
该库提供了直接运行的脚本,使用 python run_textbox.py 运行该脚本即可,在 COCO 数据集上运行 RNN 模型进行无条件生成。
如果要改变参数,比如 rnn_type 等,只需根据需要设置其他命令参数:
代码语言:javascript复制python run_textbox.py --rnn_type = lstm --max_vocab_size = 4000。
如果要更改数据集和模型,也是只需通过修改相应的命令参数来运行脚本:
代码语言:javascript复制python run_textbox.py --model = [model_name] --dataset = [dataset_name]
其中模型名称可以选择,比如 RNN、GPT-2 等。
如果你是通过 pip 安装了 TextBox,则可以创建一个新的 python 文件,通过调用 api 即可实现模型的训练和测试。
代码语言:javascript复制from textbox.quick_start import run_textbox
run_textbox(config_dict={'model': 'RNN',
'dataset': 'COCO',
'data_path': './dataset'})
这是在 COCO 数据集上进行 RNN 模型的训练和测试。如果想用不同的数据集和模型进行修改即可。 使用预训练语言模型
TextBox 支持应用部分预训练的语言模型(PLM)进行文本生成。以GPT-2 为例,下面将展示如何使用 PLM 进行微调。
从 huggingface 提供的模型源(https://huggingface.co/gpt2/tree/main) 中下载 GPT-2 模型,包括 config.json
,merges.txt
,pytorch_model.bin
,tokenizer.json
和 vocab.json
。然后将它们放在与相同级别的文件夹中 textbox
,例如 pretrained_model/gpt2
。
下载后,您只需要运行以下命令:
代码语言:javascript复制
python run_textbox.py --model=GPT2 --dataset=COCO
--pretrained_model_path=pretrained_model/gpt2
使用分布式数据并行(DDP)进行训练
TextBox 支持使用多个 GPU 训练模型。您无需修改模型,只需运行以下命令:
代码语言:javascript复制python -m torch.distributed.launch --nproc_per_node=[gpu_num]
run_textbox.py --model=[model_name]
--dataset=[dataset_name] --gpu_id=[gpu_ids] --DDP=True
最后,给大家展示下涵盖到的 21 种模型和 9 种数据集。
其中的 21 种模型如下表所示:
Category | Model | Reference |
---|---|---|
VAE | LSTMVAE | (Bowman et al., 2016) |
CNNVAE | (Yang et al., 2017) | |
HybridVAE | (Semeniuta et al., 2017) | |
CVAE | (Li et al., 2018) | |
GAN | SeqGAN | (Yu et al., 2017) |
TextGAN | (Zhang et al., 2017) | |
RankGAN | (Lin et al., 2017) | |
MaliGAN | (Che et al., 2017) | |
LeakGAN | (Guo et al., 2018) | |
MaskGAN | (Fedus et al., 2018) | |
PLM | GPT-2 | (Radford et al., 2019) |
XLNet | (Yang et al., 2019) | |
BERT2BERT | (Rothe et al., 2020) | |
BART | (Lewis et al., 2020) | |
T5 | (Raffel et al., 2020) | |
ProphetNet | (Qi et al., 2020) | |
Seq2Seq | RNN | (Sutskever et al., 2014) |
Transformer | (Vaswani et al., 2017b) | |
Context2Seq | (Tang et al., 2016) | |
Attr2Seq | (Dong et al., 2017) | |
HRED | (Serban et al., 2016) |
其中的 9 种数据集如下表:
Task | Dataset |
---|---|
Unconditional | Image COCO Caption |
EMNLP2017 WMT News | |
IMDB Movie Review | |
Translation | IWSLT2014 German-English |
WMT2014 English-German | |
Summarization | GigaWord |
Dialog | Persona Chat |
Attribute to Text | Amazon Electronic |
Poem Generation | Chinese Classical Poetry Corpus |
说了这么多,总之,妙笔(TextBox),真的是文本生成的妙妙工具!