Paddle 使用预训练模型 实现快递单信息抽取

2021-09-06 11:18:53 浏览数 (1)

文章目录

    • 1. 导包
    • 2. 数据处理
    • 3. 辅助函数
      • 3.1 评估函数
      • 3.2 预测函数
      • 3.3 预测结果解码
    • 4. 训练

填写快递单据可以直接把所有信息直接粘贴进客户端,客户端自动识别 省市、人名、电话等信息,分类填入,然后打印出来粘贴。无须人工填写,加快了作业效率。

learn from : https://aistudio.baidu.com/aistudio/projectdetail/1329361

通过使用预训练模型 finetune,训练一个快递信息抽取模型。

1. 导包

代码语言:javascript复制
# 快递单信息抽取
from functools import partial # 打包函数,并给定默认参数
import paddle
from paddlenlp.datasets import MapDataset # 自定义数据集
from paddlenlp.data import Stack, Tuple, Pad # batch化工具函数
from paddlenlp.transformers import ErnieTokenizer, ErnieForTokenClassification
from paddlenlp.metrics import ChunkEvaluator # 指标计算
from paddle.utils.download import get_path_from_url

2. 数据处理

代码语言:javascript复制
URL = "https://paddlenlp.bj.bcebos.com/paddlenlp/datasets/waybill.tar.gz"
get_path_from_url(URL, "./")
epochs = 10
batch_size = 16


def load_dict(dict_path): # 读取字典
    vocab = {}
    i = 0
    for line in open(dict_path, 'r', encoding='utf-8'):
        key = line.strip('n')
        vocab[key] = i
        i  = 1
    return vocab


# 展示下数据格式
with open("./data/test.txt", 'r', encoding='utf-8') as f:
    i = 0
    for line in f:
        print(line)
        i  = 1
        if i > 5:
            break


# text_a	label
#
# 黑龙江省双鸭山市尖山区八马路与东平行路交叉口北40米韦业涛18600009172
# A1-BA1-IA1-IA1-IA2-BA2-IA2-IA2-IA3-BA3-IA3-IA4-BA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IP-BP-IP-IT-BT-IT-IT-IT-IT-IT-IT-IT-IT-IT-I
# A1 表示省,-B 开始, -I 内部, P 人名, T 电话
  • 数据转换函数,把文字转成数字 ids 类型
代码语言:javascript复制
def convert_example(example, tokenizer, label_vocab):
    tokens, labels = example
    tokenized_input = tokenizer(
        tokens, return_length=True, is_split_into_words=True)
    # Token '[CLS]' and '[SEP]' will get label 'O'
    labels = ['O']   labels   ['O']  # 大写的字母 O(欧)
    tokenized_input['labels'] = [label_vocab[x] for x in labels]
    return tokenized_input['input_ids'], tokenized_input['token_type_ids'], 
           tokenized_input['seq_len'], tokenized_input['labels']
    # 转成数字list
  • 加载数据集
代码语言:javascript复制
def load_dataset(datafiles):
    def read(data_path):
        with open(data_path, 'r', encoding='utf-8') as fp:
            next(fp)  # Skip header
            for line in fp.readlines():
                words, labels = line.strip('n').split('t')
                words = words.split('02')
                # ['1', '6', '6', '2', '0', '2', '0', '0', '0', '7', '7',
                # '宣', '荣', '嗣',
                # '甘', '肃', '省',
                # '白', '银', '市',
                # '会', '宁', '县',
                # '河', '畔', '镇', '十', '字', '街', '金', '海', '超', '市', '西', '行', '5', '0', '米']
                labels = labels.split('02')
                # ['T-B', 'T-I', 'T-I', 'T-I', 'T-I', 'T-I', 'T-I', 'T-I', 'T-I', 'T-I', 'T-I',
                # 'P-B', 'P-I', 'P-I',
                # 'A1-B', 'A1-I', 'A1-I',
                # 'A2-B', 'A2-I', 'A2-I',
                # 'A3-B', 'A3-I', 'A3-I',
                # 'A4-B', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I', 'A4-I']
                yield words, labels

    if isinstance(datafiles, str):
        return MapDataset(list(read(datafiles)))
    elif isinstance(datafiles, list) or isinstance(datafiles, tuple):
        return [MapDataset(list(read(datafile))) for datafile in datafiles]


# Create dataset, tokenizer and dataloader.
train_ds, dev_ds, test_ds = load_dataset(datafiles=(
    './data/train.txt', './data/dev.txt', './data/test.txt'))
  • batch化数据
代码语言:javascript复制
label_vocab = load_dict('./data/tag.dic')
# {'P-B': 0, 'P-I': 1, 'T-B': 2, 'T-I': 3, 'A1-B': 4, 'A1-I': 5,
# 'A2-B': 6, 'A2-I': 7, 'A3-B': 8, 'A3-I': 9, 'A4-B': 10, 'A4-I': 11, 'O': 12}

tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')

trans_func = partial(convert_example, tokenizer=tokenizer, label_vocab=label_vocab)

train_ds.map(trans_func)
dev_ds.map(trans_func)
test_ds.map(trans_func)
print(train_ds[0])
# ([1, 208, 515, 515, 249, 540, 249, 540, 540, 540, 589, 589, 803, 838, 2914, 1222, 1734, 244, 368, 797, 99, 32, 863, 308, 457, 2778, 484, 167, 436, 930, 192, 233, 634, 99, 213, 40, 317, 540, 256, 2], 
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
# 40, 
# [12, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1, 1, 4, 5, 5, 6, 7, 7, 8, 9, 9, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12])

ignore_label = -1

# batch化函数
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input_ids
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # token_type_ids
    Stack(dtype='int64'),  # seq_len
    Pad(axis=0, pad_val=ignore_label, dtype='int64')  # labels
): fn(samples)

# data_loader
train_loader = paddle.io.DataLoader(
    dataset=train_ds,
    batch_size=batch_size,
    return_list=True,
    collate_fn=batchify_fn)
dev_loader = paddle.io.DataLoader(
    dataset=dev_ds,
    batch_size=batch_size,
    return_list=True,
    collate_fn=batchify_fn)
test_loader = paddle.io.DataLoader(
    dataset=test_ds,
    batch_size=batch_size,
    return_list=True,
    collate_fn=batchify_fn)

3. 辅助函数

3.1 评估函数

代码语言:javascript复制
@paddle.no_grad()
def evaluate(model, metric, data_loader):
    model.eval()
    metric.reset()
    for input_ids, seg_ids, lens, labels in data_loader:
        logits = model(input_ids, seg_ids)
        preds = paddle.argmax(logits, axis=-1)
        n_infer, n_label, n_correct = metric.compute(None, lens, preds, labels)
        metric.update(n_infer.numpy(), n_label.numpy(), n_correct.numpy())
        precision, recall, f1_score = metric.accumulate()
    print("eval precision: %f - recall: %f - f1: %f" %
          (precision, recall, f1_score))
    model.train()

3.2 预测函数

代码语言:javascript复制
def predict(model, data_loader, ds, label_vocab):
    pred_list = []
    len_list = []
    for input_ids, seg_ids, lens, labels in data_loader:
        logits = model(input_ids, seg_ids)
        pred = paddle.argmax(logits, axis=-1)
        pred_list.append(pred.numpy())
        len_list.append(lens.numpy())
    preds = parse_decodes(ds, pred_list, len_list, label_vocab)
    return preds

3.3 预测结果解码

代码语言:javascript复制
def parse_decodes(ds, decodes, lens, label_vocab):
    decodes = [x for batch in decodes for x in batch]
    lens = [x for batch in lens for x in batch]
    id_label = dict(zip(label_vocab.values(), label_vocab.keys()))

    outputs = []
    for idx, end in enumerate(lens):
        sent = ds.data[idx][0][:end]
        tags = [id_label[x] for x in decodes[idx][1:end]]
        sent_out = []
        tags_out = []
        words = ""
        for s, t in zip(sent, tags):
            if t.endswith('-B') or t == 'O':
                if len(words):
                    sent_out.append(words)
                tags_out.append(t.split('-')[0])
                words = s
            else:
                words  = s
        if len(sent_out) < len(tags_out):
            sent_out.append(words)
        outputs.append(''.join(
            [str((s, t)) for s, t in zip(sent_out, tags_out)]))
    return outputs

4. 训练

代码语言:javascript复制
# 加载预训练模型
model = ErnieForTokenClassification.from_pretrained("ernie-1.0", num_classes=len(label_vocab))
# 指标
metric = ChunkEvaluator(label_list=label_vocab.keys(), suffix=True)
# 损失函数
loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label)
# 优化器
optimizer = paddle.optimizer.AdamW(learning_rate=2e-5, parameters=model.parameters())

# 训练
step = 0
for epoch in range(epochs):
    for idx, (input_ids, token_type_ids, length, labels) in enumerate(train_loader):
        logits = model(input_ids, token_type_ids)
        loss = paddle.mean(loss_fn(logits, labels))
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()
        step  = 1
        print("epoch:%d - step:%d - loss: %f" % (epoch, step, loss))
    # 每个 epoch 评估一次
    evaluate(model, metric, dev_loader)
    # 保存模型参数
    paddle.save(model.state_dict(),
                './ernie_result/model_%d.pdparams' % step)

# 训练完成,加载模型参数
state_dict = paddle.load("./ernie_result/model_450.pdparams")
model.load_dict(state_dict)

# 预测
preds = predict(model, test_loader, test_ds, label_vocab)
file_path = "ernie_results.txt"
with open(file_path, "w", encoding="utf8") as fout:
    fout.write("n".join(preds))
# 打印预测结果
print(
    "The results have been saved in the file: %s, some examples are shown below: "
    % file_path)
print("n".join(preds[:10]))

训练过程:

代码语言:javascript复制
epoch:0 - step:1 - loss: 2.788503
epoch:0 - step:2 - loss: 2.520449
epoch:0 - step:3 - loss: 2.365216
epoch:0 - step:4 - loss: 2.255839
epoch:0 - step:5 - loss: 2.108390
epoch:0 - step:6 - loss: 2.006438
...
epoch:0 - step:100 - loss: 0.045199
eval precision: 0.969141 - recall: 0.977292 - f1: 0.973199
epoch:1 - step:101 - loss: 0.026065
...
epoch:1 - step:200 - loss: 0.012335
eval precision: 0.984925 - recall: 0.989066 - f1: 0.986991
epoch:2 - step:201 - loss: 0.014337
...
epoch:2 - step:300 - loss: 0.004556
eval precision: 0.987427 - recall: 0.990749 - f1: 0.989085
epoch:3 - step:301 - loss: 0.003423
...
epoch:3 - step:400 - loss: 0.002968
eval precision: 0.987427 - recall: 0.990749 - f1: 0.989085
epoch:4 - step:401 - loss: 0.001868
...
epoch:4 - step:500 - loss: 0.016371
eval precision: 0.989933 - recall: 0.992431 - f1: 0.991180
epoch:5 - step:501 - loss: 0.006276
...
epoch:5 - step:530 - loss: 0.001634
...

一些预测结果:

代码语言:javascript复制
The results have been saved in the file: ernie_results.txt, some examples are shown below: 
('黑龙江省', 'A1')('双鸭山市', 'A2')('尖山区', 'A3')('八马路与东平行路交叉口北40米', 'A4')('韦业涛', 'P')('18600009172', 'T')
('广西壮族自治区', 'A1')('桂林市', 'A2')('雁山区', 'A3')('雁山镇西龙村老年活动中心', 'A4')('17610348888', 'T')('羊卓卫', 'P')
('15652864561', 'T')('河南省', 'A1')('开封市', 'A2')('顺河回族区', 'A3')('顺河区公园路32号', 'A4')('赵本山', 'P')
('河北省', 'A1')('唐山市', 'A2')('玉田县', 'A3')('无终大街159号', 'A4')('18614253058', 'T')('尚汉生', 'P')
('台湾', 'A1')('台中市', 'A2')('北区', 'A3')('北区锦新街18号', 'A4')('18511226708', 'T')('蓟丽', 'P')
('廖梓琪', 'P')('18514743222', 'T')('湖北省', 'A1')('宜昌市', 'A2')('长阳土家族自治县', 'A3')('贺家坪镇贺家坪村一组临河1号', 'A4')
('江苏省', 'A1')('南通市', 'A2')('海门市', 'A3')('孝威村孝威路88号', 'A4')('18611840623', 'T')('计星仪', 'P')
('17601674746', 'T')('赵春丽', 'P')('内蒙古自治区', 'A1')('乌兰察布市', 'A2')('凉城县', 'A3')('新建街', 'A4')
('云南省', 'A1')('临沧市', 'A2')('耿马傣族佤族自治县', 'A3')('鑫源路法院对面', 'A4')('许贞爱', 'P')('18510566685', 'T')
('四川省', 'A1')('成都市', 'A2')('双流区', 'A3')('东升镇北仓路196号', 'A4')('耿丕岭', 'P')('18513466161', 'T')

0 人点赞