未验证 提交 f8e6172b 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1773 from w5688414/ernie

Add Ernie Indexing Example
# In-batch Negatives
**目录**
* [模型下载](#模型下载)
* [模型部署](#模型部署)
<a name="模型下载"></a>
## 1. 语义索引模型
**语义索引训练模型下载链接:**
以下模型结构参数为: `TrasformerLayer:12, Hidden:768, Heads:12, OutputEmbSize: 256`
|Model|训练参数配置|硬件|MD5|
| ------------ | ------------ | ------------ |-----------|
|[batch_neg](https://bj.bcebos.com/v1/paddlenlp/models/inbatch_model.zip)|<div style="width: 150pt">margin:0.2 scale:30 epoch:3 lr:5E-5 bs:64 max_len:64 </div>|<div style="width: 100pt">4卡 v100-16g</div>|f3e5c7d7b0b718c2530c5e1b136b2d74|
```
wget https://bj.bcebos.com/v1/paddlenlp/models/inbatch_model.zip
unzip inbatch_model.zip -d checkpoints
```
<a name="模型部署"></a>
## 2. 模型部署
### 2.1 动转静导出
首先把动态图模型转换为静态图:
```
python export_model.py --params_path checkpoints/model_40/model_state.pdparams --output_path=./output
```
也可以运行下面的bash脚本:
```
sh scripts/export_model.sh
```
### 2.2 Paddle Inference预测
预测既可以抽取向量也可以计算两个文本的相似度。
修改id2corpus的样本:
```
# 抽取向量
id2corpus={0:'国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据'}
# 计算相似度
corpus_list=[['中西方语言与文化的差异','中西方文化差异以及语言体现中西方文化,差异,语言体现'],
['中西方语言与文化的差异','飞桨致力于让深度学习技术的创新与应用更简单']]
```
然后使用PaddleInference
```
python deploy/python/predict.py --model_dir=./output
```
也可以运行下面的bash脚本:
```
sh deploy.sh
```
最终输出的是256维度的特征向量和句子对的预测概率:
```
(1, 256)
[[-0.0394925 -0.04474756 -0.065534 0.00939134 0.04359895 0.14659195
-0.0091779 -0.07303623 0.09413272 -0.01255222 -0.08685658 0.02762237
0.10138468 0.00962821 0.10888419 0.04553023 0.05898942 0.00694253
....
[0.959269642829895, 0.04725276678800583]
```
### 2.3 Paddle Serving部署
Paddle Serving 的详细文档请参考 [Pipeline_Design](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Python_Pipeline/Pipeline_Design_CN.md)[Serving_Design](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Serving_Design_CN.md),首先把静态图模型转换成Serving的格式:
```
python export_to_serving.py \
--dirname "output" \
--model_filename "inference.get_pooled_embedding.pdmodel" \
--params_filename "inference.get_pooled_embedding.pdiparams" \
--server_path "./serving_server" \
--client_path "./serving_client" \
--fetch_alias_names "output_embedding"
```
参数含义说明
* `dirname`: 需要转换的模型文件存储路径,Program 结构文件和参数文件均保存在此目录。
* `model_filename`: 存储需要转换的模型 Inference Program 结构的文件名称。如果设置为 None ,则使用 `__model__` 作为默认的文件名
* `params_filename`: 存储需要转换的模型所有参数的文件名称。当且仅当所有模型参数被保>存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为 None
* `server_path`: 转换后的模型文件和配置文件的存储路径。默认值为 serving_server
* `client_path`: 转换后的客户端配置文件存储路径。默认值为 serving_client
* `fetch_alias_names`: 模型输出的别名设置,比如输入的 input_ids 等,都可以指定成其他名字,默认不指定
* `feed_alias_names`: 模型输入的别名设置,比如输出 pooled_out 等,都可以重新指定成其他模型,默认不指定
也可以运行下面的 bash 脚本:
```
sh scripts/export_to_serving.sh
```
Paddle Serving的部署有两种方式,第一种方式是Pipeline的方式,第二种是C++的方式,下面分别介绍这两种方式的用法:
#### 2.3.1 Pipeline方式
启动 Pipeline Server:
```
python web_service.py
```
启动客户端调用 Server。
首先修改rpc_client.py中需要预测的样本:
```
list_data = [
"国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据",
"试论翻译过程中的文化差异与语言空缺翻译过程,文化差异,语言空缺,文化对比"
]
```
然后运行:
```
python rpc_client.py
```
模型的输出为:
```
{'0': '国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据', '1': '试论翻译过程中的文化差异与语言空缺翻译过程,文化差异,语言空缺,文化对比'}
PipelineClient::predict pack_data time:1641450851.3752182
PipelineClient::predict before time:1641450851.375738
['output_embedding']
(2, 256)
[[ 0.07830612 -0.14036864 0.03433796 -0.14967982 -0.03386067 0.06630666
0.01357943 0.03531194 0.02411093 0.02000859 0.05724002 -0.08119463
......
```
可以看到客户端发送了2条文本,返回了2个 embedding 向量
#### 2.3.2 C++的方式
启动C++的Serving:
```
python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_id 2 --thread 5 --ir_optim True --use_trt --precision FP16
```
也可以使用脚本:
```
sh deploy/C++/start_server.sh
```
Client 可以使用 http 或者 rpc 两种方式,rpc 的方式为:
```
python deploy/C++/rpc_client.py
```
运行的输出为:
```
I0209 20:40:07.978225 20896 general_model.cpp:490] [client]logid=0,client_cost=395.695ms,server_cost=392.559ms.
time to cost :0.3960278034210205 seconds
{'output_embedding': array([[ 9.01343748e-02, -1.21870913e-01, 1.32834800e-02,
-1.57673359e-01, -2.60387752e-02, 6.98455423e-02,
1.58108603e-02, 3.89952064e-02, 3.22783105e-02,
3.49135026e-02, 7.66086206e-02, -9.12970975e-02,
6.25643134e-02, 7.21886680e-02, 7.03565404e-02,
5.44054210e-02, 3.25332815e-03, 5.01751155e-02,
......
```
可以看到服务端返回了向量
或者使用 http 的客户端访问模式:
```
python deploy/C++/http_client.py
```
运行的输出为:
```
(2, 64)
(2, 64)
outputs {
tensor {
float_data: 0.09013437479734421
float_data: -0.12187091261148453
float_data: 0.01328347995877266
float_data: -0.15767335891723633
......
```
可以看到服务端返回了向量
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import sys
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class SemanticIndexBase(nn.Layer):
def __init__(self, pretrained_model, dropout=None, output_emb_size=None):
super().__init__()
self.ptm = pretrained_model
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
# if output_emb_size is not None, then add Linear layer to reduce embedding_size,
# we recommend set output_emb_size = 256 considering the trade-off beteween
# recall performance and efficiency
self.output_emb_size = output_emb_size
if output_emb_size > 0:
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
self.emb_reduce_linear = paddle.nn.Linear(
768, output_emb_size, weight_attr=weight_attr)
@paddle.jit.to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, None], dtype='int64'), paddle.static.InputSpec(
shape=[None, None], dtype='int64')
])
def get_pooled_embedding(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
attention_mask)
if self.output_emb_size > 0:
cls_embedding = self.emb_reduce_linear(cls_embedding)
cls_embedding = self.dropout(cls_embedding)
cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)
return cls_embedding
def get_semantic_embedding(self, data_loader):
self.eval()
with paddle.no_grad():
for batch_data in data_loader:
input_ids, token_type_ids = batch_data
input_ids = paddle.to_tensor(input_ids)
token_type_ids = paddle.to_tensor(token_type_ids)
text_embeddings = self.get_pooled_embedding(
input_ids, token_type_ids=token_type_ids)
yield text_embeddings
def cosine_sim(self,
query_input_ids,
title_input_ids,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None):
query_cls_embedding = self.get_pooled_embedding(
query_input_ids, query_token_type_ids, query_position_ids,
query_attention_mask)
title_cls_embedding = self.get_pooled_embedding(
title_input_ids, title_token_type_ids, title_position_ids,
title_attention_mask)
cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
axis=-1)
return cosine_sim
@abc.abstractmethod
def forward(self):
pass
class SemanticIndexBaseStatic(nn.Layer):
def __init__(self, pretrained_model, dropout=None, output_emb_size=None):
super().__init__()
self.ptm = pretrained_model
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
# if output_emb_size is not None, then add Linear layer to reduce embedding_size,
# we recommend set output_emb_size = 256 considering the trade-off beteween
# recall performance and efficiency
self.output_emb_size = output_emb_size
if output_emb_size > 0:
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
self.emb_reduce_linear = paddle.nn.Linear(
768, output_emb_size, weight_attr=weight_attr)
@paddle.jit.to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, None], dtype='int64'), paddle.static.InputSpec(
shape=[None, None], dtype='int64')
])
def get_pooled_embedding(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
attention_mask)
if self.output_emb_size > 0:
cls_embedding = self.emb_reduce_linear(cls_embedding)
cls_embedding = self.dropout(cls_embedding)
cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)
return cls_embedding
def get_semantic_embedding(self, data_loader):
self.eval()
with paddle.no_grad():
for batch_data in data_loader:
input_ids, token_type_ids = batch_data
input_ids = paddle.to_tensor(input_ids)
token_type_ids = paddle.to_tensor(token_type_ids)
text_embeddings = self.get_pooled_embedding(
input_ids, token_type_ids=token_type_ids)
yield text_embeddings
def cosine_sim(self,
query_input_ids,
title_input_ids,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None):
query_cls_embedding = self.get_pooled_embedding(
query_input_ids, query_token_type_ids, query_position_ids,
query_attention_mask)
title_cls_embedding = self.get_pooled_embedding(
title_input_ids, title_token_type_ids, title_position_ids,
title_attention_mask)
cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
axis=-1)
return cosine_sim
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
attention_mask)
if self.output_emb_size > 0:
cls_embedding = self.emb_reduce_linear(cls_embedding)
cls_embedding = self.dropout(cls_embedding)
cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)
return cls_embedding
# coding:utf-8
# pylint: disable=doc-string-missing
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import requests
import json
from paddle_serving_client import HttpClient
import paddlenlp as ppnlp
def convert_example(example,
tokenizer,
max_seq_length=512,
pad_to_max_seq_len=True):
list_input_ids = []
list_token_type_ids = []
for text in example:
encoded_inputs = tokenizer(
text=text,
max_seq_len=max_seq_length,
pad_to_max_seq_len=pad_to_max_seq_len)
input_ids = encoded_inputs["input_ids"]
token_type_ids = encoded_inputs["token_type_ids"]
list_input_ids.append(input_ids)
list_token_type_ids.append(token_type_ids)
return list_input_ids, list_token_type_ids
# 启动python客户端
endpoint_list = ['127.0.0.1:9393']
client = HttpClient()
client.load_client_config('serving_client')
client.connect(endpoint_list)
feed_names = client.feed_names_
fetch_names = client.fetch_names_
print(feed_names)
print(fetch_names)
# 创建tokenizer
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
max_seq_len = 64
# 数据预处理
list_data = ['国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据.', '面向生态系统服务的生态系统分类方案研发与应用']
# for i in range(5):
# list_data.extend(list_data)
# print(len(list_data))
examples = convert_example(list_data, tokenizer, max_seq_length=max_seq_len)
print(examples)
feed_dict = {}
feed_dict['input_ids'] = np.array(examples[0])
feed_dict['token_type_ids'] = np.array(examples[1])
print(feed_dict['input_ids'].shape)
print(feed_dict['token_type_ids'].shape)
# batch设置为True表示的是批量预测
b_start = time.time()
result = client.predict(feed=feed_dict, fetch=fetch_names, batch=True)
b_end = time.time()
print(result)
print("time to cost :{} seconds".format(b_end - b_start))
# coding:utf-8
# pylint: disable=doc-string-missing
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
from paddle_serving_client import Client
import paddlenlp as ppnlp
def convert_example(example,
tokenizer,
max_seq_length=512,
pad_to_max_seq_len=True):
list_input_ids = []
list_token_type_ids = []
for text in example:
encoded_inputs = tokenizer(
text=text,
max_seq_len=max_seq_length,
pad_to_max_seq_len=pad_to_max_seq_len)
input_ids = encoded_inputs["input_ids"]
token_type_ids = encoded_inputs["token_type_ids"]
list_input_ids.append(input_ids)
list_token_type_ids.append(token_type_ids)
return list_input_ids, list_token_type_ids
# 启动python客户端
endpoint_list = ['127.0.0.1:9393']
client = Client()
client.load_client_config('serving_client')
client.connect(endpoint_list)
feed_names = client.feed_names_
fetch_names = client.fetch_names_
print(feed_names)
print(fetch_names)
# 创建tokenizer
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
max_seq_len = 64
# 数据预处理
list_data = ['国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据.', '面向生态系统服务的生态系统分类方案研发与应用']
# for i in range(5):
# list_data.extend(list_data)
# print(len(list_data))
examples = convert_example(list_data, tokenizer, max_seq_length=max_seq_len)
print(examples)
feed_dict = {}
feed_dict['input_ids'] = np.array(examples[0])
feed_dict['token_type_ids'] = np.array(examples[1])
print(feed_dict['input_ids'].shape)
print(feed_dict['token_type_ids'].shape)
# batch设置为True表示的是批量预测
b_start = time.time()
result = client.predict(feed=feed_dict, fetch=fetch_names, batch=True)
b_end = time.time()
print("time to cost :{} seconds".format(b_end - b_start))
print(result)
python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_id 2 --thread 5 --ir_optim True --use_trt --precision FP16
\ No newline at end of file
# worker_num, 最大并发数。当build_dag_each_worker=True时, 框架会创建worker_num个进程,每个进程内构建grpcSever和DAG
# 当build_dag_each_worker=False时,框架会设置主线程grpc线程池的max_workers=worker_num
worker_num: 20
# build_dag_each_worker, False,框架在进程内创建一条DAG;True,框架会每个进程内创建多个独立的DAG
build_dag_each_worker: false
dag:
# op资源类型, True, 为线程模型;False,为进程模型
is_thread_op: False
# 使用性能分析, True,生成Timeline性能数据,对性能有一定影响;False为不使用
tracer:
interval_s: 10
# http端口, rpc_port和http_port不允许同时为空。当rpc_port可用且http_port为空时,不自动生成http_port
http_port: 18082
# rpc端口, rpc_port和http_port不允许同时为空。当rpc_port为空且http_port不为空时,会自动将rpc_port设置为http_port+1
rpc_port: 8088
op:
ernie:
# 并发数,is_thread_op=True时,为线程并发;否则为进程并发
concurrency: 1
# 当op配置没有server_endpoints时,从local_service_conf读取本地服务配置
local_service_conf:
# client类型,包括brpc, grpc和local_predictor.local_predictor不启动Serving服务,进程内预测
client_type: local_predictor
#ir_optim
ir_optim: True
# device_type, 0=cpu, 1=gpu, 2=tensorRT, 3=arm cpu, 4=kunlun xpu
device_type: 1
# 计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
devices: '2'
# Fetch结果列表,以client_config中fetch_var的alias_name为准, 如果没有设置则全部返回
fetch_list: ['output_embedding']
# 模型路径
model_config: ../../serving_server/
python predict.py --model_dir=../../output
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
import numpy as np
import paddle
import paddlenlp as ppnlp
from scipy.special import softmax
from scipy import spatial
from paddle import inference
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.utils.log import logger
sys.path.append('.')
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True,
help="The directory to static model.")
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after tokenization. Sequences "
"longer than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--batch_size", default=15, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument('--use_tensorrt', default=False, type=eval, choices=[True, False],
help='Enable to use tensorrt to speed up.')
parser.add_argument("--precision", default="fp32", type=str, choices=["fp32", "fp16", "int8"],
help='The tensorrt precision.')
parser.add_argument('--cpu_threads', default=10, type=int,
help='Number of threads to predict when using cpu.')
parser.add_argument('--enable_mkldnn', default=False, type=eval, choices=[True, False],
help='Enable to use mkldnn to speed up when using cpu.')
parser.add_argument("--benchmark", type=eval, default=False,
help="To log some information about environment and running.")
parser.add_argument("--save_log_path", type=str, default="./log_output/",
help="The file path to save log.")
args = parser.parse_args()
# yapf: enable
def convert_example(example,
tokenizer,
max_seq_length=512,
pad_to_max_seq_len=False):
"""
Builds model inputs from a sequence.
A BERT sequence has the following format:
- single sequence: ``[CLS] X [SEP]``
Args:
example(obj:`list(str)`): The list of text to be converted to ids.
tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
which contains most of the methods. Users should refer to the superclass for more information regarding methods.
max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.
Sequences longer than this will be truncated, sequences shorter will be padded.
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
Returns:
input_ids(obj:`list[int]`): The list of query token ids.
token_type_ids(obj: `list[int]`): List of query sequence pair mask.
"""
result = []
for key, text in example.items():
encoded_inputs = tokenizer(
text=text,
max_seq_len=max_seq_length,
pad_to_max_seq_len=pad_to_max_seq_len)
input_ids = encoded_inputs["input_ids"]
token_type_ids = encoded_inputs["token_type_ids"]
result += [input_ids, token_type_ids]
return result
class Predictor(object):
def __init__(self,
model_dir,
device="gpu",
max_seq_length=128,
batch_size=32,
use_tensorrt=False,
precision="fp32",
cpu_threads=10,
enable_mkldnn=False):
self.max_seq_length = max_seq_length
self.batch_size = batch_size
model_file = model_dir + "/inference.pdmodel"
params_file = model_dir + "/inference.pdiparams"
if not os.path.exists(model_file):
raise ValueError("not find model file path {}".format(model_file))
if not os.path.exists(params_file):
raise ValueError("not find params file path {}".format(params_file))
config = paddle.inference.Config(model_file, params_file)
if device == "gpu":
# set GPU configs accordingly
# such as intialize the gpu memory, enable tensorrt
config.enable_use_gpu(100, 0)
precision_map = {
"fp16": inference.PrecisionType.Half,
"fp32": inference.PrecisionType.Float32,
"int8": inference.PrecisionType.Int8
}
precision_mode = precision_map[precision]
if args.use_tensorrt:
config.enable_tensorrt_engine(
max_batch_size=batch_size,
min_subgraph_size=30,
precision_mode=precision_mode)
elif device == "cpu":
# set CPU configs accordingly,
# such as enable_mkldnn, set_cpu_math_library_num_threads
config.disable_gpu()
if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.set_cpu_math_library_num_threads(args.cpu_threads)
elif device == "xpu":
# set XPU configs accordingly
config.enable_xpu(100)
config.switch_use_feed_fetch_ops(False)
self.predictor = paddle.inference.create_predictor(config)
self.input_handles = [
self.predictor.get_input_handle(name)
for name in self.predictor.get_input_names()
]
self.output_handle = self.predictor.get_output_handle(
self.predictor.get_output_names()[0])
if args.benchmark:
import auto_log
pid = os.getpid()
self.autolog = auto_log.AutoLogger(
model_name="ernie-1.0",
model_precision=precision,
batch_size=self.batch_size,
data_shape="dynamic",
save_path=args.save_log_path,
inference_config=config,
pids=pid,
process_name=None,
gpu_ids=0,
time_keys=[
'preprocess_time', 'inference_time', 'postprocess_time'
],
warmup=0,
logger=logger)
def extract_embedding(self, data, tokenizer):
"""
Predicts the data labels.
Args:
data (obj:`List(str)`): The batch data whose each element is a raw text.
tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
which contains most of the methods. Users should refer to the superclass for more information regarding methods.
Returns:
results(obj:`dict`): All the feature vectors.
"""
if args.benchmark:
self.autolog.times.start()
examples = []
for text in data:
input_ids, segment_ids = convert_example(text, tokenizer)
examples.append((input_ids, segment_ids))
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
): fn(samples)
if args.benchmark:
self.autolog.times.stamp()
input_ids, segment_ids = batchify_fn(examples)
self.input_handles[0].copy_from_cpu(input_ids)
self.input_handles[1].copy_from_cpu(segment_ids)
self.predictor.run()
logits = self.output_handle.copy_to_cpu()
if args.benchmark:
self.autolog.times.stamp()
if args.benchmark:
self.autolog.times.end(stamp=True)
return logits
def predict(self, data, tokenizer):
"""
Predicts the data labels.
Args:
data (obj:`List(str)`): The batch data whose each element is a raw text.
tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
which contains most of the methods. Users should refer to the superclass for more information regarding methods.
Returns:
results(obj:`dict`): All the predictions probs.
"""
if args.benchmark:
self.autolog.times.start()
examples = []
for idx, text in enumerate(data):
input_ids, segment_ids = convert_example({idx: text[0]}, tokenizer)
title_ids, title_segment_ids = convert_example({
idx: text[1]
}, tokenizer)
examples.append(
(input_ids, segment_ids, title_ids, title_segment_ids))
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
): fn(samples)
if args.benchmark:
self.autolog.times.stamp()
query_ids, query_segment_ids, title_ids, title_segment_ids = batchify_fn(
examples)
self.input_handles[0].copy_from_cpu(query_ids)
self.input_handles[1].copy_from_cpu(query_segment_ids)
self.predictor.run()
query_logits = self.output_handle.copy_to_cpu()
self.input_handles[0].copy_from_cpu(title_ids)
self.input_handles[1].copy_from_cpu(title_segment_ids)
self.predictor.run()
title_logits = self.output_handle.copy_to_cpu()
if args.benchmark:
self.autolog.times.stamp()
if args.benchmark:
self.autolog.times.end(stamp=True)
result = [
float(1 - spatial.distance.cosine(arr1, arr2))
for arr1, arr2 in zip(query_logits, title_logits)
]
return result
if __name__ == "__main__":
# Define predictor to do prediction.
predictor = Predictor(args.model_dir, args.device, args.max_seq_length,
args.batch_size, args.use_tensorrt, args.precision,
args.cpu_threads, args.enable_mkldnn)
# ErnieTinyTokenizer is special for ernie-tiny pretained model.
output_emb_size = 256
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
id2corpus = {0: '国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据'}
corpus_list = [{idx: text} for idx, text in id2corpus.items()]
res = predictor.extract_embedding(corpus_list, tokenizer)
print(res.shape)
print(res)
corpus_list = [['中西方语言与文化的差异', '中西方文化差异以及语言体现中西方文化,差异,语言体现'],
['中西方语言与文化的差异', '飞桨致力于让深度学习技术的创新与应用更简单']]
res = predictor.predict(corpus_list, tokenizer)
print(res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import numpy as np
from paddle_serving_server.pipeline import PipelineClient
client = PipelineClient()
client.connect(['127.0.0.1:8088'])
list_data = [
"国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据",
"试论翻译过程中的文化差异与语言空缺翻译过程,文化差异,语言空缺,文化对比"
]
feed = {}
for i, item in enumerate(list_data):
feed[str(i)] = item
print(feed)
start_time = time.time()
ret = client.predict(feed_dict=feed)
end_time = time.time()
print("time to cost :{} seconds".format(end_time - start_time))
result = np.array(eval(ret.value[0]))
print(ret.key)
print(result.shape)
print(result)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import sys
from paddle_serving_server.web_service import WebService, Op
_LOGGER = logging.getLogger()
def convert_example(example,
tokenizer,
max_seq_length=512,
pad_to_max_seq_len=False):
result = []
for text in example:
encoded_inputs = tokenizer(
text=text,
max_seq_len=max_seq_length,
pad_to_max_seq_len=pad_to_max_seq_len)
input_ids = encoded_inputs["input_ids"]
token_type_ids = encoded_inputs["token_type_ids"]
result += [input_ids, token_type_ids]
return result
class ErnieOp(Op):
def init_op(self):
import paddlenlp as ppnlp
self.tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained(
'ernie-1.0')
def preprocess(self, input_dicts, data_id, log_id):
from paddlenlp.data import Stack, Tuple, Pad
(_, input_dict), = input_dicts.items()
print("input dict", input_dict)
batch_size = len(input_dict.keys())
examples = []
for i in range(batch_size):
input_ids, segment_ids = convert_example([input_dict[str(i)]],
self.tokenizer)
examples.append((input_ids, segment_ids))
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # segment
): fn(samples)
input_ids, segment_ids = batchify_fn(examples)
feed_dict = {}
feed_dict['input_ids'] = input_ids
feed_dict['token_type_ids'] = segment_ids
return feed_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
new_dict = {}
new_dict["output_embedding"] = str(fetch_dict["output_embedding"]
.tolist())
return new_dict, None, ""
class ErnieService(WebService):
def get_pipeline_response(self, read_op):
ernie_op = ErnieOp(name="ernie", input_ops=[read_op])
return ernie_op
ernie_service = ErnieService(name="ernie")
ernie_service.prepare_pipeline_config("config_nlp.yml")
ernie_service.run_service()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from functools import partial
import numpy as np
import paddle
import paddle.nn.functional as F
import paddlenlp as ppnlp
from paddlenlp.data import Stack, Tuple, Pad
from base_model import SemanticIndexBase, SemanticIndexBaseStatic
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--params_path", type=str, required=True,
default='./checkpoint/model_900/model_state.pdparams', help="The path to model parameters to be loaded.")
parser.add_argument("--output_path", type=str, default='./output',
help="The path of model parameter in static graph to be saved.")
args = parser.parse_args()
# yapf: enable
if __name__ == "__main__":
# If you want to use ernie1.0 model, plesace uncomment the following code
output_emb_size = 256
pretrained_model = ppnlp.transformers.ErnieModel.from_pretrained(
"ernie-1.0")
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
model = SemanticIndexBaseStatic(
pretrained_model, output_emb_size=output_emb_size)
if args.params_path and os.path.isfile(args.params_path):
state_dict = paddle.load(args.params_path)
model.set_dict(state_dict)
print("Loaded parameters from %s" % args.params_path)
model.eval()
# Convert to static graph with specific input description
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None], dtype="int64"), # input_ids
paddle.static.InputSpec(
shape=[None, None], dtype="int64") # segment_ids
])
# Save in static graph model.
save_path = os.path.join(args.output_path, "inference")
paddle.jit.save(model, save_path)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle_serving_client.io as serving_io
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--dirname", type=str, required=True,
default='./output', help="Path of saved model files. Program file and parameter files are saved in this directory.")
parser.add_argument("--model_filename", type=str, required=True,
default='inference.get_pooled_embedding.pdmodel', help="The name of file to load the inference program. If it is None, the default filename __model__ will be used.")
parser.add_argument("--params_filename", type=str, required=True,
default='inference.get_pooled_embedding.pdiparams', help="The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None.")
parser.add_argument("--server_path", type=str, default='./serving_server',
help="The path of server parameter in static graph to be saved.")
parser.add_argument("--client_path", type=str, default='./serving_client',
help="The path of client parameter in static graph to be saved.")
parser.add_argument("--feed_alias_names", type=str, default=None,
help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of feed vars')
parser.add_argument("--fetch_alias_names", type=str, default=None,
help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of fetch vars')
parser.add_argument("--show_proto", type=bool, default=False,
help='If yes, you can preview the proto and then determine your feed var alias name and fetch var alias name.')
# yapf: enable
if __name__ == "__main__":
args = parser.parse_args()
serving_io.inference_model_to_serving(
dirname=args.dirname,
serving_server=args.server_path,
serving_client=args.client_path,
model_filename=args.model_filename,
params_filename=args.params_filename,
show_proto=args.show_proto,
feed_alias_names=args.feed_alias_names,
fetch_alias_names=args.fetch_alias_names)
python export_model.py --params_path checkpoints/model_40/model_state.pdparams --output_path=./output
\ No newline at end of file
python export_to_serving.py \
--dirname "output" \
--model_filename "inference.get_pooled_embedding.pdmodel" \
--params_filename "inference.get_pooled_embedding.pdiparams" \
--server_path "serving_server" \
--client_path "serving_client" \
--fetch_alias_names "output_embedding"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册