未验证 提交 0aa85d4f 编写于 作者: B Bin Lu 提交者: GitHub

Merge pull request #1603 from Intsigstephon/develop

add cpp serving for clas and pp-shitu
# 模型服务化部署
- [1. 简介](#1)
- [2. Serving 安装](#2)
- [3. 图像分类服务部署](#3)
- [3.1 模型转换](#3.1)
- [3.2 服务部署和请求](#3.2)
- [4. 图像识别服务部署](#4)
- [4.1 模型转换](#4.1)
- [4.2 服务部署和请求](#4.2)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 简介
[Paddle Serving](https://github.com/PaddlePaddle/Serving) 旨在帮助深度学习开发者轻松部署在线预测服务,支持一键部署工业级的服务能力、客户端和服务端之间高并发和高效通信、并支持多种编程语言开发客户端。
该部分以 HTTP 预测服务部署为例,介绍怎样在 PaddleClas 中使用 PaddleServing 部署模型服务。目前只支持 Linux 平台部署,暂不支持 Windows 平台。
<a name="2"></a>
## 2. Serving 安装
Serving 官网推荐使用 docker 安装并部署 Serving 环境。首先需要拉取 docker 环境并创建基于 Serving 的 docker。
```shell
docker pull paddlepaddle/serving:0.7.0-cuda10.2-cudnn7-devel
nvidia-docker run -p 9292:9292 --name test -dit paddlepaddle/serving:0.7.0-cuda10.2-cudnn7-devel bash
nvidia-docker exec -it test bash
```
进入 docker 后,需要安装 Serving 相关的 python 包。
```shell
pip3 install paddle-serving-client==0.7.0
pip3 install paddle-serving-server==0.7.0 # CPU
pip3 install paddle-serving-app==0.7.0
pip3 install paddle-serving-server-gpu==0.7.0.post102 #GPU with CUDA10.2 + TensorRT6
# 其他GPU环境需要确认环境再选择执行哪一条
pip3 install paddle-serving-server-gpu==0.7.0.post101 # GPU with CUDA10.1 + TensorRT6
pip3 install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUDA11.2 + TensorRT8
```
* 如果安装速度太慢,可以通过 `-i https://pypi.tuna.tsinghua.edu.cn/simple` 更换源,加速安装过程。
* 其他环境配置安装请参考: [使用Docker安装Paddle Serving](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Install_CN.md)
* 如果希望部署 CPU 服务,可以安装 serving-server 的 cpu 版本,安装命令如下。
```shell
pip install paddle-serving-server
```
<a name="3"></a>
## 3. 图像分类服务部署
<a name="3.1"></a>
### 3.1 模型转换
使用 PaddleServing 做服务化部署时,需要将保存的 inference 模型转换为 Serving 模型。下面以经典的 ResNet50_vd 模型为例,介绍如何部署图像分类服务。
- 进入工作目录:
```shell
cd deploy/paddleserving
```
- 下载 ResNet50_vd 的 inference 模型:
```shell
# 下载并解压 ResNet50_vd 模型
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar && tar xf ResNet50_vd_infer.tar
```
- 用 paddle_serving_client 把下载的 inference 模型转换成易于 Server 部署的模型格式:
```
# 转换 ResNet50_vd 模型
python3 -m paddle_serving_client.convert --dirname ./ResNet50_vd_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--serving_server ./ResNet50_vd_serving/ \
--serving_client ./ResNet50_vd_client/
```
ResNet50_vd 推理模型转换完成后,会在当前文件夹多出 `ResNet50_vd_serving``ResNet50_vd_client` 的文件夹,具备如下格式:
```
|- ResNet50_vd_server/
|- inference.pdiparams
|- inference.pdmodel
|- serving_server_conf.prototxt
|- serving_server_conf.stream.prototxt
|- ResNet50_vd_client
|- serving_client_conf.prototxt
|- serving_client_conf.stream.prototxt
```
得到模型文件之后,需要修改 `ResNet50_vd_server` 下文件 `serving_server_conf.prototxt` 中的 alias 名字:将 `fetch_var` 中的 `alias_name` 改为 `prediction`
**备注**: Serving 为了兼容不同模型的部署,提供了输入输出重命名的功能。这样,不同的模型在推理部署时,只需要修改配置文件的 alias_name 即可,无需修改代码即可完成推理部署。
修改后的 serving_server_conf.prototxt 如下所示:
```
feed_var {
name: "inputs"
alias_name: "inputs"
is_lod_tensor: false
feed_type: 1
shape: 3
shape: 224
shape: 224
}
fetch_var {
name: "save_infer_model/scale_0.tmp_1"
alias_name: "prediction"
is_lod_tensor: false
fetch_type: 1
shape: 1000
}
```
<a name="3.2"></a>
### 3.2 服务部署和请求
paddleserving 目录包含了启动 pipeline 服务和发送预测请求的代码,包括:
```shell
__init__.py
config.yml # 启动服务的配置文件
pipeline_http_client.py # http方式发送pipeline预测请求的脚本
pipeline_rpc_client.py # rpc方式发送pipeline预测请求的脚本
classification_web_service.py # 启动pipeline服务端的脚本
```
- 启动服务:
```shell
# 启动服务,运行日志保存在 log.txt
python3 classification_web_service.py &>log.txt &
```
成功启动服务后,log.txt 中会打印类似如下日志
![](./imgs/start_server.png)
- 发送请求:
```shell
# 发送服务请求
python3 pipeline_http_client.py
```
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果示例为:
![](./imgs/results.png)
<a name="4"></a>
## 4.图像识别服务部署
使用 PaddleServing 做服务化部署时,需要将保存的 inference 模型转换为 Serving 模型。 下面以 PP-ShiTu 中的超轻量图像识别模型为例,介绍图像识别服务的部署。
<a name="4.1"></a>
## 4.1 模型转换
- 下载通用检测 inference 模型和通用识别 inference 模型
```
cd deploy
# 下载并解压通用识别模型
wget -P models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
cd models
tar -xf general_PPLCNet_x2_5_lite_v1.0_infer.tar
# 下载并解压通用检测模型
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
tar -xf picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
```
- 转换识别 inference 模型为 Serving 模型:
```
# 转换识别模型
python3 -m paddle_serving_client.convert --dirname ./general_PPLCNet_x2_5_lite_v1.0_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--serving_server ./general_PPLCNet_x2_5_lite_v1.0_serving/ \
--serving_client ./general_PPLCNet_x2_5_lite_v1.0_client/
```
识别推理模型转换完成后,会在当前文件夹多出 `general_PPLCNet_x2_5_lite_v1.0_serving/``general_PPLCNet_x2_5_lite_v1.0_client/` 的文件夹。修改 `general_PPLCNet_x2_5_lite_v1.0_serving/` 目录下的 serving_server_conf.prototxt 中的 alias 名字: 将 `fetch_var` 中的 `alias_name` 改为 `features`
修改后的 serving_server_conf.prototxt 内容如下:
```
feed_var {
name: "x"
alias_name: "x"
is_lod_tensor: false
feed_type: 1
shape: 3
shape: 224
shape: 224
}
fetch_var {
name: "save_infer_model/scale_0.tmp_1"
alias_name: "features"
is_lod_tensor: false
fetch_type: 1
shape: 512
}
```
- 转换通用检测 inference 模型为 Serving 模型:
```
# 转换通用检测模型
python3 -m paddle_serving_client.convert --dirname ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--serving_server ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/ \
--serving_client ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/
```
检测 inference 模型转换完成后,会在当前文件夹多出 `picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/``picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/` 的文件夹。
**注意:** 此处不需要修改 `picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/` 目录下的 serving_server_conf.prototxt 中的 alias 名字。
- 下载并解压已经构建后的检索库 index
```
cd ../
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar && tar -xf drink_dataset_v1.0.tar
```
<a name="4.2"></a>
## 4.2 服务部署和请求
**注意:** 识别服务涉及到多个模型,出于性能考虑采用 PipeLine 部署方式。Pipeline 部署方式当前不支持 windows 平台。
- 进入到工作目录
```shell
cd ./deploy/paddleserving/recognition
```
paddleserving 目录包含启动 pipeline 服务和发送预测请求的代码,包括:
```
__init__.py
config.yml # 启动服务的配置文件
pipeline_http_client.py # http方式发送pipeline预测请求的脚本
pipeline_rpc_client.py # rpc方式发送pipeline预测请求的脚本
recognition_web_service.py # 启动pipeline服务端的脚本
```
- 启动服务:
```
# 启动服务,运行日志保存在 log.txt
python3 recognition_web_service.py &>log.txt &
```
成功启动服务后,log.txt 中会打印类似如下日志
![](./imgs/start_server_shitu.png)
- 发送请求:
```
python3 pipeline_http_client.py
```
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果示例为:
![](./imgs/results_shitu.png)
<a name="5"></a>
## 5.FAQ
**Q1**: 发送请求后没有结果返回或者提示输出解码报错
**A1**: 启动服务和发送请求时不要设置代理,可以在启动服务前和发送请求前关闭代理,关闭代理的命令是:
```
unset https_proxy
unset http_proxy
```
更多的服务部署类型,如 `RPC 预测服务` 等,可以参考 Serving 的[github 官网](https://github.com/PaddlePaddle/Serving/tree/v0.7.0/examples)
../../docs/zh_CN/inference_deployment/paddle_serving_deploy.md
\ No newline at end of file
nohup python3 -m paddle_serving_server.serve \
--model ../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving \
--port 9293 >>log_mainbody_detection.txt 1&>2 &
nohup python3 -m paddle_serving_server.serve \
--model ../../models/general_PPLCNet_x2_5_lite_v1.0_serving \
--port 9294 >>log_feature_extraction.txt 1&>2 &
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import cv2
import faiss
import os
import pickle
class MainbodyDetect():
"""
pp-shitu mainbody detect.
include preprocess, process, postprocess
return detect results
Attention: Postprocess include num limit and box filter; no nms
"""
def __init__(self):
self.preprocess = DetectionSequential([
DetectionFile2Image(), DetectionNormalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
DetectionResize(
(640, 640), False, interpolation=2), DetectionTranspose(
(2, 0, 1))
])
self.client = Client()
self.client.load_client_config(
"../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/serving_client_conf.prototxt"
)
self.client.connect(['127.0.0.1:9293'])
self.max_det_result = 5
self.conf_threshold = 0.2
def predict(self, imgpath):
im, im_info = self.preprocess(imgpath)
im_shape = np.array(im.shape[1:]).reshape(-1)
scale_factor = np.array(list(im_info['scale_factor'])).reshape(-1)
fetch_map = self.client.predict(
feed={
"image": im,
"im_shape": im_shape,
"scale_factor": scale_factor,
},
fetch=["save_infer_model/scale_0.tmp_1"],
batch=False)
return self.postprocess(fetch_map, imgpath)
def postprocess(self, fetch_map, imgpath):
#1. get top max_det_result
det_results = fetch_map["save_infer_model/scale_0.tmp_1"]
if len(det_results) > self.max_det_result:
boxes_reserved = fetch_map[
"save_infer_model/scale_0.tmp_1"][:self.max_det_result]
else:
boxes_reserved = det_results
#2. do conf threshold
boxes_list = []
for i in range(boxes_reserved.shape[0]):
if (boxes_reserved[i, 1]) > self.conf_threshold:
boxes_list.append(boxes_reserved[i, :])
#3. add origin image box
origin_img = cv2.imread(imgpath)
boxes_list.append(
np.array([0, 1.0, 0, 0, origin_img.shape[1], origin_img.shape[0]]))
return np.array(boxes_list)
class ObjectRecognition():
"""
pp-shitu object recognion for all objects detected by MainbodyDetect.
include preprocess, process, postprocess
preprocess include preprocess for each image and batching.
Batch process
postprocess include retrieval and nms
"""
def __init__(self):
self.client = Client()
self.client.load_client_config(
"../../models/general_PPLCNet_x2_5_lite_v1.0_client/serving_client_conf.prototxt"
)
self.client.connect(["127.0.0.1:9294"])
self.seq = Sequential([
BGR2RGB(), Resize((224, 224)), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225],
False), Transpose((2, 0, 1))
])
self.searcher, self.id_map = self.init_index()
self.rec_nms_thresold = 0.05
self.rec_score_thres = 0.5
self.feature_normalize = True
self.return_k = 1
def init_index(self):
index_dir = "../../drink_dataset_v1.0/index"
assert os.path.exists(os.path.join(
index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join(
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
searcher = faiss.read_index(os.path.join(index_dir, "vector.index"))
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
id_map = pickle.load(fd)
return searcher, id_map
def predict(self, det_boxes, imgpath):
#1. preprocess
batch_imgs = []
origin_img = cv2.imread(imgpath)
for i in range(det_boxes.shape[0]):
box = det_boxes[i]
x1, y1, x2, y2 = [int(x) for x in box[2:]]
cropped_img = origin_img[y1:y2, x1:x2, :].copy()
tmp = self.seq(cropped_img)
batch_imgs.append(tmp)
batch_imgs = np.array(batch_imgs)
#2. process
fetch_map = self.client.predict(
feed={"x": batch_imgs}, fetch=["features"], batch=True)
batch_features = fetch_map["features"]
#3. postprocess
if self.feature_normalize:
feas_norm = np.sqrt(
np.sum(np.square(batch_features), axis=1, keepdims=True))
batch_features = np.divide(batch_features, feas_norm)
scores, docs = self.searcher.search(batch_features, self.return_k)
results = []
for i in range(scores.shape[0]):
pred = {}
if scores[i][0] >= self.rec_score_thres:
pred["bbox"] = [int(x) for x in det_boxes[i, 2:]]
pred["rec_docs"] = self.id_map[docs[i][0]].split()[1]
pred["rec_scores"] = scores[i][0]
results.append(pred)
return self.nms_to_rec_results(results)
def nms_to_rec_results(self, results):
filtered_results = []
x1 = np.array([r["bbox"][0] for r in results]).astype("float32")
y1 = np.array([r["bbox"][1] for r in results]).astype("float32")
x2 = np.array([r["bbox"][2] for r in results]).astype("float32")
y2 = np.array([r["bbox"][3] for r in results]).astype("float32")
scores = np.array([r["rec_scores"] for r in results])
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
while order.size > 0:
i = order[0]
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= self.rec_nms_thresold)[0]
order = order[inds + 1]
filtered_results.append(results[i])
return filtered_results
if __name__ == "__main__":
det = MainbodyDetect()
rec = ObjectRecognition()
#1. get det_results
imgpath = "../../drink_dataset_v1.0/test_images/001.jpeg"
det_results = det.predict(imgpath)
#2. get rec_results
rec_results = rec.predict(det_results, imgpath)
print(rec_results)
#run cls server:
nohup python3 -m paddle_serving_server.serve --model ResNet50_vd_serving --port 9292 &
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from paddle_serving_client import Client
#app
from paddle_serving_app.reader import Sequential, URL2Image, Resize
from paddle_serving_app.reader import CenterCrop, RGB2BGR, Transpose, Div, Normalize
import time
client = Client()
client.load_client_config("./ResNet50_vd_serving/serving_server_conf.prototxt")
client.connect(["127.0.0.1:9292"])
label_dict = {}
label_idx = 0
with open("imagenet.label") as fin:
for line in fin:
label_dict[label_idx] = line.strip()
label_idx += 1
#preprocess
seq = Sequential([
URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True)
])
start = time.time()
image_file = "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"
for i in range(1):
img = seq(image_file)
fetch_map = client.predict(
feed={"inputs": img}, fetch=["prediction"], batch=False)
prob = max(fetch_map["prediction"][0])
label = label_dict[fetch_map["prediction"][0].tolist().index(prob)].strip(
).replace(",", "")
print("prediction: {}, probability: {}".format(label, prob))
end = time.time()
print(end - start)
......@@ -6,9 +6,13 @@
- [3. 图像分类服务部署](#3)
- [3.1 模型转换](#3.1)
- [3.2 服务部署和请求](#3.2)
- [3.2.1 Python Serving](#3.2.1)
- [3.2.2 C++ Serving](#3.2.2)
- [4. 图像识别服务部署](#4)
- [4.1 模型转换](#4.1)
- [4.2 服务部署和请求](#4.2)
- [4.2.1 Python Serving](#4.2.1)
- [4.2.2 C++ Serving](#4.2.2)
- [5. FAQ](#5)
<a name="1"></a>
......@@ -90,7 +94,7 @@ ResNet50_vd 推理模型转换完成后,会在当前文件夹多出 `ResNet50_
|- serving_client_conf.prototxt
|- serving_client_conf.stream.prototxt
```
得到模型文件之后,需要修改 `ResNet50_vd_server` 下文件 `serving_server_conf.prototxt` 中的 alias 名字:将 `fetch_var` 中的 `alias_name` 改为 `prediction`
得到模型文件之后,需要分别修改 `ResNet50_vd_server``ResNet50_vd_client` 下文件 `serving_server_conf.prototxt` 中的 alias 名字:将 `fetch_var` 中的 `alias_name` 改为 `prediction`
**备注**: Serving 为了兼容不同模型的部署,提供了输入输出重命名的功能。这样,不同的模型在推理部署时,只需要修改配置文件的 alias_name 即可,无需修改代码即可完成推理部署。
修改后的 serving_server_conf.prototxt 如下所示:
......@@ -114,30 +118,51 @@ fetch_var {
```
<a name="3.2"></a>
### 3.2 服务部署和请求
paddleserving 目录包含了启动 pipeline 服务和发送预测请求的代码,包括:
paddleserving 目录包含了启动 pipeline 服务、C++ serving服务和发送预测请求的代码,包括:
```shell
__init__.py
config.yml # 启动服务的配置文件
config.yml # 启动pipeline服务的配置文件
pipeline_http_client.py # http方式发送pipeline预测请求的脚本
pipeline_rpc_client.py # rpc方式发送pipeline预测请求的脚本
classification_web_service.py # 启动pipeline服务端的脚本
run_cpp_serving.sh # 启动C++ Serving部署的脚本
test_cpp_serving_client.py # rpc方式发送C++ serving预测请求的脚本
```
<a name="3.2.1"></a>
#### 3.2.1 Python Serving
- 启动服务:
```shell
# 启动服务,运行日志保存在 log.txt
python3 classification_web_service.py &>log.txt &
```
成功启动服务后,log.txt 中会打印类似如下日志
![](../../../deploy/paddleserving/imgs/start_server.png)
- 发送请求:
```shell
# 发送服务请求
python3 pipeline_http_client.py
```
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果示例为:
![](../../../deploy/paddleserving/imgs/results.png)
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果如下:
```
{'err_no': 0, 'err_msg': '', 'key': ['label', 'prob'], 'value': ["['daisy']", '[0.9341402053833008]'], 'tensors': []}
```
<a name="3.2.2"></a>
#### 3.2.2 C++ Serving
- 启动服务:
```shell
# 启动服务, 服务在后台运行,运行日志保存在 nohup.txt
sh run_cpp_serving.sh
```
- 发送请求:
```shell
# 发送服务请求
python3 test_cpp_serving_client.py
```
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果如下:
```
prediction: daisy, probability: 0.9341399073600769
```
<a name="4"></a>
## 4.图像识别服务部署
......@@ -164,7 +189,7 @@ python3 -m paddle_serving_client.convert --dirname ./general_PPLCNet_x2_5_lite_v
--serving_server ./general_PPLCNet_x2_5_lite_v1.0_serving/ \
--serving_client ./general_PPLCNet_x2_5_lite_v1.0_client/
```
识别推理模型转换完成后,会在当前文件夹多出 `general_PPLCNet_x2_5_lite_v1.0_serving/``general_PPLCNet_x2_5_lite_v1.0_client/` 的文件夹。修改 `general_PPLCNet_x2_5_lite_v1.0_serving/` 目录下的 serving_server_conf.prototxt 中的 alias 名字: 将 `fetch_var` 中的 `alias_name` 改为 `features`
识别推理模型转换完成后,会在当前文件夹多出 `general_PPLCNet_x2_5_lite_v1.0_serving/``general_PPLCNet_x2_5_lite_v1.0_serving/` 的文件夹。分别修改 `general_PPLCNet_x2_5_lite_v1.0_serving/``general_PPLCNet_x2_5_lite_v1.0_client/` 目录下的 serving_server_conf.prototxt 中的 alias 名字: 将 `fetch_var` 中的 `alias_name` 改为 `features`
修改后的 serving_server_conf.prototxt 内容如下:
```
feed_var {
......@@ -209,28 +234,52 @@ wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_da
```shell
cd ./deploy/paddleserving/recognition
```
paddleserving 目录包含启动 pipeline 服务和发送预测请求的代码,包括:
paddleserving 目录包含启动 Python Pipeline 服务、C++ Serving 服务和发送预测请求的代码,包括:
```
__init__.py
config.yml # 启动服务的配置文件
config.yml # 启动python pipeline服务的配置文件
pipeline_http_client.py # http方式发送pipeline预测请求的脚本
pipeline_rpc_client.py # rpc方式发送pipeline预测请求的脚本
recognition_web_service.py # 启动pipeline服务端的脚本
run_cpp_serving.sh # 启动C++ Pipeline Serving部署的脚本
test_cpp_serving_client.py # rpc方式发送C++ Pipeline serving预测请求的脚本
```
<a name="4.2.1"></a>
#### 4.2.1 Python Serving
- 启动服务:
```
# 启动服务,运行日志保存在 log.txt
python3 recognition_web_service.py &>log.txt &
```
成功启动服务后,log.txt 中会打印类似如下日志
![](../../../deploy/paddleserving/imgs/start_server_shitu.png)
- 发送请求:
```
python3 pipeline_http_client.py
```
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果示例为:
![](../../../deploy/paddleserving/imgs/results_shitu.png)
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果如下:
```
{'err_no': 0, 'err_msg': '', 'key': ['result'], 'value': ["[{'bbox': [345, 95, 524, 576], 'rec_docs': '红牛-强化型', 'rec_scores': 0.79903316}]"], 'tensors': []}
```
<a name="4.2.2"></a>
#### 4.2.2 C++ Serving
- 启动服务:
```shell
# 启动服务: 此处会在后台同时启动主体检测和特征提取服务,端口号分别为9293和9294;
# 运行日志分别保存在 log_mainbody_detection.txt 和 log_feature_extraction.txt中
sh run_cpp_serving.sh
```
- 发送请求:
```shell
# 发送服务请求
python3 test_cpp_serving_client.py
```
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果如下所示:
```
[{'bbox': [345, 95, 524, 586], 'rec_docs': '红牛-强化型', 'rec_scores': 0.8016462}]
```
<a name="5"></a>
## 5.FAQ
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册