diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/base_model.py b/examples/Pipeline/PaddleNLP/semantic_indexing/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c471d126c2649fee7554fa8f026284c7300ada2f --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/base_model.py @@ -0,0 +1,187 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import sys + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class SemanticIndexBase(nn.Layer): + def __init__(self, pretrained_model, dropout=None, output_emb_size=None): + super().__init__() + self.ptm = pretrained_model + self.dropout = nn.Dropout(dropout if dropout is not None else 0.1) + + # if output_emb_size is not None, then add Linear layer to reduce embedding_size, + # we recommend set output_emb_size = 256 considering the trade-off beteween + # recall performance and efficiency + + self.output_emb_size = output_emb_size + if output_emb_size > 0: + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.TruncatedNormal(std=0.02)) + self.emb_reduce_linear = paddle.nn.Linear( + 768, output_emb_size, weight_attr=weight_attr) + + @paddle.jit.to_static(input_spec=[ + paddle.static.InputSpec( + shape=[None, None], dtype='int64'), paddle.static.InputSpec( + shape=[None, None], dtype='int64') + ]) + def get_pooled_embedding(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids, + attention_mask) + + if self.output_emb_size > 0: + cls_embedding = self.emb_reduce_linear(cls_embedding) + cls_embedding = self.dropout(cls_embedding) + cls_embedding = F.normalize(cls_embedding, p=2, axis=-1) + + return cls_embedding + + def get_semantic_embedding(self, data_loader): + self.eval() + with paddle.no_grad(): + for batch_data in data_loader: + input_ids, token_type_ids = batch_data + input_ids = paddle.to_tensor(input_ids) + token_type_ids = paddle.to_tensor(token_type_ids) + + text_embeddings = self.get_pooled_embedding( + input_ids, token_type_ids=token_type_ids) + + yield text_embeddings + + def cosine_sim(self, + query_input_ids, + title_input_ids, + query_token_type_ids=None, + query_position_ids=None, + query_attention_mask=None, + title_token_type_ids=None, + title_position_ids=None, + title_attention_mask=None): + + query_cls_embedding = self.get_pooled_embedding( + query_input_ids, query_token_type_ids, query_position_ids, + query_attention_mask) + + title_cls_embedding = self.get_pooled_embedding( + title_input_ids, title_token_type_ids, title_position_ids, + title_attention_mask) + + cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding, + axis=-1) + return cosine_sim + + @abc.abstractmethod + def forward(self): + pass + + +class SemanticIndexBaseStatic(nn.Layer): + def __init__(self, pretrained_model, dropout=None, output_emb_size=None): + super().__init__() + self.ptm = pretrained_model + self.dropout = nn.Dropout(dropout if dropout is not None else 0.1) + + # if output_emb_size is not None, then add Linear layer to reduce embedding_size, + # we recommend set output_emb_size = 256 considering the trade-off beteween + # recall performance and efficiency + + self.output_emb_size = output_emb_size + if output_emb_size > 0: + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.TruncatedNormal(std=0.02)) + self.emb_reduce_linear = paddle.nn.Linear( + 768, output_emb_size, weight_attr=weight_attr) + + @paddle.jit.to_static(input_spec=[ + paddle.static.InputSpec( + shape=[None, None], dtype='int64'), paddle.static.InputSpec( + shape=[None, None], dtype='int64') + ]) + def get_pooled_embedding(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids, + attention_mask) + + if self.output_emb_size > 0: + cls_embedding = self.emb_reduce_linear(cls_embedding) + cls_embedding = self.dropout(cls_embedding) + cls_embedding = F.normalize(cls_embedding, p=2, axis=-1) + + return cls_embedding + + def get_semantic_embedding(self, data_loader): + self.eval() + with paddle.no_grad(): + for batch_data in data_loader: + input_ids, token_type_ids = batch_data + input_ids = paddle.to_tensor(input_ids) + token_type_ids = paddle.to_tensor(token_type_ids) + + text_embeddings = self.get_pooled_embedding( + input_ids, token_type_ids=token_type_ids) + + yield text_embeddings + + def cosine_sim(self, + query_input_ids, + title_input_ids, + query_token_type_ids=None, + query_position_ids=None, + query_attention_mask=None, + title_token_type_ids=None, + title_position_ids=None, + title_attention_mask=None): + + query_cls_embedding = self.get_pooled_embedding( + query_input_ids, query_token_type_ids, query_position_ids, + query_attention_mask) + + title_cls_embedding = self.get_pooled_embedding( + title_input_ids, title_token_type_ids, title_position_ids, + title_attention_mask) + + cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding, + axis=-1) + return cosine_sim + + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids, + attention_mask) + + if self.output_emb_size > 0: + cls_embedding = self.emb_reduce_linear(cls_embedding) + cls_embedding = self.dropout(cls_embedding) + cls_embedding = F.normalize(cls_embedding, p=2, axis=-1) + + return cls_embedding diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/http_client.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/http_client.py new file mode 100644 index 0000000000000000000000000000000000000000..a976ad9fc33b06ce7148adc7153d4b35183e31c0 --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/http_client.py @@ -0,0 +1,81 @@ +# coding:utf-8 +# pylint: disable=doc-string-missing +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +import numpy as np +import requests +import json + +from paddle_serving_client import HttpClient +import paddlenlp as ppnlp + + +def convert_example(example, + tokenizer, + max_seq_length=512, + pad_to_max_seq_len=True): + list_input_ids = [] + list_token_type_ids = [] + for text in example: + encoded_inputs = tokenizer( + text=text, + max_seq_len=max_seq_length, + pad_to_max_seq_len=pad_to_max_seq_len) + input_ids = encoded_inputs["input_ids"] + token_type_ids = encoded_inputs["token_type_ids"] + + list_input_ids.append(input_ids) + list_token_type_ids.append(token_type_ids) + return list_input_ids, list_token_type_ids + + +# 启动python客户端 +endpoint_list = ['127.0.0.1:9393'] +client = HttpClient() +client.load_client_config('serving_client') +client.connect(endpoint_list) +feed_names = client.feed_names_ +fetch_names = client.fetch_names_ +print(feed_names) +print(fetch_names) + +# 创建tokenizer +tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0') +max_seq_len = 64 + +# 数据预处理 + +list_data = ['国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据.', '面向生态系统服务的生态系统分类方案研发与应用'] +# for i in range(5): +# list_data.extend(list_data) +# print(len(list_data)) +examples = convert_example(list_data, tokenizer, max_seq_length=max_seq_len) +print(examples) + +feed_dict = {} +feed_dict['input_ids'] = np.array(examples[0]) +feed_dict['token_type_ids'] = np.array(examples[1]) + +print(feed_dict['input_ids'].shape) +print(feed_dict['token_type_ids'].shape) + +# batch设置为True表示的是批量预测 +b_start = time.time() +result = client.predict(feed=feed_dict, fetch=fetch_names, batch=True) +b_end = time.time() +print(result) +print("time to cost :{} seconds".format(b_end - b_start)) diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/rpc_client.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/rpc_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea4c245f2a10256166a512f9282282e69d9997b --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/rpc_client.py @@ -0,0 +1,77 @@ +# coding:utf-8 +# pylint: disable=doc-string-missing +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +import numpy as np + +from paddle_serving_client import Client +import paddlenlp as ppnlp + + +def convert_example(example, + tokenizer, + max_seq_length=512, + pad_to_max_seq_len=True): + list_input_ids = [] + list_token_type_ids = [] + for text in example: + encoded_inputs = tokenizer( + text=text, + max_seq_len=max_seq_length, + pad_to_max_seq_len=pad_to_max_seq_len) + input_ids = encoded_inputs["input_ids"] + token_type_ids = encoded_inputs["token_type_ids"] + list_input_ids.append(input_ids) + list_token_type_ids.append(token_type_ids) + return list_input_ids, list_token_type_ids + + +# 启动python客户端 +endpoint_list = ['127.0.0.1:9393'] +client = Client() +client.load_client_config('serving_client') +client.connect(endpoint_list) +feed_names = client.feed_names_ +fetch_names = client.fetch_names_ +print(feed_names) +print(fetch_names) + +# 创建tokenizer +tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0') +max_seq_len = 64 + +# 数据预处理 + +list_data = ['国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据.', '面向生态系统服务的生态系统分类方案研发与应用'] +# for i in range(5): +# list_data.extend(list_data) +# print(len(list_data)) +examples = convert_example(list_data, tokenizer, max_seq_length=max_seq_len) +print(examples) + +feed_dict = {} +feed_dict['input_ids'] = np.array(examples[0]) +feed_dict['token_type_ids'] = np.array(examples[1]) + +print(feed_dict['input_ids'].shape) +print(feed_dict['token_type_ids'].shape) +# batch设置为True表示的是批量预测 +b_start = time.time() +result = client.predict(feed=feed_dict, fetch=fetch_names, batch=True) +b_end = time.time() +print("time to cost :{} seconds".format(b_end - b_start)) +print(result) diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/start_server.sh b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/start_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..55d380d6f87396887675a008c54bb8544ce2a793 --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/start_server.sh @@ -0,0 +1 @@ +python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_id 2 --thread 5 --ir_optim True --use_trt --precision FP16 \ No newline at end of file diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/config_nlp.yml b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/config_nlp.yml new file mode 100644 index 0000000000000000000000000000000000000000..d896adbfa1f9671cb569137637cf5f3ec169ef69 --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/config_nlp.yml @@ -0,0 +1,34 @@ +# worker_num, 最大并发数。当build_dag_each_worker=True时, 框架会创建worker_num个进程,每个进程内构建grpcSever和DAG +# 当build_dag_each_worker=False时,框架会设置主线程grpc线程池的max_workers=worker_num +worker_num: 20 +# build_dag_each_worker, False,框架在进程内创建一条DAG;True,框架会每个进程内创建多个独立的DAG +build_dag_each_worker: false + +dag: + # op资源类型, True, 为线程模型;False,为进程模型 + is_thread_op: False + # 使用性能分析, True,生成Timeline性能数据,对性能有一定影响;False为不使用 + tracer: + interval_s: 10 +# http端口, rpc_port和http_port不允许同时为空。当rpc_port可用且http_port为空时,不自动生成http_port +http_port: 18082 +# rpc端口, rpc_port和http_port不允许同时为空。当rpc_port为空且http_port不为空时,会自动将rpc_port设置为http_port+1 +rpc_port: 8088 +op: + ernie: + # 并发数,is_thread_op=True时,为线程并发;否则为进程并发 + concurrency: 1 + # 当op配置没有server_endpoints时,从local_service_conf读取本地服务配置 + local_service_conf: + # client类型,包括brpc, grpc和local_predictor.local_predictor不启动Serving服务,进程内预测 + client_type: local_predictor + #ir_optim + ir_optim: True + # device_type, 0=cpu, 1=gpu, 2=tensorRT, 3=arm cpu, 4=kunlun xpu + device_type: 1 + # 计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡 + devices: '2' + # Fetch结果列表,以client_config中fetch_var的alias_name为准, 如果没有设置则全部返回 + fetch_list: ['output_embedding'] + # 模型路径 + model_config: ../../serving_server/ diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/deploy.sh b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/deploy.sh new file mode 100644 index 0000000000000000000000000000000000000000..fe8f071e0a47a47f5dc24d84ea4eaaf8e7503c06 --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/deploy.sh @@ -0,0 +1 @@ +python predict.py --model_dir=../../output \ No newline at end of file diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/predict.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..0e81dbb5092ce6178587f5aa8f40d758f4446a42 --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/predict.py @@ -0,0 +1,292 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys + +import numpy as np +import paddle +import paddlenlp as ppnlp +from scipy.special import softmax +from scipy import spatial +from paddle import inference +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.datasets import load_dataset +from paddlenlp.utils.log import logger + +sys.path.append('.') + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--model_dir", type=str, required=True, + help="The directory to static model.") + +parser.add_argument("--max_seq_length", default=128, type=int, + help="The maximum total input sequence length after tokenization. Sequences " + "longer than this will be truncated, sequences shorter will be padded.") +parser.add_argument("--batch_size", default=15, type=int, + help="Batch size per GPU/CPU for training.") +parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu", + help="Select which device to train model, defaults to gpu.") + +parser.add_argument('--use_tensorrt', default=False, type=eval, choices=[True, False], + help='Enable to use tensorrt to speed up.') +parser.add_argument("--precision", default="fp32", type=str, choices=["fp32", "fp16", "int8"], + help='The tensorrt precision.') + +parser.add_argument('--cpu_threads', default=10, type=int, + help='Number of threads to predict when using cpu.') +parser.add_argument('--enable_mkldnn', default=False, type=eval, choices=[True, False], + help='Enable to use mkldnn to speed up when using cpu.') + +parser.add_argument("--benchmark", type=eval, default=False, + help="To log some information about environment and running.") +parser.add_argument("--save_log_path", type=str, default="./log_output/", + help="The file path to save log.") +args = parser.parse_args() +# yapf: enable + + +def convert_example(example, + tokenizer, + max_seq_length=512, + pad_to_max_seq_len=False): + """ + Builds model inputs from a sequence. + + A BERT sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + + Args: + example(obj:`list(str)`): The list of text to be converted to ids. + tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer` + which contains most of the methods. Users should refer to the superclass for more information regarding methods. + max_seq_len(obj:`int`): The maximum total input sequence length after tokenization. + Sequences longer than this will be truncated, sequences shorter will be padded. + is_test(obj:`False`, defaults to `False`): Whether the example contains label or not. + + Returns: + input_ids(obj:`list[int]`): The list of query token ids. + token_type_ids(obj: `list[int]`): List of query sequence pair mask. + """ + + result = [] + for key, text in example.items(): + encoded_inputs = tokenizer( + text=text, + max_seq_len=max_seq_length, + pad_to_max_seq_len=pad_to_max_seq_len) + input_ids = encoded_inputs["input_ids"] + token_type_ids = encoded_inputs["token_type_ids"] + result += [input_ids, token_type_ids] + return result + + +class Predictor(object): + def __init__(self, + model_dir, + device="gpu", + max_seq_length=128, + batch_size=32, + use_tensorrt=False, + precision="fp32", + cpu_threads=10, + enable_mkldnn=False): + self.max_seq_length = max_seq_length + self.batch_size = batch_size + + model_file = model_dir + "/inference.pdmodel" + params_file = model_dir + "/inference.pdiparams" + if not os.path.exists(model_file): + raise ValueError("not find model file path {}".format(model_file)) + if not os.path.exists(params_file): + raise ValueError("not find params file path {}".format(params_file)) + config = paddle.inference.Config(model_file, params_file) + + if device == "gpu": + # set GPU configs accordingly + # such as intialize the gpu memory, enable tensorrt + config.enable_use_gpu(100, 0) + precision_map = { + "fp16": inference.PrecisionType.Half, + "fp32": inference.PrecisionType.Float32, + "int8": inference.PrecisionType.Int8 + } + precision_mode = precision_map[precision] + + if args.use_tensorrt: + config.enable_tensorrt_engine( + max_batch_size=batch_size, + min_subgraph_size=30, + precision_mode=precision_mode) + elif device == "cpu": + # set CPU configs accordingly, + # such as enable_mkldnn, set_cpu_math_library_num_threads + config.disable_gpu() + if args.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + config.set_cpu_math_library_num_threads(args.cpu_threads) + elif device == "xpu": + # set XPU configs accordingly + config.enable_xpu(100) + + config.switch_use_feed_fetch_ops(False) + self.predictor = paddle.inference.create_predictor(config) + self.input_handles = [ + self.predictor.get_input_handle(name) + for name in self.predictor.get_input_names() + ] + self.output_handle = self.predictor.get_output_handle( + self.predictor.get_output_names()[0]) + + if args.benchmark: + import auto_log + pid = os.getpid() + self.autolog = auto_log.AutoLogger( + model_name="ernie-1.0", + model_precision=precision, + batch_size=self.batch_size, + data_shape="dynamic", + save_path=args.save_log_path, + inference_config=config, + pids=pid, + process_name=None, + gpu_ids=0, + time_keys=[ + 'preprocess_time', 'inference_time', 'postprocess_time' + ], + warmup=0, + logger=logger) + + def extract_embedding(self, data, tokenizer): + """ + Predicts the data labels. + + Args: + data (obj:`List(str)`): The batch data whose each element is a raw text. + tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer` + which contains most of the methods. Users should refer to the superclass for more information regarding methods. + + Returns: + results(obj:`dict`): All the feature vectors. + """ + if args.benchmark: + self.autolog.times.start() + + examples = [] + for text in data: + input_ids, segment_ids = convert_example(text, tokenizer) + examples.append((input_ids, segment_ids)) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + ): fn(samples) + + if args.benchmark: + self.autolog.times.stamp() + + input_ids, segment_ids = batchify_fn(examples) + self.input_handles[0].copy_from_cpu(input_ids) + self.input_handles[1].copy_from_cpu(segment_ids) + self.predictor.run() + logits = self.output_handle.copy_to_cpu() + if args.benchmark: + self.autolog.times.stamp() + + if args.benchmark: + self.autolog.times.end(stamp=True) + + return logits + + def predict(self, data, tokenizer): + """ + Predicts the data labels. + + Args: + data (obj:`List(str)`): The batch data whose each element is a raw text. + tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer` + which contains most of the methods. Users should refer to the superclass for more information regarding methods. + + Returns: + results(obj:`dict`): All the predictions probs. + """ + if args.benchmark: + self.autolog.times.start() + + examples = [] + for idx, text in enumerate(data): + input_ids, segment_ids = convert_example({idx: text[0]}, tokenizer) + title_ids, title_segment_ids = convert_example({ + idx: text[1] + }, tokenizer) + examples.append( + (input_ids, segment_ids, title_ids, title_segment_ids)) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + ): fn(samples) + + if args.benchmark: + self.autolog.times.stamp() + + query_ids, query_segment_ids, title_ids, title_segment_ids = batchify_fn( + examples) + self.input_handles[0].copy_from_cpu(query_ids) + self.input_handles[1].copy_from_cpu(query_segment_ids) + self.predictor.run() + query_logits = self.output_handle.copy_to_cpu() + + self.input_handles[0].copy_from_cpu(title_ids) + self.input_handles[1].copy_from_cpu(title_segment_ids) + self.predictor.run() + title_logits = self.output_handle.copy_to_cpu() + + if args.benchmark: + self.autolog.times.stamp() + + if args.benchmark: + self.autolog.times.end(stamp=True) + result = [ + float(1 - spatial.distance.cosine(arr1, arr2)) + for arr1, arr2 in zip(query_logits, title_logits) + ] + return result + + +if __name__ == "__main__": + # Define predictor to do prediction. + predictor = Predictor(args.model_dir, args.device, args.max_seq_length, + args.batch_size, args.use_tensorrt, args.precision, + args.cpu_threads, args.enable_mkldnn) + + # ErnieTinyTokenizer is special for ernie-tiny pretained model. + output_emb_size = 256 + tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0') + id2corpus = {0: '国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据'} + corpus_list = [{idx: text} for idx, text in id2corpus.items()] + res = predictor.extract_embedding(corpus_list, tokenizer) + print(res.shape) + print(res) + corpus_list = [['中西方语言与文化的差异', '中西方文化差异以及语言体现中西方文化,差异,语言体现'], + ['中西方语言与文化的差异', '飞桨致力于让深度学习技术的创新与应用更简单']] + res = predictor.predict(corpus_list, tokenizer) + print(res) diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/rpc_client.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/rpc_client.py new file mode 100644 index 0000000000000000000000000000000000000000..03863db6114b7c381dae17ee3bf33f00f15d8f4a --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/rpc_client.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +import numpy as np + +from paddle_serving_server.pipeline import PipelineClient + +client = PipelineClient() +client.connect(['127.0.0.1:8088']) + +list_data = [ + "国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据", + "试论翻译过程中的文化差异与语言空缺翻译过程,文化差异,语言空缺,文化对比" +] +feed = {} +for i, item in enumerate(list_data): + feed[str(i)] = item + +print(feed) +start_time = time.time() +ret = client.predict(feed_dict=feed) +end_time = time.time() +print("time to cost :{} seconds".format(end_time - start_time)) + +result = np.array(eval(ret.value[0])) +print(ret.key) +print(result.shape) +print(result) diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/web_service.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/web_service.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad12032b3c92d72a5297f15d732b7dfbd19589e --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/web_service.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +import sys + +from paddle_serving_server.web_service import WebService, Op + +_LOGGER = logging.getLogger() + + +def convert_example(example, + tokenizer, + max_seq_length=512, + pad_to_max_seq_len=False): + result = [] + for text in example: + encoded_inputs = tokenizer( + text=text, + max_seq_len=max_seq_length, + pad_to_max_seq_len=pad_to_max_seq_len) + input_ids = encoded_inputs["input_ids"] + token_type_ids = encoded_inputs["token_type_ids"] + result += [input_ids, token_type_ids] + return result + + +class ErnieOp(Op): + def init_op(self): + import paddlenlp as ppnlp + self.tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained( + 'ernie-1.0') + + def preprocess(self, input_dicts, data_id, log_id): + from paddlenlp.data import Stack, Tuple, Pad + + (_, input_dict), = input_dicts.items() + print("input dict", input_dict) + batch_size = len(input_dict.keys()) + examples = [] + for i in range(batch_size): + input_ids, segment_ids = convert_example([input_dict[str(i)]], + self.tokenizer) + examples.append((input_ids, segment_ids)) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # segment + ): fn(samples) + input_ids, segment_ids = batchify_fn(examples) + feed_dict = {} + feed_dict['input_ids'] = input_ids + feed_dict['token_type_ids'] = segment_ids + return feed_dict, False, None, "" + + def postprocess(self, input_dicts, fetch_dict, data_id, log_id): + new_dict = {} + new_dict["output_embedding"] = str(fetch_dict["output_embedding"] + .tolist()) + return new_dict, None, "" + + +class ErnieService(WebService): + def get_pipeline_response(self, read_op): + ernie_op = ErnieOp(name="ernie", input_ops=[read_op]) + return ernie_op + + +ernie_service = ErnieService(name="ernie") +ernie_service.prepare_pipeline_config("config_nlp.yml") +ernie_service.run_service() diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/export_model.py b/examples/Pipeline/PaddleNLP/semantic_indexing/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..da468ea7b2c3af6eff093eef98a3e4f9393f9b3d --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/export_model.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from functools import partial + +import numpy as np +import paddle +import paddle.nn.functional as F +import paddlenlp as ppnlp +from paddlenlp.data import Stack, Tuple, Pad + +from base_model import SemanticIndexBase, SemanticIndexBaseStatic + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--params_path", type=str, required=True, + default='./checkpoint/model_900/model_state.pdparams', help="The path to model parameters to be loaded.") +parser.add_argument("--output_path", type=str, default='./output', + help="The path of model parameter in static graph to be saved.") +args = parser.parse_args() +# yapf: enable + +if __name__ == "__main__": + # If you want to use ernie1.0 model, plesace uncomment the following code + output_emb_size = 256 + + pretrained_model = ppnlp.transformers.ErnieModel.from_pretrained( + "ernie-1.0") + + tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0') + model = SemanticIndexBaseStatic( + pretrained_model, output_emb_size=output_emb_size) + + if args.params_path and os.path.isfile(args.params_path): + state_dict = paddle.load(args.params_path) + model.set_dict(state_dict) + print("Loaded parameters from %s" % args.params_path) + + model.eval() + + # Convert to static graph with specific input description + model = paddle.jit.to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int64") # segment_ids + ]) + # Save in static graph model. + save_path = os.path.join(args.output_path, "inference") + paddle.jit.save(model, save_path) diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/export_to_serving.py b/examples/Pipeline/PaddleNLP/semantic_indexing/export_to_serving.py new file mode 100644 index 0000000000000000000000000000000000000000..c24f931510e5662ae1b824049d1ac35c4ef34076 --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/export_to_serving.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import paddle_serving_client.io as serving_io +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--dirname", type=str, required=True, + default='./output', help="Path of saved model files. Program file and parameter files are saved in this directory.") +parser.add_argument("--model_filename", type=str, required=True, + default='inference.get_pooled_embedding.pdmodel', help="The name of file to load the inference program. If it is None, the default filename __model__ will be used.") +parser.add_argument("--params_filename", type=str, required=True, + default='inference.get_pooled_embedding.pdiparams', help="The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None.") +parser.add_argument("--server_path", type=str, default='./serving_server', + help="The path of server parameter in static graph to be saved.") +parser.add_argument("--client_path", type=str, default='./serving_client', + help="The path of client parameter in static graph to be saved.") +parser.add_argument("--feed_alias_names", type=str, default=None, + help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of feed vars') +parser.add_argument("--fetch_alias_names", type=str, default=None, + help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of fetch vars') +parser.add_argument("--show_proto", type=bool, default=False, + help='If yes, you can preview the proto and then determine your feed var alias name and fetch var alias name.') +# yapf: enable + +if __name__ == "__main__": + args = parser.parse_args() + serving_io.inference_model_to_serving( + dirname=args.dirname, + serving_server=args.server_path, + serving_client=args.client_path, + model_filename=args.model_filename, + params_filename=args.params_filename, + show_proto=args.show_proto, + feed_alias_names=args.feed_alias_names, + fetch_alias_names=args.fetch_alias_names) diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_model.sh b/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..7c79266219cea03e16968ed0d00a3755615c7432 --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_model.sh @@ -0,0 +1 @@ +python export_model.py --params_path checkpoints/model_40/model_state.pdparams --output_path=./output \ No newline at end of file diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_to_serving.sh b/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_to_serving.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0d7a422551fd09eb1a28cfacdf47237a8efc795 --- /dev/null +++ b/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_to_serving.sh @@ -0,0 +1,7 @@ +python export_to_serving.py \ + --dirname "output" \ + --model_filename "inference.get_pooled_embedding.pdmodel" \ + --params_filename "inference.get_pooled_embedding.pdiparams" \ + --server_path "serving_server" \ + --client_path "serving_client" \ + --fetch_alias_names "output_embedding"