博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
用深度学习做命名实体识别(五)-模型使用
阅读量:4332 次
发布时间:2019-06-06

本文共 19343 字,大约阅读时间需要 64 分钟。

通过本文,你将了解如何基于训练好的模型,来编写一个rest风格的命名实体提取接口,传入一个句子,接口会提取出句子中的人名、地址、组织、公司、产品、时间信息并返回。

核心模块entity_extractor.py

关键函数
# 加载实体识别模型def person_model_init():   ...   # 预测句子中的实体def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,            pred_ids,            tokenizer,            sess, max_seq_length):    ...
完整代码
# -*- coding: utf-8 -*-"""基于模型的地址提取"""__author__ = '程序员一一涤生'import codecsimport osimport picklefrom datetime import datetimefrom pprint import pprintimport numpy as npimport tensorflow as tffrom bert_base.bert import tokenization, modelingfrom bert_base.train.models import create_model, InputFeaturesfrom bert_base.train.train_helper import get_args_parserargs = get_args_parser()def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length):    feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p')    input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length))    input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length))    segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length))    label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length))    return input_ids, input_mask, segment_ids, label_idsdef predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,            pred_ids,            tokenizer,            sess, max_seq_length):    with graph.as_default():        start = datetime.now()        # print(id2label)        sentence = tokenizer.tokenize(sentence)        # print('your input is:{}'.format(sentence))        input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size,                                                                max_seq_length)        feed_dict = {input_ids_p: input_ids,                     input_mask_p: input_mask}        # run session get current feed_dict result        pred_ids_result = sess.run([pred_ids], feed_dict)        pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size)        # print(pred_ids_result)        print(pred_label_result)        # todo: 组合策略        result = strage_combined(sentence, pred_label_result[0], labels_config)        print('time used: {} sec'.format((datetime.now() - start).total_seconds()))    return result, pred_label_resultdef convert_id_to_label(pred_ids_result, idx2label, batch_size):    """    将id形式的结果转化为真实序列结果    :param pred_ids_result:    :param idx2label:    :return:    """    result = []    for row in range(batch_size):        curr_seq = []        for ids in pred_ids_result[row][0]:            if ids == 0:                break            curr_label = idx2label[ids]            if curr_label in ['[CLS]', '[SEP]']:                continue            curr_seq.append(curr_label)        result.append(curr_seq)    return resultdef strage_combined(tokens, tags, labels_config):    """    组合策略    :param pred_label_result:    :param types:    :return:    """    def get_output(rs, data, type):        words = []        for i in data:            words.append(str(i.word).replace("#", ""))            # words.append(i.word)        rs[type] = words        return rs    eval = Result(labels_config)    if len(tokens) > len(tags):        tokens = tokens[:len(tags)]    labels_dict = eval.get_result(tokens, tags)    arr = []    for k, v in labels_dict.items():        arr.append((k, v))    rs = {}    for item in arr:        rs = get_output(rs, item[1], item[0])    return rsdef convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode):    """    将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中    :param ex_index: index    :param example: 一个样本    :param label_list: 标签列表    :param max_seq_length:    :param tokenizer:    :param mode:    :return:    """    label_map = {}    # 1表示从1开始对label进行index化    for (i, label) in enumerate(label_list, 1):        label_map[label] = i    # 保存label->index 的map    if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):        with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:            pickle.dump(label_map, w)    tokens = example    # tokens = tokenizer.tokenize(example.text)    # 序列截断    if len(tokens) >= max_seq_length - 1:        tokens = tokens[0:(max_seq_length - 2)]  # -2 的原因是因为序列需要加一个句首和句尾标志    ntokens = []    segment_ids = []    label_ids = []    ntokens.append("[CLS]")  # 句子开始设置CLS 标志    segment_ids.append(0)    # append("O") or append("[CLS]") not sure!    label_ids.append(label_map["[CLS]"])  # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病    for i, token in enumerate(tokens):        ntokens.append(token)        segment_ids.append(0)        label_ids.append(0)    ntokens.append("[SEP]")  # 句尾添加[SEP] 标志    segment_ids.append(0)    # append("O") or append("[SEP]") not sure!    label_ids.append(label_map["[SEP]"])    input_ids = tokenizer.convert_tokens_to_ids(ntokens)  # 将序列中的字(ntokens)转化为ID形式    input_mask = [1] * len(input_ids)    # padding, 使用    while len(input_ids) < max_seq_length:        input_ids.append(0)        input_mask.append(0)        segment_ids.append(0)        # we don't concerned about it!        label_ids.append(0)        ntokens.append("**NULL**")        # label_mask.append(0)    # print(len(input_ids))    assert len(input_ids) == max_seq_length    assert len(input_mask) == max_seq_length    assert len(segment_ids) == max_seq_length    assert len(label_ids) == max_seq_length    # assert len(label_mask) == max_seq_length    # 结构化为一个类    feature = InputFeatures(        input_ids=input_ids,        input_mask=input_mask,        segment_ids=segment_ids,        label_ids=label_ids,        # label_mask = label_mask    )    return featureclass Pair(object):    def __init__(self, word, start, end, type, merge=False):        self.__word = word        self.__start = start        self.__end = end        self.__merge = merge        self.__types = type    @property    def start(self):        return self.__start    @property    def end(self):        return self.__end    @property    def merge(self):        return self.__merge    @property    def word(self):        return self.__word    @property    def types(self):        return self.__types    @word.setter    def word(self, word):        self.__word = word    @start.setter    def start(self, start):        self.__start = start    @end.setter    def end(self, end):        self.__end = end    @merge.setter    def merge(self, merge):        self.__merge = merge    @types.setter    def types(self, type):        self.__types = type    def __str__(self) -> str:        line = []        line.append('entity:{}'.format(self.__word))        line.append('start:{}'.format(self.__start))        line.append('end:{}'.format(self.__end))        line.append('merge:{}'.format(self.__merge))        line.append('types:{}'.format(self.__types))        return '\t'.join(line)class Result(object):    def __init__(self, labels_config):        self.others = []        self.labels_config = labels_config        self.labels = {}        for la in self.labels_config:            self.labels[la] = []    def get_result(self, tokens, tags):        # 先获取标注结果        self.result_to_json(tokens, tags)        return self.labels    def result_to_json(self, string, tags):        """        将模型标注序列和输入序列结合 转化为结果        :param string: 输入序列        :param tags: 标注结果        :return:        """        item = {"entities": []}        entity_name = ""        entity_start = 0        idx = 0        last_tag = ''        for char, tag in zip(string, tags):            if tag[0] == "S":                self.append(char, idx, idx + 1, tag[2:])                item["entities"].append({"word": char, "start": idx, "end": idx + 1, "type": tag[2:]})            elif tag[0] == "B":                if entity_name != '':                    self.append(entity_name, entity_start, idx, last_tag[2:])                    item["entities"].append(                        {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})                    entity_name = ""                entity_name += char                entity_start = idx            elif tag[0] == "I":                entity_name += char            elif tag[0] == "O":                if entity_name != '':                    self.append(entity_name, entity_start, idx, last_tag[2:])                    item["entities"].append(                        {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})                    entity_name = ""            else:                entity_name = ""                entity_start = idx            idx += 1            last_tag = tag        if entity_name != '':            self.append(entity_name, entity_start, idx, last_tag[2:])            item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})        return item    def append(self, word, start, end, tag):        if tag in self.labels_config:            self.labels[tag].append(Pair(word, start, end, tag))        else:            self.others.append(Pair(word, start, end, tag))def person_model_init():    return model_init("person")def model_init(model_name):    if os.name == 'nt':  # windows path config        model_dir = 'E:/quickstart/deeplearning/nlp_demo/%s/model' % model_name        bert_dir = 'E:/quickstart/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'    else:  # linux path config        model_dir = '/home/yjy/project/deeplearning/nlp_demo/%s/model' % model_name        bert_dir = '/home/yjy/project/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'    batch_size = 1    max_seq_length = 500    print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))    if not os.path.exists(os.path.join(model_dir, "checkpoint")):        raise Exception("failed to get checkpoint. going to return ")    # 加载label->id的词典    with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:        label2id = pickle.load(rf)        id2label = {value: key for key, value in label2id.items()}    with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:        label_list = pickle.load(rf)    num_labels = len(label_list) + 1    gpu_config = tf.ConfigProto()    gpu_config.gpu_options.allow_growth = True    graph = tf.Graph()    sess = tf.Session(graph=graph, config=gpu_config)    with graph.as_default():        print("going to restore checkpoint")        # sess.run(tf.global_variables_initializer())        input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids")        input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask")        bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))        (total_loss, logits, trans, pred_ids) = create_model(            bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,            segment_ids=None,            labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)        saver = tf.train.Saver()        saver.restore(sess, tf.train.latest_checkpoint(model_dir))    tokenizer = tokenization.FullTokenizer(        vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case)    return model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_lengthif __name__ == "__main__":    _model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length = person_model_init()    PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]    while True:        print('input the test sentence:')        _sentence = str(input())        pred_rs, pred_label_result = predict(_sentence, ADDRESS_LABELS, _model_dir, _batch_size, _id2label, _label_list,                                             _graph,                                             _input_ids_p,                                             _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length)        pprint(pred_rs)

编写rest风格的接口

我们将采用python的flask框架来提供rest接口。

首先,新建一个python项目,项目根路径下放入以下目录和文件:

%E7%94%A8%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%81%9A%E5%91%BD%E5%90%8D%E5%AE%9E%E4%BD%93%E8%AF%86%E5%88%AB(%E4%BA%94)%EF%BC%9A%E6%A8%A1%E5%9E%8B%E4%BD%BF%E7%94%A820190909203341.png

  • bert_base目录及文件、bert_model_info目录及文件在上一篇文章 给出的云盘项目中可以找到;
  • person目录下的model就是我们在上一篇文章中训练得到的命名实体识别模型以及一些附属文件,在项目的output目录下可以得到。
然后,创建启动文件nlp_main.py,内容如下:
# -*- coding: utf-8 -*-"""flask 入口"""import osimport nlp_config as ncfrom flaskr import create_app, loadProjContext__author__ = '程序员一一涤生'from flask import jsonify, make_response, redirect# 加载flask配置信息# app = create_app('config.DevelopmentConfig')app = create_app(nc.config['default'])# 加载项目上下文信息loadProjContext()@app.errorhandler(404)def not_found(error):    return make_response(jsonify({'error': 'Not found'}), 404)@app.errorhandler(400)def not_found(error):    return make_response(jsonify({'error': '400 Bad Request,参数或参数内容异常'}), 400)@app.route('/')def index_sf():    # return render_template('index.html')    return redirect('index.html')if __name__ == '__main__':    app.run('localhost', 5006, app, use_reloader=False)
接着,创建本flask项目的初始化文件flaskr.py,用于启动项目的时候预设置和加载一些信息,内容如下:
# -*- coding: utf-8 -*-"""flask初始化"""from logging.config import dictConfigfrom flask import Flaskfrom flask_cors import CORSimport address_ner_resourceimport person_ner_resourcefrom address_ner_resource import addressfrom entity_extractor import address_model_init, person_model_initfrom person_ner_resource import person__author__ = '程序员一一涤生'def create_app(config_type):    dictConfig({        'version': 1,        'formatters': {'default': {            'format': '[%(asctime)s] %(name)s %(levelname)s in %(module)s %(lineno)d: %(message)s',        }},        'handlers': {'wsgi': {            'class': 'logging.StreamHandler',            'stream': 'ext://flask.logging.wsgi_errors_stream',            'formatter': 'default'        }},        'root': {            'level': 'DEBUG',            # 'level': 'WARN',            # 'level': 'INFO',            'handlers': ['wsgi']        }    })    # 加载flask配置信息    app = Flask(__name__, static_folder='static', static_url_path='')    # CORS(app, resources=r'/*',origins=['192.168.1.104'])  # r'/*' 是通配符,允许跨域请求本服务器所有的URL,"origins": '*'表示允许所有ip跨域访问本服务器的url    CORS(app, resources={r"/*": {"origins": '*'}})  # r'/*' 是通配符,允许跨域请求本服务器所有的URL,"origins": '*'表示允许所有ip跨域访问本服务器的url    app.config.from_object(config_type)    app.register_blueprint(person, url_prefix='/person')    # 初始化上下文    ctx = app.app_context()    ctx.push()    return appdef loadProjContext():    # 加载人名提取模型    model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = person_model_init()    person_ner_resource.model_dir = model_dir    person_ner_resource.batch_size = batch_size    person_ner_resource.id2label = id2label    person_ner_resource.label_list = label_list    person_ner_resource.graph = graph    person_ner_resource.input_ids_p = input_ids_p    person_ner_resource.input_mask_p = input_mask_p    person_ner_resource.pred_ids = pred_ids    person_ner_resource.tokenizer = tokenizer    person_ner_resource.sess = sess    person_ner_resource.max_seq_length = max_seq_length
然后,创建配置文件nlp_config.py,用于切换生产、开发、测试环境,内容如下:
# -*- coding: utf-8 -*-"""本模块是Flask的配置模块"""import os__author__ = '程序员一一涤生'basedir = os.path.abspath(os.path.dirname(__file__))class BaseConfig:  # 基本配置类    SECRET_KEY = b'\xe4r\x04\xb5\xb2\x00\xf1\xadf\xa3\xf3V\x03\xc5\x9f\x82$^\xa25O\xf0R\xda'    JSONIFY_MIMETYPE = 'application/json; charset=utf-8'  # 默认JSONIFY_MIMETYPE的配置是不带'; charset=utf-8的'    JSON_AS_ASCII = False  # 若不关闭,使用JSONIFY返回json时中文会显示为Unicode字符    ENCODING = 'utf-8'    # 自定义的配置项    ADDRESS_LABELS = ["COUNTY", "STREET", "COMMUNITY", "ROAD", "NUM", "POI", "CITY", "VILLAGE"]    PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]class DevelopmentConfig(BaseConfig):    ENV = 'development'    DEBUG = Trueclass TestingConfig(BaseConfig):    TESTING = True    WTF_CSRF_ENABLED = Falseclass ProductionConfig(BaseConfig):    DEBUG = Falseconfig = {    'testing': TestingConfig,    'default': DevelopmentConfig    # 'default': ProductionConfig}
接着,创建人名识别接口文件person_ner_resource.py,内容如下:
# -*- coding: utf-8 -*-"""命名实体识别接口"""from entity_extractor import predict__author__ = '程序员一一涤生'from flask import Blueprint, make_response, request, current_appfrom flask import jsonifyperson = Blueprint('person', __name__)model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = None, None, None, None, None, None, None, None, None, None, None@person.route('/extract', methods=['POST'])def extract():    params = request.get_json()    if 't' not in params or params['t'] is None or len(params['t']) > 500 or len(params['t']) < 2:        return make_response(jsonify({'error': '文本长度不符合要求,长度限制:2~500'}), 400)    sentence = params['t']    # 成句    sentence = sentence + "。" if not sentence.endswith((",", "。", "!", "?")) else sentence    # 利用模型提取    pred_rs, pred_label_result = predict(sentence, current_app.config['PERSON_LABELS'], model_dir, batch_size, id2label,                                         label_list, graph, input_ids_p,                                         input_mask_p,                                         pred_ids, tokenizer, sess, max_seq_length)    print(sentence)    return jsonify(pred_rs)if __name__ == '__main__':    pass
接着,将requirements.txt文件放到项目根路径下,文件内容如下:
absl-py==0.7.0astor==0.7.1backcall==0.1.0backports.weakref==1.0rc1bleach==1.5.0certifi==2016.2.28click==6.7colorama==0.4.1colorful==0.5.0decorator==4.3.2defusedxml==0.5.0entrypoints==0.3Flask==1.0.2Flask-Cors==3.0.3gast==0.2.2grpcio==1.18.0h5py==2.9.0html5lib==0.9999999ipykernel==5.1.0ipython==7.2.0ipython-genutils==0.2.0ipywidgets==7.4.2itsdangerous==0.24jedi==0.13.2Jinja2==2.10jsonschema==2.6.0jupyter==1.0.0jupyter-client==5.2.4jupyter-console==6.0.0jupyter-core==4.4.0Keras-Applications==1.0.6Keras-Preprocessing==1.0.5Markdown==3.0.1MarkupSafe==1.1.0mistune==0.8.4mock==3.0.5nbconvert==5.4.0nbformat==4.4.0notebook==5.7.4numpy==1.16.0pandocfilters==1.4.2parso==0.3.2pickleshare==0.7.5prettyprinter==0.17.0prometheus-client==0.5.0prompt-toolkit==2.0.8protobuf==3.6.1Pygments==2.3.1python-dateutil==2.7.5pywinpty==0.5.5pyzmq==17.1.2qtconsole==4.4.3Send2Trash==1.5.0six==1.12.0tensorboard==1.13.1tensorflow==1.13.1tensorflow-estimator==1.13.0termcolor==1.1.0terminado==0.8.1testpath==0.4.2tornado==5.1.1traitlets==4.3.2wcwidth==0.1.7Werkzeug==0.14.1widgetsnbextension==3.4.2wincertstore==0.2
然后,执行如下命令,安装requirements.txt中的包:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt

以上步骤完成后,我们就可以尝试启动项目了。

启动项目

运行如下命令,启动该flask项目:

python nlp_main.py

调用接口

本文使用postman来调用命名实体提取接口,接口地址:

http://localhost:5006/person/extract

调用效果展示:

%E7%94%A8%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%81%9A%E5%91%BD%E5%90%8D%E5%AE%9E%E4%BD%93%E8%AF%86%E5%88%AB(%E4%BA%94)-%E6%A8%A1%E5%9E%8B%E4%BD%BF%E7%94%A820190909205638.png
%E7%94%A8%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%81%9A%E5%91%BD%E5%90%8D%E5%AE%9E%E4%BD%93%E8%AF%86%E5%88%AB(%E4%BA%94)-%E6%A8%A1%E5%9E%8B%E4%BD%BF%E7%94%A820190909205702.png

注意,在cpu上使用模型的时间大概在2到3秒,而如果项目部署在搭载了支持深度学习的GPU的电脑上,接口的返回会快很多很多,当然不要忘记将tensorflow改为安装tensorflow-gpu。

本篇就这么多内容,到此,我们已经基于深度学习开发了一个可以从自然语言中提取出人名、地址、组织、公司、产品、时间的项目,从下一篇开始,我们将介绍本项目使用的深度学习算法Bert和crf,通过对算法的了解,我们将更好的理解为什么模型能够准确的从句子中提取出我们想要的实体。

ok,本篇就这么多内容啦~,感谢阅读O(∩_∩)O,88~

本博客内容来自公众号“程序员一一涤生”,欢迎扫码关注 o(∩_∩)o

%E5%85%AC%E4%BC%97%E5%8F%B7%E4%BA%8C%E7%BB%B4%E7%A0%8120190909174946.png

转载于:https://www.cnblogs.com/anai/p/11571812.html

你可能感兴趣的文章
PHP基础(二)
查看>>
lvm逻辑卷扩展方法
查看>>
JAVA锁
查看>>
C语言程序的内存分配方式
查看>>
将硬盘从FAT32转化为NTFS以支持everything搜索
查看>>
2、JAVA基础- 关键字、标识符、常变量、数据类型、注释等
查看>>
form表单上传图片格式
查看>>
颜色追踪块CamShift---33
查看>>
c++字符串变量---8
查看>>
phpcms V9首页 频道页 列表页 推荐位 简单获取文章浏览量和评论统计
查看>>
Navicat 报错1251连接不成功Mysql
查看>>
【新年福利】《正则表达式30分钟入门》APP版本发布
查看>>
R语言排序函数汇总
查看>>
MSsql2005如何启用xp_cmdshell
查看>>
Forbidden(403)的3种处理方式
查看>>
[转]Vim 复制粘帖格式错乱问题的解决办法
查看>>
Hexo 博客搭建指南
查看>>
C#生成静态文件
查看>>
【并查集入门专题1】A+B+D 三道模板题 hdu1232 hdu1233 poj2524【并查集模板】
查看>>
[Django 2]第一个django应用
查看>>