提交 143ddd36 编写于 作者: H HydrogenSulfate

refine serving related docs & codes

上级 f1ef3a0c
......@@ -9,15 +9,15 @@
# 默认编译时的${PWD}=PaddleClas/deploy/paddleserving/
python_name=${1:-'python'}
export python_name=${1:-'python'}
apt-get update
apt install -y libcurl4-openssl-dev libbz2-dev
wget -nc https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar
tar xf centos_ssl.tar
rm -rf centos_ssl.tar
mv libcrypto.so.1.0.2k /usr/lib/libcrypto.so.1.0.2k
mv libssl.so.1.0.2k /usr/lib/libssl.so.1.0.2k
\mv libcrypto.so.1.0.2k /usr/lib/libcrypto.so.1.0.2k
\mv libssl.so.1.0.2k /usr/lib/libssl.so.1.0.2k
ln -sf /usr/lib/libcrypto.so.1.0.2k /usr/lib/libcrypto.so.10
ln -sf /usr/lib/libssl.so.1.0.2k /usr/lib/libssl.so.10
ln -sf /usr/lib/libcrypto.so.10 /usr/lib/libcrypto.so
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
import os
import requests
imgpath = "../../drink_dataset_v2.0/test_images/001.jpeg"
image_path = "../../drink_dataset_v2.0/test_images/100.jpeg"
def cv2_to_base64(image):
return base64.b64encode(image).decode('utf8')
def bytes_to_base64(image_bytes: bytes) -> bytes:
"""encode bytes using base64 algorithm
Args:
image_bytes (bytes): bytes object to be encoded
Returns:
bytes: base64 bytes
"""
return base64.b64encode(image_bytes).decode('utf8')
if __name__ == "__main__":
url = "http://127.0.0.1:18081/recognition/prediction"
with open(os.path.join(".", imgpath), 'rb') as file:
image_data1 = file.read()
image = cv2_to_base64(image_data1)
data = {"key": ["image"], "value": [image]}
with open(os.path.join(".", image_path), 'rb') as file:
image_bytes = file.read()
image_base64 = bytes_to_base64(image_bytes)
data = {"key": ["image"], "value": [image_base64]}
for i in range(1):
r = requests.post(url=url, data=json.dumps(data))
......
......@@ -15,20 +15,33 @@ try:
from paddle_serving_server_gpu.pipeline import PipelineClient
except ImportError:
from paddle_serving_server.pipeline import PipelineClient
import base64
import os
client = PipelineClient()
client.connect(['127.0.0.1:9994'])
imgpath = "../../drink_dataset_v1.0/test_images/001.jpeg"
image_path = "../../drink_dataset_v2.0/test_images/100.jpeg"
def bytes_to_base64(image_bytes: bytes) -> bytes:
"""encode bytes using base64 algorithm
Args:
image_bytes (bytes): bytes to be encoded
Returns:
bytes: base64 bytes
"""
return base64.b64encode(image_bytes).decode('utf8')
def cv2_to_base64(image):
return base64.b64encode(image).decode('utf8')
if __name__ == "__main__":
with open(imgpath, 'rb') as file:
image_data = file.read()
image = cv2_to_base64(image_data)
with open(os.path.join(".", image_path), 'rb') as file:
image_bytes = file.read()
image_base64 = bytes_to_base64(image_bytes)
for i in range(1):
ret = client.predict(feed_dict={"image": image}, fetch=["result"])
ret = client.predict(
feed_dict={"image": image_base64}, fetch=["result"])
print(ret)
......@@ -15,7 +15,7 @@ feed_var {
shape: 6
}
fetch_var {
name: "save_infer_model/scale_0.tmp_1"
name: "batch_norm_25.tmp_2"
alias_name: "features"
is_lod_tensor: false
fetch_type: 1
......
......@@ -15,7 +15,7 @@ feed_var {
shape: 6
}
fetch_var {
name: "save_infer_model/scale_0.tmp_1"
name: "batch_norm_25.tmp_2"
alias_name: "features"
is_lod_tensor: false
fetch_type: 1
......
......@@ -112,7 +112,7 @@ class RecOp(Op):
Transpose((2, 0, 1))
])
index_dir = "../../drink_dataset_v1.0/index"
index_dir = "../../drink_dataset_v2.0/index"
assert os.path.exists(os.path.join(
index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join(
......
......@@ -3,12 +3,12 @@ gpu_id=$1
# PP-ShiTu CPP serving script
if [[ -n "${gpu_id}" ]]; then
nohup python3.7 -m paddle_serving_server.serve \
--model ../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving ../../models/general_PPLCNet_x2_5_lite_v1.0_serving \
--model ../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving ../../models/general_PPLCNetV2_base_pretrained_v1.0_serving \
--op GeneralPicodetOp GeneralFeatureExtractOp \
--port 9400 --gpu_id="${gpu_id}" > log_PPShiTu.txt 2>&1 &
else
nohup python3.7 -m paddle_serving_server.serve \
--model ../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving ../../models/general_PPLCNet_x2_5_lite_v1.0_serving \
--model ../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving ../../models/general_PPLCNetV2_base_pretrained_v1.0_serving \
--op GeneralPicodetOp GeneralFeatureExtractOp \
--port 9400 > log_PPShiTu.txt 2>&1 &
fi
......@@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import pickle
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import cv2
import faiss
import os
import pickle
import numpy as np
from paddle_serving_client import Client
rec_nms_thresold = 0.05
rec_score_thres = 0.5
feature_normalize = True
return_k = 1
index_dir = "../../drink_dataset_v1.0/index"
index_dir = "../../drink_dataset_v2.0/index"
def init_index(index_dir):
......@@ -41,7 +40,7 @@ def init_index(index_dir):
return searcher, id_map
#get box
# get box
def nms_to_rec_results(results, thresh=0.1):
filtered_results = []
......@@ -91,21 +90,21 @@ def postprocess(fetch_dict, feature_normalize, det_boxes, searcher, id_map,
pred["rec_scores"] = scores[i][0]
results.append(pred)
#do nms
# do NMS
results = nms_to_rec_results(results, rec_nms_thresold)
return results
#do client
# do client
if __name__ == "__main__":
client = Client()
client.load_client_config([
"../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client",
"../../models/general_PPLCNet_x2_5_lite_v1.0_client"
"../../models/general_PPLCNetV2_base_pretrained_v1.0_client"
])
client.connect(['127.0.0.1:9400'])
im = cv2.imread("../../drink_dataset_v1.0/test_images/001.jpeg")
im = cv2.imread("../../drink_dataset_v2.0/test_images/100.jpeg")
im_shape = np.array(im.shape[:2]).reshape(-1)
fetch_map = client.predict(
feed={"image": im,
......@@ -113,7 +112,8 @@ if __name__ == "__main__":
fetch=["features", "boxes"],
batch=False)
#add retrieval procedure
# add retrieval procedure
print(fetch_map.keys())
det_boxes = fetch_map["boxes"]
searcher, id_map = init_index(index_dir)
results = postprocess(fetch_map, feature_normalize, det_boxes, searcher,
......
......@@ -7,6 +7,7 @@
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推理](#模型推理)
- [模型部署](#模型部署)
- [模块介绍](#模块介绍)
- [主体检测模型](#主体检测模型)
- [特征提取模型](#特征提取模型)
......@@ -21,18 +22,18 @@
PP-ShiTuV2 是基于 PP-ShiTuV1 改进的一个实用轻量级通用图像识别系统,相比 PP-ShiTuV1 具有更高的识别精度、更强的泛化能力以及相近的推理速度<sup>*</sup>。该系统主要针对**训练数据集**、特征提取两个部分进行优化,使用了更优的骨干网络、损失函数与训练策略。使得 PP-ShiTuV2 在多个实际应用场景上的检索性能有显著提升。
<div align="center">
<img src="../../images/structure.png" />
<img src="../../images/structure.jpg" />
</div>
### 数据集介绍
我们将训练数据进行了合理扩充与优化,更多细节请参考 [PP-ShiTuV2 数据集](../image_recognition_pipeline/feature_extraction.md#42-pp-shituv2)
我们将训练数据进行了合理扩充与优化,更多细节请参考 [PP-ShiTuV2 数据集](../image_recognition_pipeline/feature_extraction.md#4-实验部分)
下面以 [PP-ShiTuV2](../image_recognition_pipeline/feature_extraction.md#42-pp-shituv2) 的数据集为例,介绍 PP-ShiTuV2 模型的训练、评估、推理流程。
下面以 [PP-ShiTuV2](../image_recognition_pipeline/feature_extraction.md#4-实验部分) 的数据集为例,介绍 PP-ShiTuV2 模型的训练、评估、推理流程。
### 模型训练
首先下载好 [PP-ShiTuV2 数据集](../image_recognition_pipeline/feature_extraction.md#42-pp-shituv2) 中的16个数据集并手动进行合并、生成标注文本文件 `train_reg_all_data_v2.txt`,最后放置到 `dataset` 目录下。
首先下载好 [PP-ShiTuV2 数据集](../image_recognition_pipeline/feature_extraction.md#4-实验部分) 中的16个数据集并手动进行合并、生成标注文本文件 `train_reg_all_data_v2.txt`,最后放置到 `dataset` 目录下。
合并后的文件夹结构如下所示:
......@@ -80,6 +81,10 @@ python3.7 -m paddle.distributed.launch tools/train.py \
参考 [模型推理](../image_recognition_pipeline/feature_extraction.md#54-模型推理)
### 模型部署
参考 [模型部署](../inference_deployment/recognition_serving_deploy.md#3-图像识别服务部署)
## 模块介绍
### 主体检测模型
......@@ -106,7 +111,7 @@ python3.7 -m paddle.distributed.launch tools/train.py \
2. `last stride=1`:只将最后一个 stage 的 stride 改为1,即不进行下采样,以此增加最后输出的特征图的语义信息,同时不对推理速度产生太大影响。
3. `BN Neck`:在全局池化层后加入一个 `BatchNorm1D` 结构,对特征向量的每个维度进行标准化,`CELoss``TripletAngularMarginLoss` 在不同的分布下进行优化,使得模型更快地收敛。
3. `BN Neck`:在全局池化层后加入一个 `BatchNorm1D` 结构,对特征向量的每个维度进行标准化,使得模型更快地收敛。
| 模型 | training data | recall@1%(mAP%) |
| :----------------------------------------------------------------- | :---------------- | :-------------- |
......
简体中文|[English](../../en/image_recognition_pipeline/feature_extraction_en.md)
简体中文 | [English](../../en/image_recognition_pipeline/feature_extraction_en.md)
# 特征提取
## 目录
......@@ -58,7 +58,7 @@ Head 部分选用 [FC Layer](../../../ppcls/arch/gears/fc.py),使用分类头
#### 3.4 Loss
Loss 部分选用 [Cross entropy loss](../../../ppcls/loss/celoss.py)[TripletAngularMarginLoss](../../../ppcls/loss/tripletangularmarginloss.py),在训练时以分类损失和基于角度的三元组损失来指导网络进行优化。详细的配置文件见[GeneralRecognitionV2_PPLCNetV2_base.yaml](../../../ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml#L63-77)
Loss 部分选用 [Cross entropy loss](../../../ppcls/loss/celoss.py)[TripletAngularMarginLoss](../../../ppcls/loss/tripletangularmarginloss.py),在训练时以分类损失和基于角度的三元组损失来指导网络进行优化。详细的配置文件见 [GeneralRecognitionV2_PPLCNetV2_base.yaml](../../../ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml#L63-77)
<a name="4"></a>
......@@ -137,8 +137,8 @@ Loss 部分选用 [Cross entropy loss](../../../ppcls/loss/celoss.py) 和 [Tripl
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ # 此处表示train数据所在的目录
cls_label_path: ./dataset/train_reg_all_data.txt # 此处表示train数据集label文件的地址
image_root: ./dataset/ # 此处表示train数据所在的目录
cls_label_path: ./dataset/train_reg_all_data_v2.txt # 此处表示train数据集对应标注文件的地址
relabel: True
```
- 修改评估数据集中query数据配置:
......@@ -147,7 +147,7 @@ Loss 部分选用 [Cross entropy loss](../../../ppcls/loss/celoss.py) 和 [Tripl
dataset:
name: VeriWild
image_root: ./dataset/Aliproduct/ # 此处表示query数据集所在的目录
cls_label_path: ./dataset/Aliproduct/val_list.txt # 此处表示query数据集label文件的地址
cls_label_path: ./dataset/Aliproduct/val_list.txt # 此处表示query数据集对应标注文件的地址
```
- 修改评估数据集中gallery数据配置:
```yaml
......@@ -155,7 +155,7 @@ Loss 部分选用 [Cross entropy loss](../../../ppcls/loss/celoss.py) 和 [Tripl
dataset:
name: VeriWild
image_root: ./dataset/Aliproduct/ # 此处表示gallery数据集所在的目录
cls_label_path: ./dataset/Aliproduct/val_list.txt # 此处表示gallery数据集label文件的地址
cls_label_path: ./dataset/Aliproduct/val_list.txt # 此处表示gallery数据集对应标注文件的地址
```
<a name="5.2"></a>
......@@ -249,8 +249,16 @@ python3.7 python/predict_rec.py \
-c configs/inference_rec.yaml \
-o Global.rec_inference_model_dir="../inference"
```
得到的特征输出格式如下图所示:
![](../../images/feature_extraction_output.png)
得到的特征输出格式如下所示:
```log
wangzai.jpg: [-4.72979844e-02 3.40038240e-02 -4.06982675e-02 2.46225717e-03
9.67078656e-03 2.70162839e-02 -8.85589980e-03 -4.71983254e-02
6.18615076e-02 1.38509814e-02 -1.11342799e-02 6.73768669e-02
-1.03065455e-02 -4.88462113e-02 8.46387446e-03 1.57074817e-02
...
-3.14170569e-02 7.35917836e-02 -3.09373233e-02 -2.31755469e-02]
```
在实际使用过程中,仅仅得到特征可能并不能满足业务需求。如果想进一步通过特征检索来进行图像识别,可以参照文档[向量检索](./vector_search.md)
......
......@@ -14,6 +14,7 @@
- [4. FAQ](#4-faq)
<a name="1"></a>
## 1. 简介
[Paddle Serving](https://github.com/PaddlePaddle/Serving) 旨在帮助深度学习开发者轻松部署在线预测服务,支持一键部署工业级的服务能力、客户端和服务端之间高并发和高效通信、并支持多种编程语言开发客户端。
......@@ -21,6 +22,7 @@
该部分以 HTTP 预测服务部署为例,介绍怎样在 PaddleClas 中使用 PaddleServing 部署模型服务。目前只支持 Linux 平台部署,暂不支持 Windows 平台。
<a name="2"></a>
## 2. Serving 安装
Serving 官网推荐使用 docker 安装并部署 Serving 环境。首先需要拉取 docker 环境并创建基于 Serving 的 docker。
......@@ -59,12 +61,12 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
* 如果安装速度太慢,可以通过 `-i https://pypi.tuna.tsinghua.edu.cn/simple` 更换源,加速安装过程。
* 其他环境配置安装请参考:[使用Docker安装Paddle Serving](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Install_CN.md)
<a name="3"></a>
## 3. 图像识别服务部署
使用 PaddleServing 做图像识别服务化部署时,**需要将保存的多个 inference 模型都转换为 Serving 模型**。 下面以 PP-ShiTu 中的超轻量图像识别模型为例,介绍图像识别服务的部署。
<a name="3.1"></a>
### 3.1 模型转换
......@@ -108,18 +110,7 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
├── serving_client_conf.prototxt
└── serving_client_conf.stream.prototxt
```
- 转换通用检测 inference 模型为 Serving 模型:
```shell
# 转换通用检测模型
python3.7 -m paddle_serving_client.convert --dirname ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--serving_server ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/ \
--serving_client ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/
```
上述命令的参数含义与[#3.1 模型转换](#3.1)相同
识别推理模型转换完成后,会在当前文件夹多出 `general_PPLCNetV2_base_pretrained_v1.0_serving/``general_PPLCNetV2_base_pretrained_v1.0_client/` 的文件夹。分别修改 `general_PPLCNetV2_base_pretrained_v1.0_serving/``general_PPLCNetV2_base_pretrained_v1.0_client/` 目录下的 `serving_server_conf.prototxt` 中的 `alias` 名字: 将 `fetch_var` 中的 `alias_name` 改为 `features`。 修改后的 `serving_server_conf.prototxt` 内容如下
接下来分别修改 `general_PPLCNetV2_base_pretrained_v1.0_serving/``general_PPLCNetV2_base_pretrained_v1.0_client/` 目录下的 `serving_server_conf.prototxt` 中的 `alias` 名字: 将 `fetch_var` 中的 `alias_name` 改为 `features`。修改后的 `serving_server_conf.prototxt` 内容如下
```log
feed_var {
......@@ -139,6 +130,17 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
shape: 512
}
```
- 转换通用检测 inference 模型为 Serving 模型:
```shell
# 转换通用检测模型
python3.7 -m paddle_serving_client.convert --dirname ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--serving_server ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/ \
--serving_client ./picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/
```
上述命令的参数含义与[#3.1 模型转换](#3.1)相同
通用检测 inference 模型转换完成后,会在当前文件夹多出 `picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/``picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/` 的文件夹,具备如下结构:
```shell
├── picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/
......@@ -151,14 +153,14 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
├── serving_client_conf.prototxt
└── serving_client_conf.stream.prototxt
```
上述命令中参数具体含义如下表所示
| 参数 | 类型 | 默认值 | 描述 |
| ----------------- | ---- | ------------------ | ------------------------------------------------------------ |
| `dirname` | str | - | 需要转换的模型文件存储路径,Program结构文件和参数文件均保存在此目录。 |
| `model_filename` | str | None | 存储需要转换的模型Inference Program结构的文件名称。如果设置为None,则使用 `__model__` 作为默认的文件名 |
| `params_filename` | str | None | 存储需要转换的模型所有参数的文件名称。当且仅当所有模型参数被保>存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为None |
| `serving_server` | str | `"serving_server"` | 转换后的模型文件和配置文件的存储路径。默认值为serving_server |
| `serving_client` | str | `"serving_client"` | 转换后的客户端配置文件存储路径。默认值为serving_client |
上述转换命令的参数具体含义如下表所示
| 参数 | 类型 | 默认值 | 描述 |
| ----------------- | ---- | ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `dirname` | str | - | 需要转换的模型文件存储路径,Program结构文件和参数文件均保存在此目录。 |
| `model_filename` | str | None | 存储需要转换的模型Inference Program结构的文件名称。如果设置为None,则使用 `__model__` 作为默认的文件名 |
| `params_filename` | str | None | 存储需要转换的模型所有参数的文件名称。当且仅当所有模型参数被保>存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为None |
| `serving_server` | str | `"serving_server"` | 转换后的模型文件和配置文件的存储路径。默认值为serving_server |
| `serving_client` | str | `"serving_client"` | 转换后的客户端配置文件存储路径。默认值为serving_client |
- 下载并解压已经构建后完成的检索库 index
```shell
......@@ -167,15 +169,17 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
# 下载构建完成的检索库 index
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar
# 解压构建完成的检索库 index
tar -xf drink_dataset_v1.0.tar
tar -xf drink_dataset_v2.0.tar
```
<a name="3.2"></a>
### 3.2 服务部署和请求
**注意:** 识别服务涉及到多个模型,出于性能考虑采用 PipeLine 部署方式。Pipeline 部署方式当前不支持 windows 平台。
- 进入到工作目录
```shell
cd ./paddleserving/recognition
cd ./deploy/paddleserving/recognition
```
paddleserving 目录包含启动 Python Pipeline 服务、C++ Serving 服务和发送预测请求的代码,包括:
```shell
......@@ -190,6 +194,7 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
```
<a name="3.2.1"></a>
#### 3.2.1 Python Serving
- 启动服务:
......@@ -204,17 +209,19 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
```
成功运行后,模型预测的结果会打印在客户端中,如下所示:
```log
{'err_no': 0, 'err_msg': '', 'key': ['result'], 'value': ["[{'bbox': [0, 0, 600, 600], 'rec_docs': '红牛-强化型', 'rec_scores': 0.7408101}]"], 'tensors': []}
{'err_no': 0, 'err_msg': '', 'key': ['result'], 'value': ["[{'bbox': [438, 71, 660, 712], 'rec_docs': '元气森林', 'rec_scores': 0.7581642}, {'bbox': [220, 72, 449, 689], 'rec_docs': '元气森林', 'rec_scores': 0.68961805}, {'bbox': [794, 104, 978, 652], 'rec_docs': '元气森林', 'rec_scores': 0.63075215}]"], 'tensors': []}
```
<a name="3.2.2"></a>
#### 3.2.2 C++ Serving
与Python Serving不同,C++ Serving客户端调用 C++ OP来预测,因此在启动服务之前,需要编译并安装 serving server包,并设置 `SERVING_BIN`
- 编译并安装Serving server包
```shell
# 进入工作目录
cd PaddleClas/deploy/paddleserving
cd ./deploy/paddleserving
# 一键编译安装Serving server、设置 SERVING_BIN
source ./build_server.sh python3.7
```
......@@ -222,8 +229,8 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
- C++ Serving使用的输入输出格式与Python不同,因此需要执行以下命令,将4个文件复制到下的文件覆盖掉[3.1](#31-模型转换)得到文件夹中的对应4个prototxt文件。
```shell
# 进入PaddleClas/deploy目录
cd PaddleClas/deploy/
# 回到deploy目录
cd ../
# 覆盖prototxt文件
\cp ./paddleserving/recognition/preprocess/general_PPLCNetV2_base_pretrained_v1.0_serving/*.prototxt ./models/general_PPLCNetV2_base_pretrained_v1.0_serving/
......@@ -252,9 +259,9 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
成功运行后,模型预测的结果会打印在客户端中,如下所示:
```log
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0614 03:01:36.273097 6084 naming_service_thread.cpp:202] brpc::policy::ListNamingService("127.0.0.1:9400"): added 1
I0614 03:01:37.393564 6084 general_model.cpp:490] [client]logid=0,client_cost=1107.82ms,server_cost=1101.75ms.
[{'bbox': [0, 0, 600, 600], 'rec_docs': '红牛-强化型', 'rec_scores': 0.7508101}]
I0903 16:03:20.020586 35600 naming_service_thread.cpp:202] brpc::policy::ListNamingService("127.0.0.1:9400"): added 1
I0903 16:03:21.346057 35600 general_model.cpp:490] [client]logid=0,client_cost=1306.26ms,server_cost=1293.65ms.
[{'bbox': [437, 71, 660, 727], 'rec_docs': '元气森林', 'rec_scores': 0.76902336}, {'bbox': [222, 72, 449, 700], 'rec_docs': '元气森林', 'rec_scores': 0.69347066}, {'bbox': [794, 104, 979, 652], 'rec_docs': '元气森林', 'rec_scores': 0.6305151}]
```
- 关闭服务
......@@ -265,6 +272,7 @@ python3.7 -m pip install paddle-serving-server-gpu==0.7.0.post112 # GPU with CUD
执行完毕后出现`Process stopped`信息表示成功关闭服务。
<a name="4"></a>
## 4. FAQ
**Q1**: 发送请求后没有结果返回或者提示输出解码报错
......@@ -276,6 +284,6 @@ unset http_proxy
```
**Q2**: 启动服务后没有任何反应
**A2**: 可以检查`config.yml``model_config`对应的路径是否存在,文件夹命名是否正确
**A2**: 可以检查 `config.yml``model_config` 对应的路径是否存在,文件夹命名是否正确
更多的服务部署类型,如 `RPC 预测服务` 等,可以参考 Serving 的[github 官网](https://github.com/PaddlePaddle/Serving/tree/v0.9.0/examples)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册