未验证 提交 290f8bc2 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge branch 'develop' into fix_encryption_getdata.sh

......@@ -136,8 +136,8 @@ if (WITH_TRT)
endif()
if (WITH_LITE)
ADD_LIBRARY(paddle_api_full_bundled STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET paddle_api_full_bundled PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/lite/cxx/lib/libpaddle_api_full_bundled.a)
ADD_LIBRARY(paddle_full_api_shared STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET paddle_full_api_shared PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/lite/cxx/lib/libpaddle_full_api_shared.so)
if (WITH_XPU)
ADD_LIBRARY(xpuapi SHARED IMPORTED GLOBAL)
......@@ -157,7 +157,7 @@ LIST(APPEND paddle_depend_libs
xxhash)
if(WITH_LITE)
LIST(APPEND paddle_depend_libs paddle_api_full_bundled)
LIST(APPEND paddle_depend_libs paddle_full_api_shared)
if(WITH_XPU)
LIST(APPEND paddle_depend_libs xpuapi xpurt)
endif()
......
......@@ -34,6 +34,42 @@
**A:** http rpc
## 安装问题
#### Q: pip install安装whl包过程,报错信息如下:
```
Collecting opencv-python
Using cached opencv-python-4.3.0.38.tar.gz (88.0 MB)
Installing build dependencies ... done
Getting requirements to build wheel ... error
ERROR: Command errored out with exit status 1:
command: /home/work/Python-2.7.17/build/bin/python /home/work/Python-2.7.17/build/lib/python2.7/site-packages/pip/_vendor/pep517/_in_process.py get_requires_for_build_wheel /tmp/tmpLiweA9
cwd: /tmp/pip-install-_w6AUI/opencv-python
Complete output (22 lines):
Traceback (most recent call last):
File "/home/work/Python-2.7.17/build/lib/python2.7/site-packages/pip/_vendor/pep517/_in_process.py", line 280, in <module>
main()
File "/home/work/Python-2.7.17/build/lib/python2.7/site-packages/pip/_vendor/pep517/_in_process.py", line 263, in main
json_out['return_val'] = hook(**hook_input['kwargs'])
File "/home/work/Python-2.7.17/build/lib/python2.7/site-packages/pip/_vendor/pep517/_in_process.py", line 114, in get_requires_for_build_wheel
return hook(config_settings)
File "/tmp/pip-build-env-AUCbP4/overlay/lib/python2.7/site-packages/setuptools/build_meta.py", line 146, in get_requires_for_build_wheel
return self._get_build_requires(config_settings, requirements=['wheel'])
File "/tmp/pip-build-env-AUCbP4/overlay/lib/python2.7/site-packages/setuptools/build_meta.py", line 127, in _get_build_requires
self.run_setup()
File "/tmp/pip-build-env-AUCbP4/overlay/lib/python2.7/site-packages/setuptools/build_meta.py", line 243, in run_setup
self).run_setup(setup_script=setup_script)
File "/tmp/pip-build-env-AUCbP4/overlay/lib/python2.7/site-packages/setuptools/build_meta.py", line 142, in run_setup
exec(compile(code, __file__, 'exec'), locals())
File "setup.py", line 448, in <module>
main()
File "setup.py", line 99, in main
% {"ext": re.escape(sysconfig.get_config_var("EXT_SUFFIX"))}
File "/home/work/Python-2.7.17/build/lib/python2.7/re.py", line 210, in escape
s = list(pattern)
TypeError: 'NoneType' object is not iterable
```
**A:** 指定opencv-python版本安装,pip install opencv-python==4.2.0.32,再安装whl包
## 编译问题
......
......@@ -62,7 +62,7 @@ public class PipelineClientExample {
return false;
}
}
PipelineFuture future = StaticPipelineClient.client.asyn_pr::qedict(feed_data, fetch,false,0);
PipelineFuture future = StaticPipelineClient.client.asyn_predict(feed_data, fetch,false,0);
HashMap<String,String> result = future.get();
if (result == null) {
return false;
......
......@@ -37,7 +37,7 @@ public class StaticPipelineClient {
System.out.println("already connect.");
return true;
}
succ = clieint.connect(target);
succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
......
......@@ -128,20 +128,22 @@ class FluidArmAnalysisCore : public FluidFamilyCore {
config.DisableGpu();
config.SetCpuMathLibraryNumThreads(1);
if (params.enable_memory_optimization()) {
config.EnableMemoryOptim();
if (params.use_lite()) {
config.EnableLiteEngine(PrecisionType::kFloat32, true);
}
if (params.enable_memory_optimization()) {
config.EnableMemoryOptim();
if (params.use_xpu()) {
config.EnableXpu(2 * 1024 * 1024);
}
if (params.use_lite()) {
config.EnableLiteEngine(PrecisionType::kFloat32, true);
if (params.enable_memory_optimization()) {
config.EnableMemoryOptim();
}
if (params.use_xpu()) {
config.EnableXpu(100);
if (params.enable_ir_optimization()) {
config.SwitchIrOptim(true);
} else {
config.SwitchIrOptim(false);
}
config.SwitchSpecifyInputNames(true);
......@@ -173,6 +175,14 @@ class FluidArmAnalysisDirCore : public FluidFamilyCore {
config.SwitchSpecifyInputNames(true);
config.SetCpuMathLibraryNumThreads(1);
if (params.use_lite()) {
config.EnableLiteEngine(PrecisionType::kFloat32, true);
}
if (params.use_xpu()) {
config.EnableXpu(2 * 1024 * 1024);
}
if (params.enable_memory_optimization()) {
config.EnableMemoryOptim();
}
......@@ -183,14 +193,6 @@ class FluidArmAnalysisDirCore : public FluidFamilyCore {
config.SwitchIrOptim(false);
}
if (params.use_lite()) {
config.EnableLiteEngine(PrecisionType::kFloat32, true);
}
if (params.use_xpu()) {
config.EnableXpu(100);
}
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core = CreatePredictor(config);
if (NULL == _core.get()) {
......
......@@ -99,15 +99,27 @@ if (SERVER)
DEPENDS ${SERVING_SERVER_CORE} server_config_py_proto ${PY_FILES})
add_custom_target(paddle_python ALL DEPENDS ${PADDLE_SERVING_BINARY_DIR}/.timestamp)
elseif(WITH_LITE)
add_custom_command(
OUTPUT ${PADDLE_SERVING_BINARY_DIR}/.timestamp
COMMAND cp -r
${CMAKE_CURRENT_SOURCE_DIR}/paddle_serving_server_gpu/ ${PADDLE_SERVING_BINARY_DIR}/python/
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} gen_version.py
"server_gpu" arm
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
DEPENDS ${SERVING_SERVER_CORE} server_config_py_proto ${PY_FILES})
add_custom_target(paddle_python ALL DEPENDS ${PADDLE_SERVING_BINARY_DIR}/.timestamp)
if(WITH_XPU)
add_custom_command(
OUTPUT ${PADDLE_SERVING_BINARY_DIR}/.timestamp
COMMAND cp -r
${CMAKE_CURRENT_SOURCE_DIR}/paddle_serving_server_gpu/ ${PADDLE_SERVING_BINARY_DIR}/python/
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} gen_version.py
"server_gpu" arm-xpu
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
DEPENDS ${SERVING_SERVER_CORE} server_config_py_proto ${PY_FILES})
add_custom_target(paddle_python ALL DEPENDS ${PADDLE_SERVING_BINARY_DIR}/.timestamp)
else()
add_custom_command(
OUTPUT ${PADDLE_SERVING_BINARY_DIR}/.timestamp
COMMAND cp -r
${CMAKE_CURRENT_SOURCE_DIR}/paddle_serving_server_gpu/ ${PADDLE_SERVING_BINARY_DIR}/python/
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} gen_version.py
"server_gpu" arm
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
DEPENDS ${SERVING_SERVER_CORE} server_config_py_proto ${PY_FILES})
add_custom_target(paddle_python ALL DEPENDS ${PADDLE_SERVING_BINARY_DIR}/.timestamp)
endif()
else()
add_custom_command(
OUTPUT ${PADDLE_SERVING_BINARY_DIR}/.timestamp
......
......@@ -3,9 +3,10 @@
([简体中文](./README_CN.md)|English)
In the example, a BERT model is used for semantic understanding prediction, and the text is represented as a vector, which can be used for further analysis and prediction.
If your python version is 3.X, replace the 'pip' field in the following command with 'pip3',replace 'python' with 'python3'.
### Getting Model
method 1:
This example use model [BERT Chinese Model](https://www.paddlepaddle.org.cn/hubdetail?name=bert_chinese_L-12_H-768_A-12&en_category=SemanticModel) from [Paddlehub](https://github.com/PaddlePaddle/PaddleHub).
Install paddlehub first
......@@ -22,11 +23,13 @@ the 128 in the command above means max_seq_len in BERT model, which is the lengt
the config file and model file for server side are saved in the folder bert_seq128_model.
the config file generated for client side is saved in the folder bert_seq128_client.
method 2:
You can also download the above model from BOS(max_seq_len=128). After decompression, the config file and model file for server side are stored in the bert_chinese_L-12_H-768_A-12_model folder, and the config file generated for client side is stored in the bert_chinese_L-12_H-768_A-12_client folder:
```shell
wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticModel/bert_chinese_L-12_H-768_A-12.tar.gz
tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz
```
if your model is bert_chinese_L-12_H-768_A-12_model, replace the 'bert_seq128_model' field in the following command with 'bert_chinese_L-12_H-768_A-12_model',replace 'bert_seq128_client' with 'bert_chinese_L-12_H-768_A-12_client'.
### Getting Dict and Sample Dataset
......@@ -36,11 +39,11 @@ sh get_data.sh
this script will download Chinese Dictionary File vocab.txt and Chinese Sample Data data-c.txt
### RPC Inference Service
Run
start cpu inference service,Run
```
python -m paddle_serving_server.serve --model bert_seq128_model/ --port 9292 #cpu inference service
```
Or
Or,start gpu inference service,Run
```
python -m paddle_serving_server_gpu.serve --model bert_seq128_model/ --port 9292 --gpu_ids 0 #launch gpu inference service at GPU 0
```
......@@ -59,12 +62,18 @@ head data-c.txt | python bert_client.py --model bert_seq128_client/serving_clien
the client reads data from data-c.txt and send prediction request, the prediction is given by word vector. (Due to massive data in the word vector, we do not print it).
### HTTP Inference Service
start cpu HTTP inference service,Run
```
python bert_web_service.py bert_seq128_model/ 9292 #launch gpu inference service
```
Or,start gpu HTTP inference service,Run
```
export CUDA_VISIBLE_DEVICES=0,1
```
set environmental variable to specify which gpus are used, the command above means gpu 0 and gpu 1 is used.
```
python bert_web_service.py bert_seq128_model/ 9292 #launch gpu inference service
python bert_web_service_gpu.py bert_seq128_model/ 9292 #launch gpu inference service
```
### HTTP Inference
......
......@@ -4,8 +4,9 @@
示例中采用BERT模型进行语义理解预测,将文本表示为向量的形式,可以用来做进一步的分析和预测。
若使用python的版本为3.X, 将以下命令中的pip 替换为pip3, python替换为python3.
### 获取模型
方法1:
示例中采用[Paddlehub](https://github.com/PaddlePaddle/PaddleHub)中的[BERT中文模型](https://www.paddlepaddle.org.cn/hubdetail?name=bert_chinese_L-12_H-768_A-12&en_category=SemanticModel)
请先安装paddlehub
```
......@@ -19,11 +20,15 @@ python prepare_model.py 128
生成server端配置文件与模型文件,存放在bert_seq128_model文件夹。
生成client端配置文件,存放在bert_seq128_client文件夹。
方法2:
您也可以从bos上直接下载上述模型(max_seq_len=128),解压后server端配置文件与模型文件存放在bert_chinese_L-12_H-768_A-12_model文件夹,client端配置文件存放在bert_chinese_L-12_H-768_A-12_client文件夹:
```shell
wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticModel/bert_chinese_L-12_H-768_A-12.tar.gz
tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz
```
若使用bert_chinese_L-12_H-768_A-12_model模型,将下面命令中的bert_seq128_model字段替换为bert_chinese_L-12_H-768_A-12_model,bert_seq128_client字段替换为bert_chinese_L-12_H-768_A-12_client.
### 获取词典和样例数据
......@@ -33,13 +38,15 @@ sh get_data.sh
脚本将下载中文词典vocab.txt和中文样例数据data-c.txt
### 启动RPC预测服务
执行
启动cpu预测服务,执行
```
python -m paddle_serving_server.serve --model bert_seq128_model/ --port 9292 #启动cpu预测服务
```
或者
或者,启动gpu预测服务,执行
```
python -m paddle_serving_server_gpu.serve --model bert_seq128_model/ --port 9292 --gpu_ids 0 #在gpu 0上启动gpu预测服务
```
### 执行预测
......@@ -51,17 +58,28 @@ pip install paddle_serving_app
执行
```
head data-c.txt | python bert_client.py --model bert_seq128_client/serving_client_conf.prototxt
```
启动client读取data-c.txt中的数据进行预测,预测结果为文本的向量表示(由于数据较多,脚本中没有将输出进行打印),server端的地址在脚本中修改。
### 启动HTTP预测服务
启动cpu HTTP预测服务,执行
```
python bert_web_service.py bert_seq128_model/ 9292 #启动gpu预测服务
```
或者,启动gpu HTTP预测服务,执行
```
export CUDA_VISIBLE_DEVICES=0,1
```
通过环境变量指定gpu预测服务使用的gpu,示例中指定索引为0和1的两块gpu
```
python bert_web_service.py bert_seq128_model/ 9292 #启动gpu预测服务
python bert_web_service_gpu.py bert_seq128_model/ 9292 #启动gpu预测服务
```
### 执行预测
```
......
# coding=utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server_gpu.web_service import WebService
from paddle_serving_app.reader import ChineseBertReader
import sys
import os
import numpy as np
class BertService(WebService):
def load(self):
self.reader = ChineseBertReader({
"vocab_file": "vocab.txt",
"max_seq_len": 128
})
def preprocess(self, feed=[], fetch=[]):
feed_res = []
is_batch = False
for ins in feed:
feed_dict = self.reader.process(ins["words"].encode("utf-8"))
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape(
(len(feed_dict[key]), 1))
feed_res.append(feed_dict)
return feed_res, fetch, is_batch
bert_service = BertService(name="bert")
bert_service.load()
bert_service.load_model_config(sys.argv[1])
bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), device="gpu")
bert_service.run_rpc_service()
bert_service.run_web_service()
......@@ -132,6 +132,7 @@ class LocalPredictor(object):
ops_filter=[])
if use_xpu:
# 2MB l3 cache
config.enable_xpu(8 * 1024 * 1024)
self.predictor = create_paddle_predictor(config)
......
......@@ -20,7 +20,7 @@ from paddle_serving_server import OpMaker, OpSeqMaker, Server
from paddle_serving_client import Client
from contextlib import closing
import socket
import numpy as np
from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op
......@@ -64,8 +64,8 @@ class WebService(object):
f = open(client_config, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names = [var.alias_name for var in model_conf.feed_var]
self.fetch_names = [var.alias_name for var in model_conf.fetch_var]
self.feed_vars = {var.name: var for var in model_conf.feed_var}
self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
def _launch_rpc_service(self):
op_maker = OpMaker()
......@@ -201,6 +201,15 @@ class WebService(object):
def preprocess(self, feed=[], fetch=[]):
print("This API will be deprecated later. Please do not use it")
is_batch = True
feed_dict = {}
for var_name in self.feed_vars.keys():
feed_dict[var_name] = []
for feed_ins in feed:
for key in feed_ins:
feed_dict[key].append(np.array(feed_ins[key]).reshape(list(self.feed_vars[key].shape))[np.newaxis,:])
feed = {}
for key in feed_dict:
feed[key] = np.concatenate(feed_dict[key], axis=0)
return feed, fetch, is_batch
def postprocess(self, feed=[], fetch=[], fetch_map=None):
......
......@@ -212,6 +212,7 @@ class Server(object):
self.module_path = os.path.dirname(paddle_serving_server.__file__)
self.cur_path = os.getcwd()
self.use_local_bin = False
self.device = "cpu"
self.gpuid = 0
self.use_trt = False
self.use_lite = False
......@@ -279,6 +280,9 @@ class Server(object):
"GPU not found, please check your environment or use cpu version by \"pip install paddle_serving_server\""
)
def set_device(self, device="cpu"):
self.device = device
def set_gpuid(self, gpuid=0):
self.gpuid = gpuid
......@@ -311,18 +315,19 @@ class Server(object):
engine.static_optimization = False
engine.force_update_static_cache = False
engine.use_trt = self.use_trt
engine.use_lite = self.use_lite
engine.use_xpu = self.use_xpu
if os.path.exists('{}/__params__'.format(model_config_path)):
suffix = ""
else:
suffix = "_DIR"
if device == "arm":
engine.use_lite = self.use_lite
engine.use_xpu = self.use_xpu
if device == "cpu":
engine.type = "FLUID_CPU_ANALYSIS_DIR"
engine.type = "FLUID_CPU_ANALYSIS" + suffix
elif device == "gpu":
engine.type = "FLUID_GPU_ANALYSIS_DIR"
engine.type = "FLUID_GPU_ANALYSIS" + suffix
elif device == "arm":
engine.type = "FLUID_ARM_ANALYSIS_DIR"
engine.type = "FLUID_ARM_ANALYSIS" + suffix
self.model_toolkit_conf.engines.extend([engine])
def _prepare_infer_service(self, port):
......@@ -425,7 +430,7 @@ class Server(object):
cuda_version = line.split("\"")[1]
if cuda_version == "101" or cuda_version == "102" or cuda_version == "110":
device_version = "serving-gpu-" + cuda_version + "-"
elif cuda_version == "arm":
elif cuda_version == "arm" or cuda_version == "arm-xpu":
device_version = "serving-" + cuda_version + "-"
else:
device_version = "serving-gpu-cuda" + cuda_version + "-"
......@@ -528,7 +533,8 @@ class Server(object):
else:
print("Use local bin : {}".format(self.bin_path))
#self.check_cuda()
if self.use_lite:
# Todo: merge CPU and GPU code, remove device to model_toolkit
if self.device == "cpu" or self.device == "arm":
command = "{} " \
"-enable_model_toolkit " \
"-inferservice_path {} " \
......
......@@ -73,6 +73,7 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss
server.set_lite()
device = "arm"
server.set_device(device)
if args.use_xpu:
server.set_xpu()
......
......@@ -70,8 +70,8 @@ class WebService(object):
f = open(client_config, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names = [var.alias_name for var in model_conf.feed_var]
self.fetch_names = [var.alias_name for var in model_conf.fetch_var]
self.feed_vars = {var.name: var for var in model_conf.feed_var}
self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it")
......@@ -107,6 +107,7 @@ class WebService(object):
server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.set_device(device)
if use_lite:
server.set_lite()
......@@ -278,6 +279,15 @@ class WebService(object):
def preprocess(self, feed=[], fetch=[]):
print("This API will be deprecated later. Please do not use it")
is_batch = True
feed_dict = {}
for var_name in self.feed_vars.keys():
feed_dict[var_name] = []
for feed_ins in feed:
for key in feed_ins:
feed_dict[key].append(np.array(feed_ins[key]).reshape(list(self.feed_vars[key].shape))[np.newaxis,:])
feed = {}
for key in feed_dict:
feed[key] = np.concatenate(feed_dict[key], axis=0)
return feed, fetch, is_batch
def postprocess(self, feed=[], fetch=[], fetch_map=None):
......
......@@ -249,6 +249,8 @@ class LocalServiceHandler(object):
server = Server()
if gpuid >= 0:
server.set_gpuid(gpuid)
# TODO: support arm or arm + xpu later
server.set_device(self._device_name)
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册