diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/README.md b/examples/Pipeline/PaddleNLP/semantic_indexing/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..20e0bc04a2b0de6c3fb21355b8636de73c625d42
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/README.md
@@ -0,0 +1,201 @@
+# In-batch Negatives
+
+ **目录**
+
+* [模型下载](#模型下载)
+* [模型部署](#模型部署)
+
+
+
+
+## 1. 语义索引模型
+
+**语义索引训练模型下载链接:**
+
+以下模型结构参数为: `TrasformerLayer:12, Hidden:768, Heads:12, OutputEmbSize: 256`
+
+|Model|训练参数配置|硬件|MD5|
+| ------------ | ------------ | ------------ |-----------|
+|[batch_neg](https://bj.bcebos.com/v1/paddlenlp/models/inbatch_model.zip)|
margin:0.2 scale:30 epoch:3 lr:5E-5 bs:64 max_len:64
|4卡 v100-16g
|f3e5c7d7b0b718c2530c5e1b136b2d74|
+
+```
+wget https://bj.bcebos.com/v1/paddlenlp/models/inbatch_model.zip
+unzip inbatch_model.zip -d checkpoints
+```
+
+
+
+## 2. 模型部署
+
+### 2.1 动转静导出
+
+首先把动态图模型转换为静态图:
+
+```
+python export_model.py --params_path checkpoints/model_40/model_state.pdparams --output_path=./output
+```
+也可以运行下面的bash脚本:
+
+```
+sh scripts/export_model.sh
+```
+
+### 2.2 Paddle Inference预测
+
+预测既可以抽取向量也可以计算两个文本的相似度。
+
+修改id2corpus的样本:
+
+```
+# 抽取向量
+id2corpus={0:'国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据'}
+# 计算相似度
+corpus_list=[['中西方语言与文化的差异','中西方文化差异以及语言体现中西方文化,差异,语言体现'],
+ ['中西方语言与文化的差异','飞桨致力于让深度学习技术的创新与应用更简单']]
+
+```
+
+然后使用PaddleInference
+
+```
+python deploy/python/predict.py --model_dir=./output
+```
+也可以运行下面的bash脚本:
+
+```
+sh deploy.sh
+```
+最终输出的是256维度的特征向量和句子对的预测概率:
+
+```
+(1, 256)
+[[-0.0394925 -0.04474756 -0.065534 0.00939134 0.04359895 0.14659195
+ -0.0091779 -0.07303623 0.09413272 -0.01255222 -0.08685658 0.02762237
+ 0.10138468 0.00962821 0.10888419 0.04553023 0.05898942 0.00694253
+ ....
+
+[0.959269642829895, 0.04725276678800583]
+```
+
+### 2.3 Paddle Serving部署
+
+Paddle Serving 的详细文档请参考 [Pipeline_Design](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Python_Pipeline/Pipeline_Design_CN.md)和[Serving_Design](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Serving_Design_CN.md),首先把静态图模型转换成Serving的格式:
+
+```
+python export_to_serving.py \
+ --dirname "output" \
+ --model_filename "inference.get_pooled_embedding.pdmodel" \
+ --params_filename "inference.get_pooled_embedding.pdiparams" \
+ --server_path "./serving_server" \
+ --client_path "./serving_client" \
+ --fetch_alias_names "output_embedding"
+
+```
+
+参数含义说明
+* `dirname`: 需要转换的模型文件存储路径,Program 结构文件和参数文件均保存在此目录。
+* `model_filename`: 存储需要转换的模型 Inference Program 结构的文件名称。如果设置为 None ,则使用 `__model__` 作为默认的文件名
+* `params_filename`: 存储需要转换的模型所有参数的文件名称。当且仅当所有模型参数被保>存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为 None
+* `server_path`: 转换后的模型文件和配置文件的存储路径。默认值为 serving_server
+* `client_path`: 转换后的客户端配置文件存储路径。默认值为 serving_client
+* `fetch_alias_names`: 模型输出的别名设置,比如输入的 input_ids 等,都可以指定成其他名字,默认不指定
+* `feed_alias_names`: 模型输入的别名设置,比如输出 pooled_out 等,都可以重新指定成其他模型,默认不指定
+
+也可以运行下面的 bash 脚本:
+```
+sh scripts/export_to_serving.sh
+```
+
+Paddle Serving的部署有两种方式,第一种方式是Pipeline的方式,第二种是C++的方式,下面分别介绍这两种方式的用法:
+
+#### 2.3.1 Pipeline方式
+
+启动 Pipeline Server:
+
+```
+python web_service.py
+```
+
+启动客户端调用 Server。
+
+首先修改rpc_client.py中需要预测的样本:
+
+```
+list_data = [
+ "国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据",
+ "试论翻译过程中的文化差异与语言空缺翻译过程,文化差异,语言空缺,文化对比"
+]
+```
+然后运行:
+
+```
+python rpc_client.py
+```
+模型的输出为:
+
+```
+{'0': '国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据', '1': '试论翻译过程中的文化差异与语言空缺翻译过程,文化差异,语言空缺,文化对比'}
+PipelineClient::predict pack_data time:1641450851.3752182
+PipelineClient::predict before time:1641450851.375738
+['output_embedding']
+(2, 256)
+[[ 0.07830612 -0.14036864 0.03433796 -0.14967982 -0.03386067 0.06630666
+ 0.01357943 0.03531194 0.02411093 0.02000859 0.05724002 -0.08119463
+ ......
+```
+
+可以看到客户端发送了2条文本,返回了2个 embedding 向量
+
+#### 2.3.2 C++的方式
+
+启动C++的Serving:
+
+```
+python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_id 2 --thread 5 --ir_optim True --use_trt --precision FP16
+```
+也可以使用脚本:
+
+```
+sh deploy/C++/start_server.sh
+```
+Client 可以使用 http 或者 rpc 两种方式,rpc 的方式为:
+
+```
+python deploy/C++/rpc_client.py
+```
+运行的输出为:
+```
+I0209 20:40:07.978225 20896 general_model.cpp:490] [client]logid=0,client_cost=395.695ms,server_cost=392.559ms.
+time to cost :0.3960278034210205 seconds
+{'output_embedding': array([[ 9.01343748e-02, -1.21870913e-01, 1.32834800e-02,
+ -1.57673359e-01, -2.60387752e-02, 6.98455423e-02,
+ 1.58108603e-02, 3.89952064e-02, 3.22783105e-02,
+ 3.49135026e-02, 7.66086206e-02, -9.12970975e-02,
+ 6.25643134e-02, 7.21886680e-02, 7.03565404e-02,
+ 5.44054210e-02, 3.25332815e-03, 5.01751155e-02,
+......
+```
+可以看到服务端返回了向量
+
+或者使用 http 的客户端访问模式:
+
+```
+python deploy/C++/http_client.py
+```
+运行的输出为:
+
+```
+(2, 64)
+(2, 64)
+outputs {
+ tensor {
+ float_data: 0.09013437479734421
+ float_data: -0.12187091261148453
+ float_data: 0.01328347995877266
+ float_data: -0.15767335891723633
+......
+```
+可以看到服务端返回了向量
+
+
+
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/base_model.py b/examples/Pipeline/PaddleNLP/semantic_indexing/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c471d126c2649fee7554fa8f026284c7300ada2f
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/base_model.py
@@ -0,0 +1,187 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import sys
+
+import numpy as np
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class SemanticIndexBase(nn.Layer):
+ def __init__(self, pretrained_model, dropout=None, output_emb_size=None):
+ super().__init__()
+ self.ptm = pretrained_model
+ self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
+
+ # if output_emb_size is not None, then add Linear layer to reduce embedding_size,
+ # we recommend set output_emb_size = 256 considering the trade-off beteween
+ # recall performance and efficiency
+
+ self.output_emb_size = output_emb_size
+ if output_emb_size > 0:
+ weight_attr = paddle.ParamAttr(
+ initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
+ self.emb_reduce_linear = paddle.nn.Linear(
+ 768, output_emb_size, weight_attr=weight_attr)
+
+ @paddle.jit.to_static(input_spec=[
+ paddle.static.InputSpec(
+ shape=[None, None], dtype='int64'), paddle.static.InputSpec(
+ shape=[None, None], dtype='int64')
+ ])
+ def get_pooled_embedding(self,
+ input_ids,
+ token_type_ids=None,
+ position_ids=None,
+ attention_mask=None):
+ _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
+ attention_mask)
+
+ if self.output_emb_size > 0:
+ cls_embedding = self.emb_reduce_linear(cls_embedding)
+ cls_embedding = self.dropout(cls_embedding)
+ cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)
+
+ return cls_embedding
+
+ def get_semantic_embedding(self, data_loader):
+ self.eval()
+ with paddle.no_grad():
+ for batch_data in data_loader:
+ input_ids, token_type_ids = batch_data
+ input_ids = paddle.to_tensor(input_ids)
+ token_type_ids = paddle.to_tensor(token_type_ids)
+
+ text_embeddings = self.get_pooled_embedding(
+ input_ids, token_type_ids=token_type_ids)
+
+ yield text_embeddings
+
+ def cosine_sim(self,
+ query_input_ids,
+ title_input_ids,
+ query_token_type_ids=None,
+ query_position_ids=None,
+ query_attention_mask=None,
+ title_token_type_ids=None,
+ title_position_ids=None,
+ title_attention_mask=None):
+
+ query_cls_embedding = self.get_pooled_embedding(
+ query_input_ids, query_token_type_ids, query_position_ids,
+ query_attention_mask)
+
+ title_cls_embedding = self.get_pooled_embedding(
+ title_input_ids, title_token_type_ids, title_position_ids,
+ title_attention_mask)
+
+ cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
+ axis=-1)
+ return cosine_sim
+
+ @abc.abstractmethod
+ def forward(self):
+ pass
+
+
+class SemanticIndexBaseStatic(nn.Layer):
+ def __init__(self, pretrained_model, dropout=None, output_emb_size=None):
+ super().__init__()
+ self.ptm = pretrained_model
+ self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
+
+ # if output_emb_size is not None, then add Linear layer to reduce embedding_size,
+ # we recommend set output_emb_size = 256 considering the trade-off beteween
+ # recall performance and efficiency
+
+ self.output_emb_size = output_emb_size
+ if output_emb_size > 0:
+ weight_attr = paddle.ParamAttr(
+ initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
+ self.emb_reduce_linear = paddle.nn.Linear(
+ 768, output_emb_size, weight_attr=weight_attr)
+
+ @paddle.jit.to_static(input_spec=[
+ paddle.static.InputSpec(
+ shape=[None, None], dtype='int64'), paddle.static.InputSpec(
+ shape=[None, None], dtype='int64')
+ ])
+ def get_pooled_embedding(self,
+ input_ids,
+ token_type_ids=None,
+ position_ids=None,
+ attention_mask=None):
+ _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
+ attention_mask)
+
+ if self.output_emb_size > 0:
+ cls_embedding = self.emb_reduce_linear(cls_embedding)
+ cls_embedding = self.dropout(cls_embedding)
+ cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)
+
+ return cls_embedding
+
+ def get_semantic_embedding(self, data_loader):
+ self.eval()
+ with paddle.no_grad():
+ for batch_data in data_loader:
+ input_ids, token_type_ids = batch_data
+ input_ids = paddle.to_tensor(input_ids)
+ token_type_ids = paddle.to_tensor(token_type_ids)
+
+ text_embeddings = self.get_pooled_embedding(
+ input_ids, token_type_ids=token_type_ids)
+
+ yield text_embeddings
+
+ def cosine_sim(self,
+ query_input_ids,
+ title_input_ids,
+ query_token_type_ids=None,
+ query_position_ids=None,
+ query_attention_mask=None,
+ title_token_type_ids=None,
+ title_position_ids=None,
+ title_attention_mask=None):
+
+ query_cls_embedding = self.get_pooled_embedding(
+ query_input_ids, query_token_type_ids, query_position_ids,
+ query_attention_mask)
+
+ title_cls_embedding = self.get_pooled_embedding(
+ title_input_ids, title_token_type_ids, title_position_ids,
+ title_attention_mask)
+
+ cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
+ axis=-1)
+ return cosine_sim
+
+ def forward(self,
+ input_ids,
+ token_type_ids=None,
+ position_ids=None,
+ attention_mask=None):
+ _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
+ attention_mask)
+
+ if self.output_emb_size > 0:
+ cls_embedding = self.emb_reduce_linear(cls_embedding)
+ cls_embedding = self.dropout(cls_embedding)
+ cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)
+
+ return cls_embedding
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/http_client.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/http_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..a976ad9fc33b06ce7148adc7153d4b35183e31c0
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/http_client.py
@@ -0,0 +1,81 @@
+# coding:utf-8
+# pylint: disable=doc-string-missing
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import time
+import numpy as np
+import requests
+import json
+
+from paddle_serving_client import HttpClient
+import paddlenlp as ppnlp
+
+
+def convert_example(example,
+ tokenizer,
+ max_seq_length=512,
+ pad_to_max_seq_len=True):
+ list_input_ids = []
+ list_token_type_ids = []
+ for text in example:
+ encoded_inputs = tokenizer(
+ text=text,
+ max_seq_len=max_seq_length,
+ pad_to_max_seq_len=pad_to_max_seq_len)
+ input_ids = encoded_inputs["input_ids"]
+ token_type_ids = encoded_inputs["token_type_ids"]
+
+ list_input_ids.append(input_ids)
+ list_token_type_ids.append(token_type_ids)
+ return list_input_ids, list_token_type_ids
+
+
+# 启动python客户端
+endpoint_list = ['127.0.0.1:9393']
+client = HttpClient()
+client.load_client_config('serving_client')
+client.connect(endpoint_list)
+feed_names = client.feed_names_
+fetch_names = client.fetch_names_
+print(feed_names)
+print(fetch_names)
+
+# 创建tokenizer
+tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
+max_seq_len = 64
+
+# 数据预处理
+
+list_data = ['国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据.', '面向生态系统服务的生态系统分类方案研发与应用']
+# for i in range(5):
+# list_data.extend(list_data)
+# print(len(list_data))
+examples = convert_example(list_data, tokenizer, max_seq_length=max_seq_len)
+print(examples)
+
+feed_dict = {}
+feed_dict['input_ids'] = np.array(examples[0])
+feed_dict['token_type_ids'] = np.array(examples[1])
+
+print(feed_dict['input_ids'].shape)
+print(feed_dict['token_type_ids'].shape)
+
+# batch设置为True表示的是批量预测
+b_start = time.time()
+result = client.predict(feed=feed_dict, fetch=fetch_names, batch=True)
+b_end = time.time()
+print(result)
+print("time to cost :{} seconds".format(b_end - b_start))
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/rpc_client.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/rpc_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ea4c245f2a10256166a512f9282282e69d9997b
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/rpc_client.py
@@ -0,0 +1,77 @@
+# coding:utf-8
+# pylint: disable=doc-string-missing
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import time
+import numpy as np
+
+from paddle_serving_client import Client
+import paddlenlp as ppnlp
+
+
+def convert_example(example,
+ tokenizer,
+ max_seq_length=512,
+ pad_to_max_seq_len=True):
+ list_input_ids = []
+ list_token_type_ids = []
+ for text in example:
+ encoded_inputs = tokenizer(
+ text=text,
+ max_seq_len=max_seq_length,
+ pad_to_max_seq_len=pad_to_max_seq_len)
+ input_ids = encoded_inputs["input_ids"]
+ token_type_ids = encoded_inputs["token_type_ids"]
+ list_input_ids.append(input_ids)
+ list_token_type_ids.append(token_type_ids)
+ return list_input_ids, list_token_type_ids
+
+
+# 启动python客户端
+endpoint_list = ['127.0.0.1:9393']
+client = Client()
+client.load_client_config('serving_client')
+client.connect(endpoint_list)
+feed_names = client.feed_names_
+fetch_names = client.fetch_names_
+print(feed_names)
+print(fetch_names)
+
+# 创建tokenizer
+tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
+max_seq_len = 64
+
+# 数据预处理
+
+list_data = ['国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据.', '面向生态系统服务的生态系统分类方案研发与应用']
+# for i in range(5):
+# list_data.extend(list_data)
+# print(len(list_data))
+examples = convert_example(list_data, tokenizer, max_seq_length=max_seq_len)
+print(examples)
+
+feed_dict = {}
+feed_dict['input_ids'] = np.array(examples[0])
+feed_dict['token_type_ids'] = np.array(examples[1])
+
+print(feed_dict['input_ids'].shape)
+print(feed_dict['token_type_ids'].shape)
+# batch设置为True表示的是批量预测
+b_start = time.time()
+result = client.predict(feed=feed_dict, fetch=fetch_names, batch=True)
+b_end = time.time()
+print("time to cost :{} seconds".format(b_end - b_start))
+print(result)
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/start_server.sh b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/start_server.sh
new file mode 100644
index 0000000000000000000000000000000000000000..55d380d6f87396887675a008c54bb8544ce2a793
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/C++/start_server.sh
@@ -0,0 +1 @@
+python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_id 2 --thread 5 --ir_optim True --use_trt --precision FP16
\ No newline at end of file
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/config_nlp.yml b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/config_nlp.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d896adbfa1f9671cb569137637cf5f3ec169ef69
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/config_nlp.yml
@@ -0,0 +1,34 @@
+# worker_num, 最大并发数。当build_dag_each_worker=True时, 框架会创建worker_num个进程,每个进程内构建grpcSever和DAG
+# 当build_dag_each_worker=False时,框架会设置主线程grpc线程池的max_workers=worker_num
+worker_num: 20
+# build_dag_each_worker, False,框架在进程内创建一条DAG;True,框架会每个进程内创建多个独立的DAG
+build_dag_each_worker: false
+
+dag:
+ # op资源类型, True, 为线程模型;False,为进程模型
+ is_thread_op: False
+ # 使用性能分析, True,生成Timeline性能数据,对性能有一定影响;False为不使用
+ tracer:
+ interval_s: 10
+# http端口, rpc_port和http_port不允许同时为空。当rpc_port可用且http_port为空时,不自动生成http_port
+http_port: 18082
+# rpc端口, rpc_port和http_port不允许同时为空。当rpc_port为空且http_port不为空时,会自动将rpc_port设置为http_port+1
+rpc_port: 8088
+op:
+ ernie:
+ # 并发数,is_thread_op=True时,为线程并发;否则为进程并发
+ concurrency: 1
+ # 当op配置没有server_endpoints时,从local_service_conf读取本地服务配置
+ local_service_conf:
+ # client类型,包括brpc, grpc和local_predictor.local_predictor不启动Serving服务,进程内预测
+ client_type: local_predictor
+ #ir_optim
+ ir_optim: True
+ # device_type, 0=cpu, 1=gpu, 2=tensorRT, 3=arm cpu, 4=kunlun xpu
+ device_type: 1
+ # 计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
+ devices: '2'
+ # Fetch结果列表,以client_config中fetch_var的alias_name为准, 如果没有设置则全部返回
+ fetch_list: ['output_embedding']
+ # 模型路径
+ model_config: ../../serving_server/
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/deploy.sh b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/deploy.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fe8f071e0a47a47f5dc24d84ea4eaaf8e7503c06
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/deploy.sh
@@ -0,0 +1 @@
+python predict.py --model_dir=../../output
\ No newline at end of file
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/predict.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e81dbb5092ce6178587f5aa8f40d758f4446a42
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/predict.py
@@ -0,0 +1,292 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import sys
+
+import numpy as np
+import paddle
+import paddlenlp as ppnlp
+from scipy.special import softmax
+from scipy import spatial
+from paddle import inference
+from paddlenlp.data import Stack, Tuple, Pad
+from paddlenlp.datasets import load_dataset
+from paddlenlp.utils.log import logger
+
+sys.path.append('.')
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument("--model_dir", type=str, required=True,
+ help="The directory to static model.")
+
+parser.add_argument("--max_seq_length", default=128, type=int,
+ help="The maximum total input sequence length after tokenization. Sequences "
+ "longer than this will be truncated, sequences shorter will be padded.")
+parser.add_argument("--batch_size", default=15, type=int,
+ help="Batch size per GPU/CPU for training.")
+parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu",
+ help="Select which device to train model, defaults to gpu.")
+
+parser.add_argument('--use_tensorrt', default=False, type=eval, choices=[True, False],
+ help='Enable to use tensorrt to speed up.')
+parser.add_argument("--precision", default="fp32", type=str, choices=["fp32", "fp16", "int8"],
+ help='The tensorrt precision.')
+
+parser.add_argument('--cpu_threads', default=10, type=int,
+ help='Number of threads to predict when using cpu.')
+parser.add_argument('--enable_mkldnn', default=False, type=eval, choices=[True, False],
+ help='Enable to use mkldnn to speed up when using cpu.')
+
+parser.add_argument("--benchmark", type=eval, default=False,
+ help="To log some information about environment and running.")
+parser.add_argument("--save_log_path", type=str, default="./log_output/",
+ help="The file path to save log.")
+args = parser.parse_args()
+# yapf: enable
+
+
+def convert_example(example,
+ tokenizer,
+ max_seq_length=512,
+ pad_to_max_seq_len=False):
+ """
+ Builds model inputs from a sequence.
+
+ A BERT sequence has the following format:
+
+ - single sequence: ``[CLS] X [SEP]``
+
+ Args:
+ example(obj:`list(str)`): The list of text to be converted to ids.
+ tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
+ which contains most of the methods. Users should refer to the superclass for more information regarding methods.
+ max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.
+ Sequences longer than this will be truncated, sequences shorter will be padded.
+ is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
+
+ Returns:
+ input_ids(obj:`list[int]`): The list of query token ids.
+ token_type_ids(obj: `list[int]`): List of query sequence pair mask.
+ """
+
+ result = []
+ for key, text in example.items():
+ encoded_inputs = tokenizer(
+ text=text,
+ max_seq_len=max_seq_length,
+ pad_to_max_seq_len=pad_to_max_seq_len)
+ input_ids = encoded_inputs["input_ids"]
+ token_type_ids = encoded_inputs["token_type_ids"]
+ result += [input_ids, token_type_ids]
+ return result
+
+
+class Predictor(object):
+ def __init__(self,
+ model_dir,
+ device="gpu",
+ max_seq_length=128,
+ batch_size=32,
+ use_tensorrt=False,
+ precision="fp32",
+ cpu_threads=10,
+ enable_mkldnn=False):
+ self.max_seq_length = max_seq_length
+ self.batch_size = batch_size
+
+ model_file = model_dir + "/inference.pdmodel"
+ params_file = model_dir + "/inference.pdiparams"
+ if not os.path.exists(model_file):
+ raise ValueError("not find model file path {}".format(model_file))
+ if not os.path.exists(params_file):
+ raise ValueError("not find params file path {}".format(params_file))
+ config = paddle.inference.Config(model_file, params_file)
+
+ if device == "gpu":
+ # set GPU configs accordingly
+ # such as intialize the gpu memory, enable tensorrt
+ config.enable_use_gpu(100, 0)
+ precision_map = {
+ "fp16": inference.PrecisionType.Half,
+ "fp32": inference.PrecisionType.Float32,
+ "int8": inference.PrecisionType.Int8
+ }
+ precision_mode = precision_map[precision]
+
+ if args.use_tensorrt:
+ config.enable_tensorrt_engine(
+ max_batch_size=batch_size,
+ min_subgraph_size=30,
+ precision_mode=precision_mode)
+ elif device == "cpu":
+ # set CPU configs accordingly,
+ # such as enable_mkldnn, set_cpu_math_library_num_threads
+ config.disable_gpu()
+ if args.enable_mkldnn:
+ # cache 10 different shapes for mkldnn to avoid memory leak
+ config.set_mkldnn_cache_capacity(10)
+ config.enable_mkldnn()
+ config.set_cpu_math_library_num_threads(args.cpu_threads)
+ elif device == "xpu":
+ # set XPU configs accordingly
+ config.enable_xpu(100)
+
+ config.switch_use_feed_fetch_ops(False)
+ self.predictor = paddle.inference.create_predictor(config)
+ self.input_handles = [
+ self.predictor.get_input_handle(name)
+ for name in self.predictor.get_input_names()
+ ]
+ self.output_handle = self.predictor.get_output_handle(
+ self.predictor.get_output_names()[0])
+
+ if args.benchmark:
+ import auto_log
+ pid = os.getpid()
+ self.autolog = auto_log.AutoLogger(
+ model_name="ernie-1.0",
+ model_precision=precision,
+ batch_size=self.batch_size,
+ data_shape="dynamic",
+ save_path=args.save_log_path,
+ inference_config=config,
+ pids=pid,
+ process_name=None,
+ gpu_ids=0,
+ time_keys=[
+ 'preprocess_time', 'inference_time', 'postprocess_time'
+ ],
+ warmup=0,
+ logger=logger)
+
+ def extract_embedding(self, data, tokenizer):
+ """
+ Predicts the data labels.
+
+ Args:
+ data (obj:`List(str)`): The batch data whose each element is a raw text.
+ tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
+ which contains most of the methods. Users should refer to the superclass for more information regarding methods.
+
+ Returns:
+ results(obj:`dict`): All the feature vectors.
+ """
+ if args.benchmark:
+ self.autolog.times.start()
+
+ examples = []
+ for text in data:
+ input_ids, segment_ids = convert_example(text, tokenizer)
+ examples.append((input_ids, segment_ids))
+
+ batchify_fn = lambda samples, fn=Tuple(
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
+ ): fn(samples)
+
+ if args.benchmark:
+ self.autolog.times.stamp()
+
+ input_ids, segment_ids = batchify_fn(examples)
+ self.input_handles[0].copy_from_cpu(input_ids)
+ self.input_handles[1].copy_from_cpu(segment_ids)
+ self.predictor.run()
+ logits = self.output_handle.copy_to_cpu()
+ if args.benchmark:
+ self.autolog.times.stamp()
+
+ if args.benchmark:
+ self.autolog.times.end(stamp=True)
+
+ return logits
+
+ def predict(self, data, tokenizer):
+ """
+ Predicts the data labels.
+
+ Args:
+ data (obj:`List(str)`): The batch data whose each element is a raw text.
+ tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
+ which contains most of the methods. Users should refer to the superclass for more information regarding methods.
+
+ Returns:
+ results(obj:`dict`): All the predictions probs.
+ """
+ if args.benchmark:
+ self.autolog.times.start()
+
+ examples = []
+ for idx, text in enumerate(data):
+ input_ids, segment_ids = convert_example({idx: text[0]}, tokenizer)
+ title_ids, title_segment_ids = convert_example({
+ idx: text[1]
+ }, tokenizer)
+ examples.append(
+ (input_ids, segment_ids, title_ids, title_segment_ids))
+
+ batchify_fn = lambda samples, fn=Tuple(
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
+ ): fn(samples)
+
+ if args.benchmark:
+ self.autolog.times.stamp()
+
+ query_ids, query_segment_ids, title_ids, title_segment_ids = batchify_fn(
+ examples)
+ self.input_handles[0].copy_from_cpu(query_ids)
+ self.input_handles[1].copy_from_cpu(query_segment_ids)
+ self.predictor.run()
+ query_logits = self.output_handle.copy_to_cpu()
+
+ self.input_handles[0].copy_from_cpu(title_ids)
+ self.input_handles[1].copy_from_cpu(title_segment_ids)
+ self.predictor.run()
+ title_logits = self.output_handle.copy_to_cpu()
+
+ if args.benchmark:
+ self.autolog.times.stamp()
+
+ if args.benchmark:
+ self.autolog.times.end(stamp=True)
+ result = [
+ float(1 - spatial.distance.cosine(arr1, arr2))
+ for arr1, arr2 in zip(query_logits, title_logits)
+ ]
+ return result
+
+
+if __name__ == "__main__":
+ # Define predictor to do prediction.
+ predictor = Predictor(args.model_dir, args.device, args.max_seq_length,
+ args.batch_size, args.use_tensorrt, args.precision,
+ args.cpu_threads, args.enable_mkldnn)
+
+ # ErnieTinyTokenizer is special for ernie-tiny pretained model.
+ output_emb_size = 256
+ tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
+ id2corpus = {0: '国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据'}
+ corpus_list = [{idx: text} for idx, text in id2corpus.items()]
+ res = predictor.extract_embedding(corpus_list, tokenizer)
+ print(res.shape)
+ print(res)
+ corpus_list = [['中西方语言与文化的差异', '中西方文化差异以及语言体现中西方文化,差异,语言体现'],
+ ['中西方语言与文化的差异', '飞桨致力于让深度学习技术的创新与应用更简单']]
+ res = predictor.predict(corpus_list, tokenizer)
+ print(res)
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/rpc_client.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/rpc_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..03863db6114b7c381dae17ee3bf33f00f15d8f4a
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/rpc_client.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+import numpy as np
+
+from paddle_serving_server.pipeline import PipelineClient
+
+client = PipelineClient()
+client.connect(['127.0.0.1:8088'])
+
+list_data = [
+ "国有企业引入非国有资本对创新绩效的影响——基于制造业国有上市公司的经验证据",
+ "试论翻译过程中的文化差异与语言空缺翻译过程,文化差异,语言空缺,文化对比"
+]
+feed = {}
+for i, item in enumerate(list_data):
+ feed[str(i)] = item
+
+print(feed)
+start_time = time.time()
+ret = client.predict(feed_dict=feed)
+end_time = time.time()
+print("time to cost :{} seconds".format(end_time - start_time))
+
+result = np.array(eval(ret.value[0]))
+print(ret.key)
+print(result.shape)
+print(result)
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/web_service.py b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/web_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ad12032b3c92d72a5297f15d732b7dfbd19589e
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/deploy/python/web_service.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import numpy as np
+import sys
+
+from paddle_serving_server.web_service import WebService, Op
+
+_LOGGER = logging.getLogger()
+
+
+def convert_example(example,
+ tokenizer,
+ max_seq_length=512,
+ pad_to_max_seq_len=False):
+ result = []
+ for text in example:
+ encoded_inputs = tokenizer(
+ text=text,
+ max_seq_len=max_seq_length,
+ pad_to_max_seq_len=pad_to_max_seq_len)
+ input_ids = encoded_inputs["input_ids"]
+ token_type_ids = encoded_inputs["token_type_ids"]
+ result += [input_ids, token_type_ids]
+ return result
+
+
+class ErnieOp(Op):
+ def init_op(self):
+ import paddlenlp as ppnlp
+ self.tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained(
+ 'ernie-1.0')
+
+ def preprocess(self, input_dicts, data_id, log_id):
+ from paddlenlp.data import Stack, Tuple, Pad
+
+ (_, input_dict), = input_dicts.items()
+ print("input dict", input_dict)
+ batch_size = len(input_dict.keys())
+ examples = []
+ for i in range(batch_size):
+ input_ids, segment_ids = convert_example([input_dict[str(i)]],
+ self.tokenizer)
+ examples.append((input_ids, segment_ids))
+ batchify_fn = lambda samples, fn=Tuple(
+ Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # input
+ Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # segment
+ ): fn(samples)
+ input_ids, segment_ids = batchify_fn(examples)
+ feed_dict = {}
+ feed_dict['input_ids'] = input_ids
+ feed_dict['token_type_ids'] = segment_ids
+ return feed_dict, False, None, ""
+
+ def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
+ new_dict = {}
+ new_dict["output_embedding"] = str(fetch_dict["output_embedding"]
+ .tolist())
+ return new_dict, None, ""
+
+
+class ErnieService(WebService):
+ def get_pipeline_response(self, read_op):
+ ernie_op = ErnieOp(name="ernie", input_ops=[read_op])
+ return ernie_op
+
+
+ernie_service = ErnieService(name="ernie")
+ernie_service.prepare_pipeline_config("config_nlp.yml")
+ernie_service.run_service()
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/export_model.py b/examples/Pipeline/PaddleNLP/semantic_indexing/export_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..da468ea7b2c3af6eff093eef98a3e4f9393f9b3d
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/export_model.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+from functools import partial
+
+import numpy as np
+import paddle
+import paddle.nn.functional as F
+import paddlenlp as ppnlp
+from paddlenlp.data import Stack, Tuple, Pad
+
+from base_model import SemanticIndexBase, SemanticIndexBaseStatic
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument("--params_path", type=str, required=True,
+ default='./checkpoint/model_900/model_state.pdparams', help="The path to model parameters to be loaded.")
+parser.add_argument("--output_path", type=str, default='./output',
+ help="The path of model parameter in static graph to be saved.")
+args = parser.parse_args()
+# yapf: enable
+
+if __name__ == "__main__":
+ # If you want to use ernie1.0 model, plesace uncomment the following code
+ output_emb_size = 256
+
+ pretrained_model = ppnlp.transformers.ErnieModel.from_pretrained(
+ "ernie-1.0")
+
+ tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
+ model = SemanticIndexBaseStatic(
+ pretrained_model, output_emb_size=output_emb_size)
+
+ if args.params_path and os.path.isfile(args.params_path):
+ state_dict = paddle.load(args.params_path)
+ model.set_dict(state_dict)
+ print("Loaded parameters from %s" % args.params_path)
+
+ model.eval()
+
+ # Convert to static graph with specific input description
+ model = paddle.jit.to_static(
+ model,
+ input_spec=[
+ paddle.static.InputSpec(
+ shape=[None, None], dtype="int64"), # input_ids
+ paddle.static.InputSpec(
+ shape=[None, None], dtype="int64") # segment_ids
+ ])
+ # Save in static graph model.
+ save_path = os.path.join(args.output_path, "inference")
+ paddle.jit.save(model, save_path)
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/export_to_serving.py b/examples/Pipeline/PaddleNLP/semantic_indexing/export_to_serving.py
new file mode 100644
index 0000000000000000000000000000000000000000..c24f931510e5662ae1b824049d1ac35c4ef34076
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/export_to_serving.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import paddle_serving_client.io as serving_io
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument("--dirname", type=str, required=True,
+ default='./output', help="Path of saved model files. Program file and parameter files are saved in this directory.")
+parser.add_argument("--model_filename", type=str, required=True,
+ default='inference.get_pooled_embedding.pdmodel', help="The name of file to load the inference program. If it is None, the default filename __model__ will be used.")
+parser.add_argument("--params_filename", type=str, required=True,
+ default='inference.get_pooled_embedding.pdiparams', help="The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None.")
+parser.add_argument("--server_path", type=str, default='./serving_server',
+ help="The path of server parameter in static graph to be saved.")
+parser.add_argument("--client_path", type=str, default='./serving_client',
+ help="The path of client parameter in static graph to be saved.")
+parser.add_argument("--feed_alias_names", type=str, default=None,
+ help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of feed vars')
+parser.add_argument("--fetch_alias_names", type=str, default=None,
+ help='set alias names for feed vars, split by comma \',\', you should run --show_proto to check the number of fetch vars')
+parser.add_argument("--show_proto", type=bool, default=False,
+ help='If yes, you can preview the proto and then determine your feed var alias name and fetch var alias name.')
+# yapf: enable
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ serving_io.inference_model_to_serving(
+ dirname=args.dirname,
+ serving_server=args.server_path,
+ serving_client=args.client_path,
+ model_filename=args.model_filename,
+ params_filename=args.params_filename,
+ show_proto=args.show_proto,
+ feed_alias_names=args.feed_alias_names,
+ fetch_alias_names=args.fetch_alias_names)
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_model.sh b/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_model.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7c79266219cea03e16968ed0d00a3755615c7432
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_model.sh
@@ -0,0 +1 @@
+python export_model.py --params_path checkpoints/model_40/model_state.pdparams --output_path=./output
\ No newline at end of file
diff --git a/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_to_serving.sh b/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_to_serving.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b0d7a422551fd09eb1a28cfacdf47237a8efc795
--- /dev/null
+++ b/examples/Pipeline/PaddleNLP/semantic_indexing/scripts/export_to_serving.sh
@@ -0,0 +1,7 @@
+python export_to_serving.py \
+ --dirname "output" \
+ --model_filename "inference.get_pooled_embedding.pdmodel" \
+ --params_filename "inference.get_pooled_embedding.pdiparams" \
+ --server_path "serving_server" \
+ --client_path "serving_client" \
+ --fetch_alias_names "output_embedding"