通过本文,你将了解如何基于训练好的模型,来编写一个rest风格的命名实体提取接口,传入一个句子,接口会提取出句子中的人名、地址、组织、公司、产品、时间信息并返回。
核心模块entity_extractor.py
关键函数
代码语言:txt复制# 加载实体识别模型
def person_model_init():
...
# 预测句子中的实体
def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
pred_ids,
tokenizer,
sess, max_seq_length):
...
完整代码
代码语言:txt复制# -*- coding: utf-8 -*-
"""
基于模型的地址提取
"""
__author__ = '程序员一一涤生'
import codecs
import os
import pickle
from datetime import datetime
from pprint import pprint
import numpy as np
import tensorflow as tf
from bert_base.bert import tokenization, modeling
from bert_base.train.models import create_model, InputFeatures
from bert_base.train.train_helper import get_args_parser
args = get_args_parser()
def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length):
feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p')
input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length))
input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length))
segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length))
label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length))
return input_ids, input_mask, segment_ids, label_ids
def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
pred_ids,
tokenizer,
sess, max_seq_length):
with graph.as_default():
start = datetime.now()
# print(id2label)
sentence = tokenizer.tokenize(sentence)
# print('your input is:{}'.format(sentence))
input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size,
max_seq_length)
feed_dict = {input_ids_p: input_ids,
input_mask_p: input_mask}
# run session get current feed_dict result
pred_ids_result = sess.run([pred_ids], feed_dict)
pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size)
# print(pred_ids_result)
print(pred_label_result)
# todo: 组合策略
result = strage_combined(sentence, pred_label_result[0], labels_config)
print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
return result, pred_label_result
def convert_id_to_label(pred_ids_result, idx2label, batch_size):
"""
将id形式的结果转化为真实序列结果
:param pred_ids_result:
:param idx2label:
:return:
"""
result = []
for row in range(batch_size):
curr_seq = []
for ids in pred_ids_result[row][0]:
if ids == 0:
break
curr_label = idx2label[ids]
if curr_label in ['[CLS]', '[SEP]']:
continue
curr_seq.append(curr_label)
result.append(curr_seq)
return result
def strage_combined(tokens, tags, labels_config):
"""
组合策略
:param pred_label_result:
:param types:
:return:
"""
def get_output(rs, data, type):
words = []
for i in data:
words.append(str(i.word).replace("#", ""))
# words.append(i.word)
rs[type] = words
return rs
eval = Result(labels_config)
if len(tokens) > len(tags):
tokens = tokens[:len(tags)]
labels_dict = eval.get_result(tokens, tags)
arr = []
for k, v in labels_dict.items():
arr.append((k, v))
rs = {}
for item in arr:
rs = get_output(rs, item[1], item[0])
return rs
def convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode):
"""
将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
:param ex_index: index
:param example: 一个样本
:param label_list: 标签列表
:param max_seq_length:
:param tokenizer:
:param mode:
:return:
"""
label_map = {}
# 1表示从1开始对label进行index化
for (i, label) in enumerate(label_list, 1):
label_map[label] = i
# 保存label->index 的map
if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
pickle.dump(label_map, w)
tokens = example
# tokens = tokenizer.tokenize(example.text)
# 序列截断
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志
ntokens = []
segment_ids = []
label_ids = []
ntokens.append("[CLS]") # 句子开始设置CLS 标志
segment_ids.append(0)
# append("O") or append("[CLS]") not sure!
label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
for i, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
label_ids.append(0)
ntokens.append("[SEP]") # 句尾添加[SEP] 标志
segment_ids.append(0)
# append("O") or append("[SEP]") not sure!
label_ids.append(label_map["[SEP]"])
input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式
input_mask = [1] * len(input_ids)
# padding, 使用
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
# we don't concerned about it!
label_ids.append(0)
ntokens.append("**NULL**")
# label_mask.append(0)
# print(len(input_ids))
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length
# assert len(label_mask) == max_seq_length
# 结构化为一个类
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_ids=label_ids,
# label_mask = label_mask
)
return feature
class Pair(object):
def __init__(self, word, start, end, type, merge=False):
self.__word = word
self.__start = start
self.__end = end
self.__merge = merge
self.__types = type
@property
def start(self):
return self.__start
@property
def end(self):
return self.__end
@property
def merge(self):
return self.__merge
@property
def word(self):
return self.__word
@property
def types(self):
return self.__types
@word.setter
def word(self, word):
self.__word = word
@start.setter
def start(self, start):
self.__start = start
@end.setter
def end(self, end):
self.__end = end
@merge.setter
def merge(self, merge):
self.__merge = merge
@types.setter
def types(self, type):
self.__types = type
def __str__(self) -> str:
line = []
line.append('entity:{}'.format(self.__word))
line.append('start:{}'.format(self.__start))
line.append('end:{}'.format(self.__end))
line.append('merge:{}'.format(self.__merge))
line.append('types:{}'.format(self.__types))
return 't'.join(line)
class Result(object):
def __init__(self, labels_config):
self.others = []
self.labels_config = labels_config
self.labels = {}
for la in self.labels_config:
self.labels[la] = []
def get_result(self, tokens, tags):
# 先获取标注结果
self.result_to_json(tokens, tags)
return self.labels
def result_to_json(self, string, tags):
"""
将模型标注序列和输入序列结合 转化为结果
:param string: 输入序列
:param tags: 标注结果
:return:
"""
item = {"entities": []}
entity_name = ""
entity_start = 0
idx = 0
last_tag = ''
for char, tag in zip(string, tags):
if tag[0] == "S":
self.append(char, idx, idx 1, tag[2:])
item["entities"].append({"word": char, "start": idx, "end": idx 1, "type": tag[2:]})
elif tag[0] == "B":
if entity_name != '':
self.append(entity_name, entity_start, idx, last_tag[2:])
item["entities"].append(
{"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
entity_name = ""
entity_name = char
entity_start = idx
elif tag[0] == "I":
entity_name = char
elif tag[0] == "O":
if entity_name != '':
self.append(entity_name, entity_start, idx, last_tag[2:])
item["entities"].append(
{"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
entity_name = ""
else:
entity_name = ""
entity_start = idx
idx = 1
last_tag = tag
if entity_name != '':
self.append(entity_name, entity_start, idx, last_tag[2:])
item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
return item
def append(self, word, start, end, tag):
if tag in self.labels_config:
self.labels[tag].append(Pair(word, start, end, tag))
else:
self.others.append(Pair(word, start, end, tag))
def person_model_init():
return model_init("person")
def model_init(model_name):
if os.name == 'nt': # windows path config
model_dir = 'E:/quickstart/deeplearning/nlp_demo/%s/model' % model_name
bert_dir = 'E:/quickstart/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'
else: # linux path config
model_dir = '/home/yjy/project/deeplearning/nlp_demo/%s/model' % model_name
bert_dir = '/home/yjy/project/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'
batch_size = 1
max_seq_length = 500
print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))
if not os.path.exists(os.path.join(model_dir, "checkpoint")):
raise Exception("failed to get checkpoint. going to return ")
# 加载label->id的词典
with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
label2id = pickle.load(rf)
id2label = {value: key for key, value in label2id.items()}
with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
label_list = pickle.load(rf)
num_labels = len(label_list) 1
gpu_config = tf.ConfigProto()
gpu_config.gpu_options.allow_growth = True
graph = tf.Graph()
sess = tf.Session(graph=graph, config=gpu_config)
with graph.as_default():
print("going to restore checkpoint")
# sess.run(tf.global_variables_initializer())
input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids")
input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask")
bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))
(total_loss, logits, trans, pred_ids) = create_model(
bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,
segment_ids=None,
labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(model_dir))
tokenizer = tokenization.FullTokenizer(
vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case)
return model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length
if __name__ == "__main__":
_model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length = person_model_init()
PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]
while True:
print('input the test sentence:')
_sentence = str(input())
pred_rs, pred_label_result = predict(_sentence, PERSON_LABELS, _model_dir, _batch_size, _id2label, _label_list,
_graph,
_input_ids_p,
_input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length)
pprint(pred_rs)
编写rest风格的接口
我们将采用python的flask框架来提供rest接口。
首先,新建一个python项目,项目根路径下放入以下目录和文件:
- bert_base目录及文件、bert_model_info目录及文件在上一篇文章 用深度学习做命名实体识别(四)——模型训练 给出的云盘项目中可以找到;
- person目录下的model就是我们在上一篇文章中训练得到的命名实体识别模型以及一些附属文件,在项目的output目录下可以得到。
然后,创建启动文件nlp_main.py,内容如下:
代码语言:txt复制# -*- coding: utf-8 -*-
"""
flask 入口
"""
import os
import nlp_config as nc
from flaskr import create_app, loadProjContext
__author__ = '程序员一一涤生'
from flask import jsonify, make_response, redirect
# 加载flask配置信息
# app = create_app('config.DevelopmentConfig')
app = create_app(nc.config['default'])
# 加载项目上下文信息
loadProjContext()
@app.errorhandler(404)
def not_found(error):
return make_response(jsonify({'error': 'Not found'}), 404)
@app.errorhandler(400)
def not_found(error):
return make_response(jsonify({'error': '400 Bad Request,参数或参数内容异常'}), 400)
@app.route('/')
def index_sf():
# return render_template('index.html')
return redirect('index.html')
if __name__ == '__main__':
app.run('localhost', 5006, app, use_reloader=False)
接着,创建本flask项目的初始化文件flaskr.py,用于启动项目的时候预设置和加载一些信息,内容如下:
代码语言:txt复制# -*- coding: utf-8 -*-
"""
flask初始化
"""
from logging.config import dictConfig
from flask import Flask
from flask_cors import CORS
import person_ner_resource
from entity_extractor import person_model_init
from person_ner_resource import person
__author__ = '程序员一一涤生'
def create_app(config_type):
dictConfig({
'version': 1,
'formatters': {'default': {
'format': '[%(asctime)s] %(name)s %(levelname)s in %(module)s %(lineno)d: %(message)s',
}},
'handlers': {'wsgi': {
'class': 'logging.StreamHandler',
'stream': 'ext://flask.logging.wsgi_errors_stream',
'formatter': 'default'
}},
'root': {
'level': 'DEBUG',
# 'level': 'WARN',
# 'level': 'INFO',
'handlers': ['wsgi']
}
})
# 加载flask配置信息
app = Flask(__name__, static_folder='static', static_url_path='')
# CORS(app, resources=r'/*',origins=['192.168.1.104']) # r'/*' 是通配符,允许跨域请求本服务器所有的URL,"origins": '*'表示允许所有ip跨域访问本服务器的url
CORS(app, resources={r"/*": {"origins": '*'}}) # r'/*' 是通配符,允许跨域请求本服务器所有的URL,"origins": '*'表示允许所有ip跨域访问本服务器的url
app.config.from_object(config_type)
app.register_blueprint(person, url_prefix='/person')
# 初始化上下文
ctx = app.app_context()
ctx.push()
return app
def loadProjContext():
# 加载人名提取模型
model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = person_model_init()
person_ner_resource.model_dir = model_dir
person_ner_resource.batch_size = batch_size
person_ner_resource.id2label = id2label
person_ner_resource.label_list = label_list
person_ner_resource.graph = graph
person_ner_resource.input_ids_p = input_ids_p
person_ner_resource.input_mask_p = input_mask_p
person_ner_resource.pred_ids = pred_ids
person_ner_resource.tokenizer = tokenizer
person_ner_resource.sess = sess
person_ner_resource.max_seq_length = max_seq_length
然后,创建配置文件nlp_config.py,用于切换生产、开发、测试环境,内容如下:
代码语言:txt复制# -*- coding: utf-8 -*-
"""
本模块是Flask的配置模块
"""
import os
__author__ = '程序员一一涤生'
basedir = os.path.abspath(os.path.dirname(__file__))
class BaseConfig: # 基本配置类
SECRET_KEY = b'xe4rx04xb5xb2x00xf1xadfxa3xf3Vx03xc5x9fx82$^xa25Oxf0Rxda'
JSONIFY_MIMETYPE = 'application/json; charset=utf-8' # 默认JSONIFY_MIMETYPE的配置是不带'; charset=utf-8的'
JSON_AS_ASCII = False # 若不关闭,使用JSONIFY返回json时中文会显示为Unicode字符
ENCODING = 'utf-8'
# 自定义的配置项
PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]
class DevelopmentConfig(BaseConfig):
ENV = 'development'
DEBUG = True
class TestingConfig(BaseConfig):
TESTING = True
WTF_CSRF_ENABLED = False
class ProductionConfig(BaseConfig):
DEBUG = False
config = {
'testing': TestingConfig,
'default': DevelopmentConfig
# 'default': ProductionConfig
}
接着,创建人名识别接口文件person_ner_resource.py,内容如下:
代码语言:txt复制# -*- coding: utf-8 -*-
"""
命名实体识别接口
"""
from entity_extractor import predict
__author__ = '程序员一一涤生'
from flask import Blueprint, make_response, request, current_app
from flask import jsonify
person = Blueprint('person', __name__)
model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = None, None, None, None, None, None, None, None, None, None, None
@person.route('/extract', methods=['POST'])
def extract():
params = request.get_json()
if 't' not in params or params['t'] is None or len(params['t']) > 500 or len(params['t']) < 2:
return make_response(jsonify({'error': '文本长度不符合要求,长度限制:2~500'}), 400)
sentence = params['t']
# 成句
sentence = sentence "。" if not sentence.endswith((",", "。", "!", "?")) else sentence
# 利用模型提取
pred_rs, pred_label_result = predict(sentence, current_app.config['PERSON_LABELS'], model_dir, batch_size, id2label,
label_list, graph, input_ids_p,
input_mask_p,
pred_ids, tokenizer, sess, max_seq_length)
print(sentence)
return jsonify(pred_rs)
if __name__ == '__main__':
pass
接着,将requirements.txt文件放到项目根路径下,文件内容如下:
代码语言:txt复制absl-py==0.7.0
astor==0.7.1
backcall==0.1.0
backports.weakref==1.0rc1
bleach==1.5.0
certifi==2016.2.28
click==6.7
colorama==0.4.1
colorful==0.5.0
decorator==4.3.2
defusedxml==0.5.0
entrypoints==0.3
Flask==1.0.2
Flask-Cors==3.0.3
gast==0.2.2
grpcio==1.18.0
h5py==2.9.0
html5lib==0.9999999
ipykernel==5.1.0
ipython==7.2.0
ipython-genutils==0.2.0
ipywidgets==7.4.2
itsdangerous==0.24
jedi==0.13.2
Jinja2==2.10
jsonschema==2.6.0
jupyter==1.0.0
jupyter-client==5.2.4
jupyter-console==6.0.0
jupyter-core==4.4.0
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
Markdown==3.0.1
MarkupSafe==1.1.0
mistune==0.8.4
mock==3.0.5
nbconvert==5.4.0
nbformat==4.4.0
notebook==5.7.4
numpy==1.16.0
pandocfilters==1.4.2
parso==0.3.2
pickleshare==0.7.5
prettyprinter==0.17.0
prometheus-client==0.5.0
prompt-toolkit==2.0.8
protobuf==3.6.1
Pygments==2.3.1
python-dateutil==2.7.5
pywinpty==0.5.5
pyzmq==17.1.2
qtconsole==4.4.3
Send2Trash==1.5.0
six==1.12.0
tensorboard==1.13.1
tensorflow==1.13.1
tensorflow-estimator==1.13.0
termcolor==1.1.0
terminado==0.8.1
testpath==0.4.2
tornado==5.1.1
traitlets==4.3.2
wcwidth==0.1.7
Werkzeug==0.14.1
widgetsnbextension==3.4.2
wincertstore==0.2
然后,执行如下命令,安装requirements.txt中的包:
代码语言:txt复制pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
以上步骤完成后,我们就可以尝试启动项目了。
启动项目
运行如下命令,启动该flask项目:
代码语言:txt复制python nlp_main.py
调用接口
本文使用postman来调用命名实体提取接口,接口地址:
http://localhost:5006/person/extract
调用效果展示:
注意,在cpu上使用模型的时间大概在2到3秒,而如果项目部署在搭载了支持深度学习的GPU的电脑上,接口的返回会快很多很多,当然不要忘记将tensorflow改为安装tensorflow-gpu。
ok,我们已经基于深度学习开发了一个可以从自然语言中提取出人名、地址、组织、公司、产品、时间的项目,从下一篇开始,我们将介绍本项目使用的深度学习算法Bert和crf,通过对算法的了解,我们将更好的理解为什么模型能够准确的从句子中提取出我们想要的实体。
本篇就这么多内容啦~,感谢阅读O(∩_∩)O,88~
腾讯云部分产品一览:
云服务器,云硬盘,数据库,CDN流量包,短信流量包,cos资源包,消息队列ckafka,点播资源包,实时音视频套餐,网站管家(WAF),大禹BGP高防(包含高防包及高防IP),云解析,SSL证书,手游安全MTP,移动应用安全、 云直播等等。