diff --git a/README.md b/README.md index 46b97be4236a9f2316c97b47396187fbce2cb22b..7c6df8d5ab4463c59c1ad250383f63ac1d01529e 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ pip install paddle-serving-server-gpu # GPU ``` You may need to use a domestic mirror source (in China, you can use the Tsinghua mirror source, add `-i https://pypi.tuna.tsinghua.edu.cn/simple` to pip command) to speed up the download. - + Client package support Centos 7 and Ubuntu 18, or you can use HTTP service without install client.
- Why do we need to support distributed sparse parameter indexing in Paddle Serving? 1) In some recommendation scenarios, the number of features can be up to hundreds of billions that a single node can not hold the parameters within random access memory. 2) Paddle Serving supports distributed sparse parameter indexing that can couple with paddle inference. Users do not need to do extra work to have a low latency inference engine with hundreds of billions of parameters. - -### 3.2 Model Management, online A/B test, Model Online Reloading -Paddle Serving's C++ engine supports model management, online A/B test and model online reloading. Currently, python API is not released yet, please wait for the next release. +### 3.2 Online A/B test + +After sufficient offline evaluation of the model, online A/B test is usually needed to decide whether to enable the service on a large scale. The following figure shows the basic structure of A/B test with Paddle Serving. After the client is configured with the corresponding configuration, the traffic will be automatically distributed to different servers to achieve A/B test. Please refer to [ABTEST in Paddle Serving](ABTEST_IN_PADDLE_SERVING.md) for specific examples. + +
+
+
+
+
+
+
+### 3.3 Model Online Reloading
+
+In order to ensure the availability of services, the model needs to be hot loaded without service interruption. Paddle Serving supports this feature and provides a tool for monitoring output models to update local models. Please refer to [Hot loading in Paddle Serving](HOT_LOADING_IN_SERVING.md) for specific examples.
+
+### 3.4 Model Management
+
+Paddle Serving's C++ engine supports model management. Currently, python API is not released yet, please wait for the next release.
## 4. User Types
Paddle Serving provides RPC and HTTP protocol for users. For HTTP service, we recommend users with median or small traffic services to use, and the latency is not a strict requirement. For RPC protocol, we recommend high traffic services and low latency required services to use. For users who use distributed sparse parameter indexing built-in service, it is not necessary to care about the underlying details of communication. The following figure gives out several scenarios that user may want to use Paddle Serving.
diff --git a/doc/DESIGN_DOC_CN.md b/doc/DESIGN_DOC_CN.md
index 2a63d56593dc47a5ca69f9c5c324710ee6dc3fc6..c068ac35bb6beebe70a6f873318c6d5059fc51e7 100644
--- a/doc/DESIGN_DOC_CN.md
+++ b/doc/DESIGN_DOC_CN.md
@@ -159,14 +159,30 @@ Paddle Serving的核心执行引擎是一个有向无环图,图中的每个节
- + 为什么要使用Paddle Serving提供的分布式稀疏参数索引服务?1)在一些推荐场景中,模型的输入特征规模通常可以达到上千亿,单台机器无法支撑T级别模型在内存的保存,因此需要进行分布式存储。2)Paddle Serving提供的分布式稀疏参数索引服务,具有并发请求多个节点的能力,从而以较低的延时完成预估服务。 -### 3.2 模型管理、在线A/B流量测试、模型热加载 +### 3.2 在线A/B流量测试 + +在对模型进行充分的离线评估后,通常需要进行在线A/B测试,来决定是否大规模上线服务。下图为使用Paddle Serving做A/B测试的基本结构,Client端做好相应的配置后,自动将流量分发给不同的Server,从而完成A/B测试。具体例子请参考[如何使用Paddle Serving做ABTEST](ABTEST_IN_PADDLE_SERVING_CN.md)。 + +
+
+
+
+
+ + +### 3.3 模型热加载 -Paddle Serving的C++引擎支持模型管理、在线A/B流量测试、模型热加载等功能,当前在Python API还有没完全开放这部分功能的配置,敬请期待。 +为了保证服务的可用性,需要在服务不中断的情况下对模型进行热加载。Paddle Serving对该特性进行了支持,并提供了一个监控产出模型更新本地模型的工具,具体例子请参考[Paddle Serving中的模型热加载](HOT_LOADING_IN_SERVING_CN.md)。 + +### 3.4 模型管理 + +Paddle Serving的C++引擎支持模型管理功能,当前在Python API还有没完全开放这部分功能的配置,敬请期待。 ## 4. 用户类型 + Paddle Serving面向的用户提供RPC和HTTP两种访问协议。对于HTTP协议,我们更倾向于流量中小型的服务使用,并且对延时没有严格要求的AI服务开发者。对于RPC协议,我们面向流量较大,对延时要求更高的用户,此外RPC的客户端可能也处在一个大系统的服务中,这种情况下非常适合使用Paddle Serving提供的RPC服务。对于使用分布式稀疏参数索引服务而言,Paddle Serving的用户不需要关心底层的细节,其调用本质也是通过RPC服务再调用RPC服务。下图给出了当前设计的Paddle Serving可能会使用Serving服务的几种场景。
diff --git a/doc/INFERNCE_TO_SERVING.md b/doc/INFERNCE_TO_SERVING.md new file mode 100644 index 0000000000000000000000000000000000000000..8334159ea255ca65241a2b567e43682a148bb775 --- /dev/null +++ b/doc/INFERNCE_TO_SERVING.md @@ -0,0 +1,14 @@ +# How to Convert Paddle Inference Model To Paddle Serving Format + +([简体中文](./INFERENCE_TO_SERVING_CN.md)|English) + +## Example + +``` python +from paddle_serving_client.io import inference_model_to_serving +inference_model_dir = "your_inference_model" +serving_client_dir = "serving_client_dir" +serving_server_dir = "serving_server_dir" +feed_var_names, fetch_var_names = inference_model_to_serving( + inference_model_dir, serving_client_dir, serving_server_dir) +``` diff --git a/doc/INFERNCE_TO_SERVING_CN.md b/doc/INFERNCE_TO_SERVING_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..94d1def424db467e200020c69fbd6d1599a5ffde --- /dev/null +++ b/doc/INFERNCE_TO_SERVING_CN.md @@ -0,0 +1,14 @@ +# 如何从Paddle保存的预测模型转为Paddle Serving格式可部署的模型 + +([English](./INFERENCE_TO_SERVING.md)|简体中文) + +## 示例 + +``` python +from paddle_serving_client.io import inference_model_to_serving +inference_model_dir = "your_inference_model" +serving_client_dir = "serving_client_dir" +serving_server_dir = "serving_server_dir" +feed_var_names, fetch_var_names = inference_model_to_serving( + inference_model_dir, serving_client_dir, serving_server_dir) +``` diff --git a/doc/NEW_WEB_SERVICE.md b/doc/NEW_WEB_SERVICE.md new file mode 100644 index 0000000000000000000000000000000000000000..63f62a774d914c7271bfed1508881e04f74f2ca8 --- /dev/null +++ b/doc/NEW_WEB_SERVICE.md @@ -0,0 +1,64 @@ +# How to develop a new Web service? + +([简体中文](NEW_WEB_SERVICE_CN.md)|English) + +This document will take the image classification service based on the Imagenet data set as an example to introduce how to develop a new web service. The complete code can be visited at [here](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/imagenet/image_classification_service.py). + +## WebService base class + +Paddle Serving implements the [WebService](https://github.com/PaddlePaddle/Serving/blob/develop/python/paddle_serving_server/web_service.py#L23) base class. You need to override its `preprocess` and `postprocess` method. The default implementation is as follows: + +```python +class WebService(object): + + def preprocess(self, feed={}, fetch=[]): + return feed, fetch + def postprocess(self, feed={}, fetch=[], fetch_map=None): + return fetch_map +``` + +### preprocess + +The preprocess method has two input parameters, `feed` and `fetch`. For an HTTP request `request`: + +- The value of `feed` is request data `request.json` +- The value of `fetch` is the fetch part `request.json["fetch"]` in the request data + +The return values are the feed and fetch values used in the prediction. + +### postprocess + +The postprocess method has three input parameters, `feed`, `fetch` and `fetch_map`: + +- The value of `feed` is request data `request.json` +- The value of `fetch` is the fetch part `request.json["fetch"]` in the request data +- The value of `fetch_map` is the model output value. + +The return value will be processed as `{"reslut": fetch_map}` as the return of the HTTP request. + +## Develop ImageService class + +```python +class ImageService(WebService): + def preprocess(self, feed={}, fetch=[]): + reader = ImageReader() + if "image" not in feed: + raise ("feed data error!") + if isinstance(feed["image"], list): + feed_batch = [] + for image in feed["image"]: + sample = base64.b64decode(image) + img = reader.process_image(sample) + res_feed = {} + res_feed["image"] = img.reshape(-1) + feed_batch.append(res_feed) + return feed_batch, fetch + else: + sample = base64.b64decode(feed["image"]) + img = reader.process_image(sample) + res_feed = {} + res_feed["image"] = img.reshape(-1) + return res_feed, fetch +``` + +For the above `ImageService`, only the `preprocess` method is rewritten to process the image data in Base64 format into the data format required by prediction. diff --git a/doc/NEW_WEB_SERVICE_CN.md b/doc/NEW_WEB_SERVICE_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..e1a21d8a0e91a114c9d94b09ef0afa9a0d29de89 --- /dev/null +++ b/doc/NEW_WEB_SERVICE_CN.md @@ -0,0 +1,64 @@ +# 如何开发一个新的Web Service? + +(简体中文|[English](NEW_WEB_SERVICE.md)) + +本文档将以Imagenet图像分类服务为例,来介绍如何开发一个新的Web Service。您可以在[这里](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/imagenet/image_classification_service.py)查阅完整的代码。 + +## WebService基类 + +Paddle Serving实现了[WebService](https://github.com/PaddlePaddle/Serving/blob/develop/python/paddle_serving_server/web_service.py#L23)基类,您需要重写它的`preprocess`方法和`postprocess`方法,默认实现如下: + +```python +class WebService(object): + + def preprocess(self, feed={}, fetch=[]): + return feed, fetch + def postprocess(self, feed={}, fetch=[], fetch_map=None): + return fetch_map +``` + +### preprocess方法 + +preprocess方法有两个输入参数,`feed`和`fetch`。对于一个HTTP请求`request`: + +- `feed`的值为请求数据`request.json` +- `fetch`的值为请求数据中的fetch部分`request.json["fetch"]` + +返回值分别是预测过程中用到的feed和fetch值。 + +### postprocess方法 + +postprocess方法有三个输入参数,`feed`、`fetch`和`fetch_map`: + +- `feed`的值为请求数据`request.json` +- `fetch`的值为请求数据中的fetch部分`request.json["fetch"]` +- `fetch_map`的值为fetch到的模型输出值 + +返回值将会被处理成`{"reslut": fetch_map}`作为HTTP请求的返回。 + +## 开发ImageService类 + +```python +class ImageService(WebService): + def preprocess(self, feed={}, fetch=[]): + reader = ImageReader() + if "image" not in feed: + raise ("feed data error!") + if isinstance(feed["image"], list): + feed_batch = [] + for image in feed["image"]: + sample = base64.b64decode(image) + img = reader.process_image(sample) + res_feed = {} + res_feed["image"] = img.reshape(-1) + feed_batch.append(res_feed) + return feed_batch, fetch + else: + sample = base64.b64decode(feed["image"]) + img = reader.process_image(sample) + res_feed = {} + res_feed["image"] = img.reshape(-1) + return res_feed, fetch +``` + +对于上述的`ImageService`,只重写了前处理方法,将base64格式的图片数据处理成模型预测需要的数据格式。 diff --git a/doc/abtest.png b/doc/abtest.png index 3a33c4b30b96b32645d84291133cff0f0b79fcca..5e8f8980dffb46f4960390e6edb281968ae8bd83 100644 Binary files a/doc/abtest.png and b/doc/abtest.png differ diff --git a/python/examples/bert/bert_web_service.py b/python/examples/bert/bert_web_service.py index e22e379d67e076d4712c8971b6d342b4eaceadb2..f72694c0e8c5bb7ab2778278d3fc79f13516dc12 100644 --- a/python/examples/bert/bert_web_service.py +++ b/python/examples/bert/bert_web_service.py @@ -36,3 +36,4 @@ bert_service.set_gpus(gpu_ids) bert_service.prepare_server( workdir="workdir", port=int(sys.argv[2]), device="gpu") bert_service.run_server() +bert_service.run_flask() diff --git a/python/examples/faster_rcnn_model/000000570688.jpg b/python/examples/faster_rcnn_model/000000570688.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54 Binary files /dev/null and b/python/examples/faster_rcnn_model/000000570688.jpg differ diff --git a/python/examples/faster_rcnn_model/000000570688_bbox.jpg b/python/examples/faster_rcnn_model/000000570688_bbox.jpg new file mode 100644 index 0000000000000000000000000000000000000000..61bc11c02c92b92cffac91a6c3533a90a45c4e14 Binary files /dev/null and b/python/examples/faster_rcnn_model/000000570688_bbox.jpg differ diff --git a/python/examples/faster_rcnn_model/README.md b/python/examples/faster_rcnn_model/README.md new file mode 100644 index 0000000000000000000000000000000000000000..66f65b5ad77186dd3dd08acaddc85356277fe6fd --- /dev/null +++ b/python/examples/faster_rcnn_model/README.md @@ -0,0 +1,38 @@ +# Faster RCNN model on Paddle Serving + +([简体中文](./README_CN.md)|English) + +### Get The Faster RCNN Model +``` +wget https://paddle-serving.bj.bcebos.com/pddet_demo/faster_rcnn_model.tar.gz +wget https://paddle-serving.bj.bcebos.com/pddet_demo/infer_cfg.yml +``` +If you want to have more detection models, please refer to [Paddle Detection Model Zoo](https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.2/docs/MODEL_ZOO_cn.md) + +### Start the service +``` +tar xf faster_rcnn_model.tar.gz +mv faster_rcnn_model/pddet *. +GLOG_v=2 python -m paddle_serving_server_gpu.serve --model pddet_serving_model --port 9494 --gpu_id 0 +``` + +### Perform prediction +``` +python test_client.py pddet_client_conf/serving_client_conf.prototxt infer_cfg.yml 000000570688.jpg +``` + +## 3. Result analysis +
+
+
+
+
+This is the input picture + +
+
+
+
+
+ +This is the picture after adding bbox. You can see that the client has done post-processing for the picture. In addition, the output/bbox.json also has the number and coordinate information of each box. diff --git a/python/examples/faster_rcnn_model/README_CN.md b/python/examples/faster_rcnn_model/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..7aa4d343f05df92068d36499b48d9aa5ad7b2a2e --- /dev/null +++ b/python/examples/faster_rcnn_model/README_CN.md @@ -0,0 +1,37 @@ +# 使用Paddle Serving部署Faster RCNN模型 + +(简体中文|[English](./README.md)) + +## 获得Faster RCNN模型 +``` +wget https://paddle-serving.bj.bcebos.com/pddet_demo/faster_rcnn_model.tar.gz +wget https://paddle-serving.bj.bcebos.com/pddet_demo/infer_cfg.yml +``` +如果你想要更多的检测模型,请参考[Paddle检测模型库](https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.2/docs/MODEL_ZOO_cn.md) + +### 启动服务 +``` +tar xf faster_rcnn_model.tar.gz +mv faster_rcnn_model/pddet* . +GLOG_v=2 python -m paddle_serving_server_gpu.serve --model pddet_serving_model --port 9494 --gpu_id 0 +``` + +### 执行预测 +``` +python test_client.py pddet_client_conf/serving_client_conf.prototxt infer_cfg.yml 000000570688.jpg +``` + +## 3. 结果分析 +
+
+
+
+
+这是输入图片 + +
+
+
+
+
+这是实现添加了bbox之后的图片,可以看到客户端已经为图片做好了后处理,此外在output/bbox.json也有各个框的编号和坐标信息。 diff --git a/python/examples/faster_rcnn_model/test_client.py b/python/examples/faster_rcnn_model/test_client.py new file mode 100755 index 0000000000000000000000000000000000000000..ae2e5b8f6e961d965555d8f268f38be14c0263d0 --- /dev/null +++ b/python/examples/faster_rcnn_model/test_client.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +import sys +import os +import time +from paddle_serving_app.reader.pddet import Detection +import numpy as np + +py_version = sys.version_info[0] + +feed_var_names = ['image', 'im_shape', 'im_info'] +fetch_var_names = ['multiclass_nms'] +pddet = Detection(config_path=sys.argv[2], output_dir="./output") +feed_dict = pddet.preprocess(feed_var_names, sys.argv[3]) +client = Client() +client.load_client_config(sys.argv[1]) +client.connect(['127.0.0.1:9494']) +fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names) +outs = fetch_map.values() +pddet.postprocess(fetch_map, fetch_var_names) diff --git a/python/examples/imagenet/image_classification_service.py b/python/examples/imagenet/image_classification_service.py index 2776eb1bc7126fab32dbb05774fb0060506b61af..ee3ae6dd1c64bda154bbadabe8d1e91da734fb5a 100644 --- a/python/examples/imagenet/image_classification_service.py +++ b/python/examples/imagenet/image_classification_service.py @@ -31,14 +31,14 @@ class ImageService(WebService): sample = base64.b64decode(image) img = reader.process_image(sample) res_feed = {} - res_feed["image"] = img.reshape(-1) + res_feed["image"] = img feed_batch.append(res_feed) return feed_batch, fetch else: sample = base64.b64decode(feed["image"]) img = reader.process_image(sample) res_feed = {} - res_feed["image"] = img.reshape(-1) + res_feed["image"] = img return res_feed, fetch @@ -47,3 +47,4 @@ image_service.load_model_config(sys.argv[1]) image_service.prepare_server( workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu") image_service.run_server() +image_service.run_flask() diff --git a/python/examples/imagenet/image_classification_service_gpu.py b/python/examples/imagenet/image_classification_service_gpu.py index 287392e4f3ea922686cb03a032ba0b8e13d39709..d8ba4ed8cda9f600fb6d33441b90accdf5ecc532 100644 --- a/python/examples/imagenet/image_classification_service_gpu.py +++ b/python/examples/imagenet/image_classification_service_gpu.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle_serving_server_gpu.web_service import WebService import sys import cv2 import base64 import numpy as np from image_reader import ImageReader +from paddle_serving_server_gpu.web_service import WebService class ImageService(WebService): @@ -32,14 +32,14 @@ class ImageService(WebService): sample = base64.b64decode(image) img = reader.process_image(sample) res_feed = {} - res_feed["image"] = img.reshape(-1) + res_feed["image"] = img feed_batch.append(res_feed) return feed_batch, fetch else: sample = base64.b64decode(feed["image"]) img = reader.process_image(sample) res_feed = {} - res_feed["image"] = img.reshape(-1) + res_feed["image"] = img return res_feed, fetch @@ -49,3 +49,4 @@ image_service.set_gpus("0,1") image_service.prepare_server( workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu") image_service.run_server() +image_service.run_flask() diff --git a/python/examples/imagenet/image_http_client.py b/python/examples/imagenet/image_http_client.py index cda0f33ac82d0bd228a22a8f438cbe1aa013eadf..d920eccb06cc9ad4a87237792a1e688fd76b0d6e 100644 --- a/python/examples/imagenet/image_http_client.py +++ b/python/examples/imagenet/image_http_client.py @@ -31,7 +31,7 @@ def predict(image_path, server): r = requests.post( server, data=req, headers={"Content-Type": "application/json"}) try: - print(r.json()["score"][0]) + print(r.json()["result"]["score"]) except ValueError: print(r.text) return r diff --git a/python/examples/imagenet/image_rpc_client.py b/python/examples/imagenet/image_rpc_client.py index 76f3a043474bf75e1e96a44f18ac7dfe3da11f78..f905179629f0dfc8c9da09b0cae90bae7be3687e 100644 --- a/python/examples/imagenet/image_rpc_client.py +++ b/python/examples/imagenet/image_rpc_client.py @@ -26,7 +26,7 @@ start = time.time() for i in range(1000): with open("./data/n01440764_10026.JPEG", "rb") as f: img = f.read() - img = reader.process_image(img).reshape(-1) + img = reader.process_image(img) fetch_map = client.predict(feed={"image": img}, fetch=["score"]) end = time.time() print(end - start) diff --git a/python/examples/imdb/text_classify_service.py b/python/examples/imdb/text_classify_service.py index 50d0d1aebba34a630c16442c6e3d00460bb1bc6a..5ff919ebb44b9a2590b148e4ccf8b91ce85f3f53 100755 --- a/python/examples/imdb/text_classify_service.py +++ b/python/examples/imdb/text_classify_service.py @@ -39,3 +39,4 @@ imdb_service.prepare_server( workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu") imdb_service.prepare_dict({"dict_file_path": sys.argv[4]}) imdb_service.run_server() +imdb_service.run_flask() diff --git a/python/paddle_serving_app/reader/pddet/__init__.py b/python/paddle_serving_app/reader/pddet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e356387b5855cb60606f7bdebe7d8c6d091814 --- /dev/null +++ b/python/paddle_serving_app/reader/pddet/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import argparse +from .image_tool import Resize, Detection diff --git a/python/paddle_serving_app/reader/pddet/image_tool.py b/python/paddle_serving_app/reader/pddet/image_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..4b461bc491a90d25a3259cf6db806beae6dbf593 --- /dev/null +++ b/python/paddle_serving_app/reader/pddet/image_tool.py @@ -0,0 +1,620 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import numpy as np +from PIL import Image, ImageDraw +import cv2 +import yaml +import copy +import argparse +import logging +import paddle.fluid as fluid +import json + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + +precision_map = { + 'trt_int8': fluid.core.AnalysisConfig.Precision.Int8, + 'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32, + 'trt_fp16': fluid.core.AnalysisConfig.Precision.Half +} + + +class Resize(object): + def __init__(self, + target_size, + max_size=0, + interp=cv2.INTER_LINEAR, + use_cv2=True, + image_shape=None): + super(Resize, self).__init__() + self.target_size = target_size + self.max_size = max_size + self.interp = interp + self.use_cv2 = use_cv2 + self.image_shape = image_shape + + def __call__(self, im): + origin_shape = im.shape[:2] + im_c = im.shape[2] + if self.max_size != 0: + im_size_min = np.min(origin_shape[0:2]) + im_size_max = np.max(origin_shape[0:2]) + im_scale = float(self.target_size) / float(im_size_min) + if np.round(im_scale * im_size_max) > self.max_size: + im_scale = float(self.max_size) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + resize_w = int(im_scale_x * float(origin_shape[1])) + resize_h = int(im_scale_y * float(origin_shape[0])) + else: + im_scale_x = float(self.target_size) / float(origin_shape[1]) + im_scale_y = float(self.target_size) / float(origin_shape[0]) + resize_w = self.target_size + resize_h = self.target_size + if self.use_cv2: + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + else: + if self.max_size != 0: + raise TypeError( + 'If you set max_size to cap the maximum size of image,' + 'please set use_cv2 to True to resize the image.') + im = im.astype('uint8') + im = Image.fromarray(im) + im = im.resize((int(resize_w), int(resize_h)), self.interp) + im = np.array(im) + # padding im + if self.max_size != 0 and self.image_shape is not None: + padding_im = np.zeros( + (self.max_size, self.max_size, im_c), dtype=np.float32) + im_h, im_w = im.shape[:2] + padding_im[:im_h, :im_w, :] = im + im = padding_im + return im, im_scale_x + + +class Normalize(object): + def __init__(self, mean, std, is_scale=True, is_channel_first=False): + super(Normalize, self).__init__() + self.mean = mean + self.std = std + self.is_scale = is_scale + self.is_channel_first = is_channel_first + + def __call__(self, im): + im = im.astype(np.float32, copy=False) + if self.is_channel_first: + mean = np.array(self.mean)[:, np.newaxis, np.newaxis] + std = np.array(self.std)[:, np.newaxis, np.newaxis] + else: + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + if self.is_scale: + im = im / 255.0 + im -= mean + im /= std + return im + + +class Permute(object): + def __init__(self, to_bgr=False, channel_first=True): + self.to_bgr = to_bgr + self.channel_first = channel_first + + def __call__(self, im): + if self.channel_first: + im = im.transpose((2, 0, 1)) + if self.to_bgr: + im = im[[2, 1, 0], :, :] + return im.copy() + + +class PadStride(object): + def __init__(self, stride=0): + assert stride >= 0, "Unsupported stride: {}," + " the stride in PadStride must be greater " + "or equal to 0".format(stride) + self.coarsest_stride = stride + + def __call__(self, im): + coarsest_stride = self.coarsest_stride + if coarsest_stride == 0: + return im + im_c, im_h, im_w = im.shape + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + return padding_im + + +class Detection(): + def __init__(self, config_path, output_dir): + self.config_path = config_path + self.if_visualize = True + self.if_dump_result = True + self.output_dir = output_dir + + def DecodeImage(self, im_path): + assert os.path.exists(im_path), "Image path {} can not be found".format( + im_path) + with open(im_path, 'rb') as f: + im = f.read() + data = np.frombuffer(im, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + return im + + def Preprocess(self, img_path, arch, config): + img = self.DecodeImage(img_path) + orig_shape = img.shape + scale = 1. + data = [] + data_config = copy.deepcopy(config) + for data_aug_conf in data_config: + obj = data_aug_conf.pop('type') + preprocess = eval(obj)(**data_aug_conf) + if obj == 'Resize': + img, scale = preprocess(img) + else: + img = preprocess(img) + + img = img[np.newaxis, :] # N, C, H, W + data.append(img) + extra_info = self.get_extra_info(img, arch, orig_shape, scale) + data += extra_info + return data + + def expand_boxes(self, boxes, scale): + """ + Expand an array of boxes by a given scale. + """ + w_half = (boxes[:, 2] - boxes[:, 0]) * .5 + h_half = (boxes[:, 3] - boxes[:, 1]) * .5 + x_c = (boxes[:, 2] + boxes[:, 0]) * .5 + y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + + w_half *= scale + h_half *= scale + + boxes_exp = np.zeros(boxes.shape) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + + return boxes_exp + + def mask2out(self, results, clsid2catid, resolution, thresh_binarize=0.5): + import pycocotools.mask as mask_util + scale = (resolution + 2.0) / resolution + + segm_res = [] + + for t in results: + bboxes = t['bbox'][0] + lengths = t['bbox'][1][0] + if bboxes.shape == (1, 1) or bboxes is None: + continue + if len(bboxes.tolist()) == 0: + continue + masks = t['mask'][0] + + s = 0 + # for each sample + for i in range(len(lengths)): + num = lengths[i] + im_shape = t['im_shape'][i] + + bbox = bboxes[s:s + num][:, 2:] + clsid_scores = bboxes[s:s + num][:, 0:2] + mask = masks[s:s + num] + s += num + + im_h = int(im_shape[0]) + im_w = int(im_shape[1]) + + expand_bbox = expand_boxes(bbox, scale) + expand_bbox = expand_bbox.astype(np.int32) + + padded_mask = np.zeros( + (resolution + 2, resolution + 2), dtype=np.float32) + + for j in range(num): + xmin, ymin, xmax, ymax = expand_bbox[j].tolist() + clsid, score = clsid_scores[j].tolist() + clsid = int(clsid) + padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :] + + catid = clsid2catid[clsid] + + w = xmax - xmin + 1 + h = ymax - ymin + 1 + w = np.maximum(w, 1) + h = np.maximum(h, 1) + + resized_mask = cv2.resize(padded_mask, (w, h)) + resized_mask = np.array( + resized_mask > thresh_binarize, dtype=np.uint8) + im_mask = np.zeros((im_h, im_w), dtype=np.uint8) + + x0 = min(max(xmin, 0), im_w) + x1 = min(max(xmax + 1, 0), im_w) + y0 = min(max(ymin, 0), im_h) + y1 = min(max(ymax + 1, 0), im_h) + + im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):( + y1 - ymin), (x0 - xmin):(x1 - xmin)] + segm = mask_util.encode( + np.array( + im_mask[:, :, np.newaxis], order='F'))[0] + catid = clsid2catid[clsid] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'category_id': catid, + 'segmentation': segm, + 'score': score + } + segm_res.append(coco_res) + return segm_res + + def draw_bbox(self, image, catid2name, bboxes, threshold, color_list): + """ + draw bbox on image + """ + draw = ImageDraw.Draw(image) + + for dt in np.array(bboxes): + catid, bbox, score = dt['category_id'], dt['bbox'], dt['score'] + if score < threshold: + continue + + xmin, ymin, w, h = bbox + xmax = xmin + w + ymax = ymin + h + + color = tuple(color_list[catid]) + + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill=color) + + # draw label + text = "{} {:.2f}".format(catid2name[catid], score) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + return image + + def draw_mask(self, image, masks, threshold, color_list, alpha=0.7): + """ + Draw mask on image + """ + mask_color_id = 0 + w_ratio = .4 + img_array = np.array(image).astype('float32') + for dt in np.array(masks): + segm, score = dt['segmentation'], dt['score'] + if score < threshold: + continue + import pycocotools.mask as mask_util + mask = mask_util.decode(segm) * 255 + color_mask = color_list[mask_color_id % len(color_list), 0:3] + mask_color_id += 1 + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + img_array[idx[0], idx[1], :] *= 1.0 - alpha + img_array[idx[0], idx[1], :] += alpha * color_mask + return Image.fromarray(img_array.astype('uint8')) + + def get_extra_info(self, im, arch, shape, scale): + info = [] + input_shape = [] + im_shape = [] + logger.info('The architecture is {}'.format(arch)) + if 'YOLO' in arch: + im_size = np.array([shape[:2]]).astype('int32') + logger.info('Extra info: im_size') + info.append(im_size) + elif 'SSD' in arch: + im_shape = np.array([shape[:2]]).astype('int32') + logger.info('Extra info: im_shape') + info.append([im_shape]) + elif 'RetinaNet' in arch: + input_shape.extend(im.shape[2:]) + im_info = np.array([input_shape + [scale]]).astype('float32') + logger.info('Extra info: im_info') + info.append(im_info) + elif 'RCNN' in arch: + input_shape.extend(im.shape[2:]) + im_shape.extend(shape[:2]) + im_info = np.array([input_shape + [scale]]).astype('float32') + im_shape = np.array([im_shape + [1.]]).astype('float32') + logger.info('Extra info: im_info, im_shape') + info.append(im_info) + info.append(im_shape) + else: + logger.error( + "Unsupported arch: {}, expect YOLO, SSD, RetinaNet and RCNN". + format(arch)) + return info + + def offset_to_lengths(self, lod): + offset = lod[0] + lengths = [offset[i + 1] - offset[i] for i in range(len(offset) - 1)] + return [lengths] + + def bbox2out(self, results, clsid2catid, is_bbox_normalized=False): + """ + Args: + results: request a dict, should include: `bbox`, `im_id`, + if is_bbox_normalized=True, also need `im_shape`. + clsid2catid: class id to category id map of COCO2017 dataset. + is_bbox_normalized: whether or not bbox is normalized. + """ + xywh_res = [] + for t in results: + bboxes = t['bbox'][0] + lengths = t['bbox'][1][0] + if bboxes.shape == (1, 1) or bboxes is None: + continue + + k = 0 + for i in range(len(lengths)): + num = lengths[i] + for j in range(num): + dt = bboxes[k] + clsid, score, xmin, ymin, xmax, ymax = dt.tolist() + catid = (clsid2catid[int(clsid)]) + + if is_bbox_normalized: + xmin, ymin, xmax, ymax = \ + self.clip_bbox([xmin, ymin, xmax, ymax]) + w = xmax - xmin + h = ymax - ymin + im_shape = t['im_shape'][0][i].tolist() + im_height, im_width = int(im_shape[0]), int(im_shape[1]) + xmin *= im_width + ymin *= im_height + w *= im_width + h *= im_height + else: + w = xmax - xmin + 1 + h = ymax - ymin + 1 + + bbox = [xmin, ymin, w, h] + coco_res = { + 'category_id': catid, + 'bbox': bbox, + 'score': score + } + xywh_res.append(coco_res) + k += 1 + return xywh_res + + def get_bbox_result(self, fetch_map, fetch_name, result, conf, clsid2catid): + is_bbox_normalized = True if 'SSD' in conf['arch'] else False + output = fetch_map[fetch_name] + lod = [fetch_map[fetch_name + '.lod']] + lengths = self.offset_to_lengths(lod) + np_data = np.array(output) + result['bbox'] = (np_data, lengths) + result['im_id'] = np.array([[0]]) + + bbox_results = self.bbox2out([result], clsid2catid, is_bbox_normalized) + return bbox_results + + def mask2out(self, results, clsid2catid, resolution, thresh_binarize=0.5): + import pycocotools.mask as mask_util + scale = (resolution + 2.0) / resolution + + segm_res = [] + + for t in results: + bboxes = t['bbox'][0] + lengths = t['bbox'][1][0] + if bboxes.shape == (1, 1) or bboxes is None: + continue + if len(bboxes.tolist()) == 0: + continue + masks = t['mask'][0] + + s = 0 + # for each sample + for i in range(len(lengths)): + num = lengths[i] + im_shape = t['im_shape'][i] + + bbox = bboxes[s:s + num][:, 2:] + clsid_scores = bboxes[s:s + num][:, 0:2] + mask = masks[s:s + num] + s += num + + im_h = int(im_shape[0]) + im_w = int(im_shape[1]) + + expand_bbox = expand_boxes(bbox, scale) + expand_bbox = expand_bbox.astype(np.int32) + + padded_mask = np.zeros( + (resolution + 2, resolution + 2), dtype=np.float32) + + for j in range(num): + xmin, ymin, xmax, ymax = expand_bbox[j].tolist() + clsid, score = clsid_scores[j].tolist() + clsid = int(clsid) + padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :] + + catid = clsid2catid[clsid] + + w = xmax - xmin + 1 + h = ymax - ymin + 1 + w = np.maximum(w, 1) + h = np.maximum(h, 1) + + resized_mask = cv2.resize(padded_mask, (w, h)) + resized_mask = np.array( + resized_mask > thresh_binarize, dtype=np.uint8) + im_mask = np.zeros((im_h, im_w), dtype=np.uint8) + + x0 = min(max(xmin, 0), im_w) + x1 = min(max(xmax + 1, 0), im_w) + y0 = min(max(ymin, 0), im_h) + y1 = min(max(ymax + 1, 0), im_h) + + im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):( + y1 - ymin), (x0 - xmin):(x1 - xmin)] + segm = mask_util.encode( + np.array( + im_mask[:, :, np.newaxis], order='F'))[0] + catid = clsid2catid[clsid] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'category_id': catid, + 'segmentation': segm, + 'score': score + } + segm_res.append(coco_res) + return segm_res + + def get_mask_result(self, fetch_map, fetch_var_names, result, conf, + clsid2catid): + resolution = conf['mask_resolution'] + bbox_out, mask_out = fetch_map[fetch_var_names] + lengths = self.offset_to_lengths(bbox_out.lod()) + bbox = np.array(bbox_out) + mask = np.array(mask_out) + result['bbox'] = (bbox, lengths) + result['mask'] = (mask, lengths) + mask_results = self.mask2out([result], clsid2catid, + conf['mask_resolution']) + return mask_results + + def get_category_info(self, with_background, label_list): + if label_list[0] != 'background' and with_background: + label_list.insert(0, 'background') + if label_list[0] == 'background' and not with_background: + label_list = label_list[1:] + clsid2catid = {i: i for i in range(len(label_list))} + catid2name = {i: name for i, name in enumerate(label_list)} + return clsid2catid, catid2name + + def color_map(self, num_classes): + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = np.array(color_map).reshape(-1, 3) + return color_map + + def visualize(self, + bbox_results, + catid2name, + num_classes, + mask_results=None): + image = Image.open(self.infer_img).convert('RGB') + color_list = self.color_map(num_classes) + image = self.draw_bbox(image, catid2name, bbox_results, 0.5, color_list) + if mask_results is not None: + image = self.draw_mask(image, mask_results, 0.5, color_list) + image_path = os.path.split(self.infer_img)[-1] + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + out_path = os.path.join(self.output_dir, image_path) + image.save(out_path, quality=95) + logger.info('Save visualize result to {}'.format(out_path)) + + def preprocess(self, feed_var_names, image_file): + self.infer_img = image_file + config_path = self.config_path + res = {} + assert config_path is not None, "Config path: {} des not exist!".format( + model_path) + with open(config_path) as f: + conf = yaml.safe_load(f) + + img_data = self.Preprocess(image_file, conf['arch'], conf['Preprocess']) + if 'SSD' in conf['arch']: + img_data, res['im_shape'] = img_data + img_data = [img_data] + if len(feed_var_names) != len(img_data): + raise ValueError( + 'the length of feed vars does not equals the length of preprocess of img data, please check your feed dict' + ) + + def processImg(v): + np_data = np.array(v[0]) + res = np_data + return res + + feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)} + return feed_dict + + def postprocess(self, fetch_map, fetch_var_names): + config_path = self.config_path + res = {} + with open(config_path) as f: + conf = yaml.safe_load(f) + if 'SSD' in conf['arch']: + img_data, res['im_shape'] = img_data + img_data = [img_data] + clsid2catid, catid2name = self.get_category_info( + conf['with_background'], conf['label_list']) + bbox_result = self.get_bbox_result(fetch_map, fetch_var_names[0], res, + conf, clsid2catid) + mask_result = None + if 'mask_resolution' in conf: + res['im_shape'] = img_data[-1] + mask_result = self.get_mask_result(fetch_map, fetch_var_names, res, + conf, clsid2catid) + if self.if_visualize: + if os.path.isdir(self.output_dir) is False: + os.mkdir(self.output_dir) + self.visualize(bbox_result, catid2name, + len(conf['label_list']), mask_result) + if self.if_dump_result: + if os.path.isdir(self.output_dir) is False: + os.mkdir(self.output_dir) + bbox_file = os.path.join(self.output_dir, 'bbox.json') + logger.info('dump bbox to {}'.format(bbox_file)) + with open(bbox_file, 'w') as f: + json.dump(bbox_result, f, indent=4) + if mask_result is not None: + mask_file = os.path.join(flags.output_dir, 'mask.json') + logger.info('dump mask to {}'.format(mask_file)) + with open(mask_file, 'w') as f: + json.dump(mask_result, f, indent=4) diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 8aeb22c92c781a4fb27b70403537f7016f05940d..801e2acb323ba64f246609684bc33194891a7250 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -26,6 +26,34 @@ int_type = 0 float_type = 1 +class _NOPProfiler(object): + def record(self, name): + pass + + def print_profile(self): + pass + + +class _TimeProfiler(object): + def __init__(self): + self.pid = os.getpid() + self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid) + self.time_record = [self.print_head] + + def record(self, name): + self.time_record.append('{}:{} '.format( + name, int(round(time.time() * 1000000)))) + + def print_profile(self): + self.time_record.append('\n') + sys.stderr.write(''.join(self.time_record)) + self.time_record = [self.print_head] + + +_is_profile = int(os.environ.get('FLAGS_profile_client', 0)) +_Profiler = _TimeProfiler if _is_profile else _NOPProfiler + + class SDKConfig(object): def __init__(self): self.sdk_desc = sdk.SDKConf() @@ -89,6 +117,7 @@ class Client(object): self.predictor_sdk_ = None self.producers = [] self.consumer = None + self.profile_ = _Profiler() def rpath(self): lib_path = os.path.dirname(paddle_serving_client.__file__) @@ -184,6 +213,8 @@ class Client(object): key)) def predict(self, feed=None, fetch=None, need_variant_tag=False): + self.profile_.record('py_prepro_0') + if feed is None or fetch is None: raise ValueError("You should specify feed and fetch for prediction") @@ -256,11 +287,17 @@ class Client(object): int_slot_batch.append(int_slot) float_slot_batch.append(float_slot) + self.profile_.record('py_prepro_1') + self.profile_.record('py_client_infer_0') + result_batch = self.result_handle_ res = self.client_handle_.batch_predict( float_slot_batch, float_feed_names, float_shape, int_slot_batch, int_feed_names, int_shape, fetch_names, result_batch, self.pid) + self.profile_.record('py_client_infer_1') + self.profile_.record('py_postpro_0') + if res == -1: return None @@ -273,7 +310,7 @@ class Client(object): if self.fetch_names_to_type_[name] == int_type: result_map[name] = result_batch.get_int64_by_name(mi, name) shape = result_batch.get_shape(mi, name) - result_map[name] = np.array(result_map[name]) + result_map[name] = np.array(result_map[name], dtype='int64') result_map[name].shape = shape if name in self.lod_tensor_set: result_map["{}.lod".format( @@ -281,7 +318,8 @@ class Client(object): elif self.fetch_names_to_type_[name] == float_type: result_map[name] = result_batch.get_float_by_name(mi, name) shape = result_batch.get_shape(mi, name) - result_map[name] = np.array(result_map[name]) + result_map[name] = np.array( + result_map[name], dtype='float32') result_map[name].shape = shape if name in self.lod_tensor_set: result_map["{}.lod".format( @@ -299,6 +337,9 @@ class Client(object): for mi, engine_name in enumerate(model_engine_names) } + self.profile_.record('py_postpro_1') + self.profile_.print_profile() + # When using the A/B test, the tag of variant needs to be returned return ret if not need_variant_tag else [ ret, self.result_handle_.variant_tag() diff --git a/python/paddle_serving_client/io/__init__.py b/python/paddle_serving_client/io/__init__.py index d723795f214e22957bff49f0ddf8fd42086b8a7e..74a6ca871b5c1e32b3c1ecbc6656c95d7c78a399 100644 --- a/python/paddle_serving_client/io/__init__.py +++ b/python/paddle_serving_client/io/__init__.py @@ -20,6 +20,7 @@ from paddle.fluid.framework import default_main_program from paddle.fluid.framework import Program from paddle.fluid import CPUPlace from paddle.fluid.io import save_inference_model +import paddle.fluid as fluid from ..proto import general_model_config_pb2 as model_conf import os @@ -100,3 +101,20 @@ def save_model(server_model_folder, with open("{}/serving_server_conf.stream.prototxt".format( server_model_folder), "wb") as fout: fout.write(config.SerializeToString()) + + +def inference_model_to_serving(infer_model, serving_client, serving_server): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + inference_program, feed_target_names, fetch_targets = \ + fluid.io.load_inference_model(dirname=infer_model, executor=exe) + feed_dict = { + x: inference_program.global_block().var(x) + for x in feed_target_names + } + fetch_dict = {x.name: x for x in fetch_targets} + save_model(serving_client, serving_server, feed_dict, fetch_dict, + inference_program) + feed_names = feed_dict.keys() + fetch_names = fetch_dict.keys() + return feed_names, fetch_names diff --git a/python/paddle_serving_server/__init__.py b/python/paddle_serving_server/__init__.py index 8062a7c83d99c0bed712ff46840b81f4557a353d..a58fb11ac3ee1fbe5086ae4381f6d6208c0c73ec 100644 --- a/python/paddle_serving_server/__init__.py +++ b/python/paddle_serving_server/__init__.py @@ -351,6 +351,7 @@ class Server(object): self._prepare_resource(workdir) self._prepare_engine(self.model_config_paths, device) self._prepare_infer_service(port) + self.port = port self.workdir = workdir infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn) diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index 4a033cbcf1d32a55eaacbe9c0f6704e304e127b3..a03649725b1c41ca94b8ef495a2fc80e8293aba0 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -18,6 +18,8 @@ from flask import Flask, request, abort from multiprocessing import Pool, Process from paddle_serving_server import OpMaker, OpSeqMaker, Server from paddle_serving_client import Client +from contextlib import closing +import socket class WebService(object): @@ -41,19 +43,34 @@ class WebService(object): server.set_num_threads(16) server.load_model_config(self.model_config) server.prepare_server( - workdir=self.workdir, port=self.port + 1, device=self.device) + workdir=self.workdir, port=self.port_list[0], device=self.device) server.run_server() + def port_is_available(self, port): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + sock.settimeout(2) + result = sock.connect_ex(('0.0.0.0', port)) + if result != 0: + return True + else: + return False + def prepare_server(self, workdir="", port=9393, device="cpu"): self.workdir = workdir self.port = port self.device = device + default_port = 12000 + self.port_list = [] + for i in range(1000): + if self.port_is_available(default_port + i): + self.port_list.append(default_port + i) + break def _launch_web_service(self): - self.client_service = Client() - self.client_service.load_client_config( - "{}/serving_server_conf.prototxt".format(self.model_config)) - self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) + self.client = Client() + self.client.load_client_config("{}/serving_server_conf.prototxt".format( + self.model_config)) + self.client.connect(["0.0.0.0:{}".format(self.port_list[0])]) def get_prediction(self, request): if not request.json: @@ -64,12 +81,12 @@ class WebService(object): feed, fetch = self.preprocess(request.json, request.json["fetch"]) if isinstance(feed, dict) and "fetch" in feed: del feed["fetch"] - fetch_map = self.client_service.predict(feed=feed, fetch=fetch) - for key in fetch_map: - fetch_map[key] = fetch_map[key][0].tolist() - result = self.postprocess( + fetch_map = self.client.predict(feed=feed, fetch=fetch) + fetch_map = self.postprocess( feed=request.json, fetch=fetch, fetch_map=fetch_map) - result = {"result": result} + for key in fetch_map: + fetch_map[key] = fetch_map[key].tolist() + result = {"result": fetch_map} except ValueError: result = {"result": "Request Value Error"} return result @@ -83,6 +100,24 @@ class WebService(object): p_rpc = Process(target=self._launch_rpc_service) p_rpc.start() + def run_flask(self): + app_instance = Flask(__name__) + + @app_instance.before_first_request + def init(): + self._launch_web_service() + + service_name = "/" + self.name + "/prediction" + + @app_instance.route(service_name, methods=["POST"]) + def run(): + return self.get_prediction(request) + + app_instance.run(host="0.0.0.0", + port=self.port, + threaded=False, + processes=4) + def preprocess(self, feed={}, fetch=[]): return feed, fetch diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index cb833ba32b20edeb22efd5b772506d32e05e4497..eb1ecfd8faaf34a6bf2955af46d5a8cf09085ad7 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -14,14 +14,15 @@ # pylint: disable=doc-string-missing from flask import Flask, request, abort -from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server -import paddle_serving_server_gpu as serving +from contextlib import closing from multiprocessing import Pool, Process, Queue from paddle_serving_client import Client +from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server from paddle_serving_server_gpu.serve import start_multi_card - +import socket import sys import numpy as np +import paddle_serving_server_gpu as serving class WebService(object): @@ -67,22 +68,39 @@ class WebService(object): def _launch_rpc_service(self, service_idx): self.rpc_service_list[service_idx].run_server() + def port_is_available(self, port): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + sock.settimeout(2) + result = sock.connect_ex(('0.0.0.0', port)) + if result != 0: + return True + else: + return False + def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0): self.workdir = workdir self.port = port self.device = device self.gpuid = gpuid + self.port_list = [] + default_port = 12000 + for i in range(1000): + if self.port_is_available(default_port + i): + self.port_list.append(default_port + i) + if len(self.port_list) > len(self.gpus): + break + if len(self.gpus) == 0: # init cpu service self.rpc_service_list.append( self.default_rpc_service( - self.workdir, self.port + 1, -1, thread_num=10)) + self.workdir, self.port_list[0], -1, thread_num=10)) else: for i, gpuid in enumerate(self.gpus): self.rpc_service_list.append( self.default_rpc_service( "{}_{}".format(self.workdir, i), - self.port + 1 + i, + self.port_list[i], gpuid, thread_num=10)) @@ -94,9 +112,9 @@ class WebService(object): endpoints = "" if gpu_num > 0: for i in range(gpu_num): - endpoints += "127.0.0.1:{},".format(self.port + i + 1) + endpoints += "127.0.0.1:{},".format(self.port_list[i]) else: - endpoints = "127.0.0.1:{}".format(self.port + 1) + endpoints = "127.0.0.1:{}".format(self.port_list[0]) self.client.connect([endpoints]) def get_prediction(self, request): @@ -109,11 +127,11 @@ class WebService(object): if isinstance(feed, dict) and "fetch" in feed: del feed["fetch"] fetch_map = self.client.predict(feed=feed, fetch=fetch) - for key in fetch_map: - fetch_map[key] = fetch_map[key][0].tolist() - result = self.postprocess( + fetch_map = self.postprocess( feed=request.json, fetch=fetch, fetch_map=fetch_map) - result = {"result": result} + for key in fetch_map: + fetch_map[key] = fetch_map[key].tolist() + result = {"result": fetch_map} except ValueError: result = {"result": "Request Value Error"} return result @@ -131,6 +149,24 @@ class WebService(object): for p in server_pros: p.start() + def run_flask(self): + app_instance = Flask(__name__) + + @app_instance.before_first_request + def init(): + self._launch_web_service() + + service_name = "/" + self.name + "/prediction" + + @app_instance.route(service_name, methods=["POST"]) + def run(): + return self.get_prediction(request) + + app_instance.run(host="0.0.0.0", + port=self.port, + threaded=False, + processes=4) + def preprocess(self, feed={}, fetch=[]): return feed, fetch diff --git a/python/setup.py.app.in b/python/setup.py.app.in index 13e71b22cdc5eb719c17af974dd2150710133491..3c0f8e065a5072919d808ba1da67f5c37eee0594 100644 --- a/python/setup.py.app.in +++ b/python/setup.py.app.in @@ -47,7 +47,8 @@ REQUIRED_PACKAGES = [ packages=['paddle_serving_app', 'paddle_serving_app.reader', - 'paddle_serving_app.utils'] + 'paddle_serving_app.utils', + 'paddle_serving_app.reader.pddet'] package_data={} package_dir={'paddle_serving_app': @@ -55,7 +56,9 @@ package_dir={'paddle_serving_app': 'paddle_serving_app.reader': '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader', 'paddle_serving_app.utils': - '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/utils',} + '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/utils', + 'paddle_serving_app.reader.pddet': + '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet',} setup( name='paddle-serving-app', diff --git a/tools/serving_build.sh b/tools/serving_build.sh index e4bf6ece3a9df1808b9190e9e77d8d2e8aba62c0..1e47b8f4fe26c689b5d6680c1478740201b335b9 100644 --- a/tools/serving_build.sh +++ b/tools/serving_build.sh @@ -323,6 +323,9 @@ function python_test_bert() { echo "bert RPC inference pass" ;; *) + echo "error type" + exit 1 + ;; esac echo "test bert $TYPE finished as expected." unset SERVING_BIN @@ -357,6 +360,9 @@ function python_test_imdb() { echo "imdb ignore GPU test" ;; *) + echo "error type" + exit 1 + ;; esac echo "test imdb $TYPE finished as expected." unset SERVING_BIN @@ -389,6 +395,9 @@ function python_test_lac() { echo "lac ignore GPU test" ;; *) + echo "error type" + exit 1 + ;; esac echo "test lac $TYPE finished as expected." unset SERVING_BIN @@ -408,6 +417,248 @@ function python_run_test() { cd ../.. # pwd: /Serving } +function monitor_test() { + local TYPE=$1 # pwd: /Serving + mkdir _monitor_test && cd _monitor_test # pwd: /Serving/_monitor_test + case $TYPE in + CPU): + pip install pyftpdlib + mkdir remote_path + mkdir local_path + cd remote_path # pwd: /Serving/_monitor_test/remote_path + check_cmd "python -m pyftpdlib -p 8000 &>/dev/null &" + cd .. # pwd: /Serving/_monitor_test + + # type: ftp + # remote_path: / + # remote_model_name: uci_housing.tar.gz + # local_tmp_path: ___tmp + # local_path: local_path + cd remote_path # pwd: /Serving/_monitor_test/remote_path + wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz + touch donefile + cd .. # pwd: /Serving/_monitor_test + mkdir -p local_path/uci_housing_model + python -m paddle_serving_server.monitor \ + --type='ftp' --ftp_host='127.0.0.1' --ftp_port='8000' \ + --remote_path='/' --remote_model_name='uci_housing.tar.gz' \ + --remote_donefile_name='donefile' --local_path='local_path' \ + --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \ + --local_tmp_path='___tmp' --unpacked_filename='uci_housing_model' \ + --interval='1' >/dev/null & + sleep 10 + if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then + echo "local_path/uci_housing_model/fluid_time_file not exist." + exit 1 + fi + ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill + rm -rf remote_path/* + rm -rf local_path/* + + # type: ftp + # remote_path: /tmp_dir + # remote_model_name: uci_housing_model + # local_tmp_path: ___tmp + # local_path: local_path + mkdir -p remote_path/tmp_dir && cd remote_path/tmp_dir # pwd: /Serving/_monitor_test/remote_path/tmp_dir + wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz + tar -xzf uci_housing.tar.gz + touch donefile + cd ../.. # pwd: /Serving/_monitor_test + mkdir -p local_path/uci_housing_model + python -m paddle_serving_server.monitor \ + --type='ftp' --ftp_host='127.0.0.1' --ftp_port='8000' \ + --remote_path='/tmp_dir' --remote_model_name='uci_housing_model' \ + --remote_donefile_name='donefile' --local_path='local_path' \ + --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \ + --local_tmp_path='___tmp' --interval='1' >/dev/null & + sleep 10 + if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then + echo "local_path/uci_housing_model/fluid_time_file not exist." + exit 1 + fi + ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill + rm -rf remote_path/* + rm -rf local_path/* + + # type: general + # remote_path: / + # remote_model_name: uci_housing.tar.gz + # local_tmp_path: ___tmp + # local_path: local_path + cd remote_path # pwd: /Serving/_monitor_test/remote_path + wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz + touch donefile + cd .. # pwd: /Serving/_monitor_test + mkdir -p local_path/uci_housing_model + python -m paddle_serving_server.monitor \ + --type='general' --general_host='ftp://127.0.0.1:8000' \ + --remote_path='/' --remote_model_name='uci_housing.tar.gz' \ + --remote_donefile_name='donefile' --local_path='local_path' \ + --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \ + --local_tmp_path='___tmp' --unpacked_filename='uci_housing_model' \ + --interval='1' >/dev/null & + sleep 10 + if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then + echo "local_path/uci_housing_model/fluid_time_file not exist." + exit 1 + fi + ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill + rm -rf remote_path/* + rm -rf local_path/* + + # type: general + # remote_path: /tmp_dir + # remote_model_name: uci_housing_model + # local_tmp_path: ___tmp + # local_path: local_path + mkdir -p remote_path/tmp_dir && cd remote_path/tmp_dir # pwd: /Serving/_monitor_test/remote_path/tmp_dir + wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz + tar -xzf uci_housing.tar.gz + touch donefile + cd ../.. # pwd: /Serving/_monitor_test + mkdir -p local_path/uci_housing_model + python -m paddle_serving_server.monitor \ + --type='general' --general_host='ftp://127.0.0.1:8000' \ + --remote_path='/tmp_dir' --remote_model_name='uci_housing_model' \ + --remote_donefile_name='donefile' --local_path='local_path' \ + --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \ + --local_tmp_path='___tmp' --interval='1' >/dev/null & + sleep 10 + if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then + echo "local_path/uci_housing_model/fluid_time_file not exist." + exit 1 + fi + ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill + rm -rf remote_path/* + rm -rf local_path/* + + ps -ef | grep "pyftpdlib" | grep -v grep | awk '{print $2}' | xargs kill + ;; + GPU): + pip install pyftpdlib + mkdir remote_path + mkdir local_path + cd remote_path # pwd: /Serving/_monitor_test/remote_path + check_cmd "python -m pyftpdlib -p 8000 &>/dev/null &" + cd .. # pwd: /Serving/_monitor_test + + # type: ftp + # remote_path: / + # remote_model_name: uci_housing.tar.gz + # local_tmp_path: ___tmp + # local_path: local_path + cd remote_path # pwd: /Serving/_monitor_test/remote_path + wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz + touch donefile + cd .. # pwd: /Serving/_monitor_test + mkdir -p local_path/uci_housing_model + python -m paddle_serving_server_gpu.monitor \ + --type='ftp' --ftp_host='127.0.0.1' --ftp_port='8000' \ + --remote_path='/' --remote_model_name='uci_housing.tar.gz' \ + --remote_donefile_name='donefile' --local_path='local_path' \ + --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \ + --local_tmp_path='___tmp' --unpacked_filename='uci_housing_model' \ + --interval='1' >/dev/null & + sleep 10 + if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then + echo "local_path/uci_housing_model/fluid_time_file not exist." + exit 1 + fi + ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill + rm -rf remote_path/* + rm -rf local_path/* + + # type: ftp + # remote_path: /tmp_dir + # remote_model_name: uci_housing_model + # local_tmp_path: ___tmp + # local_path: local_path + mkdir -p remote_path/tmp_dir && cd remote_path/tmp_dir # pwd: /Serving/_monitor_test/remote_path/tmp_dir + wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz + tar -xzf uci_housing.tar.gz + touch donefile + cd ../.. # pwd: /Serving/_monitor_test + mkdir -p local_path/uci_housing_model + python -m paddle_serving_server_gpu.monitor \ + --type='ftp' --ftp_host='127.0.0.1' --ftp_port='8000' \ + --remote_path='/tmp_dir' --remote_model_name='uci_housing_model' \ + --remote_donefile_name='donefile' --local_path='local_path' \ + --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \ + --local_tmp_path='___tmp' --interval='1' >/dev/null & + sleep 10 + if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then + echo "local_path/uci_housing_model/fluid_time_file not exist." + exit 1 + fi + ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill + rm -rf remote_path/* + rm -rf local_path/* + + # type: general + # remote_path: / + # remote_model_name: uci_housing.tar.gz + # local_tmp_path: ___tmp + # local_path: local_path + cd remote_path # pwd: /Serving/_monitor_test/remote_path + wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz + touch donefile + cd .. # pwd: /Serving/_monitor_test + mkdir -p local_path/uci_housing_model + python -m paddle_serving_server_gpu.monitor \ + --type='general' --general_host='ftp://127.0.0.1:8000' \ + --remote_path='/' --remote_model_name='uci_housing.tar.gz' \ + --remote_donefile_name='donefile' --local_path='local_path' \ + --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \ + --local_tmp_path='___tmp' --unpacked_filename='uci_housing_model' \ + --interval='1' >/dev/null & + sleep 10 + if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then + echo "local_path/uci_housing_model/fluid_time_file not exist." + exit 1 + fi + ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill + rm -rf remote_path/* + rm -rf local_path/* + + # type: general + # remote_path: /tmp_dir + # remote_model_name: uci_housing_model + # local_tmp_path: ___tmp + # local_path: local_path + mkdir -p remote_path/tmp_dir && cd remote_path/tmp_dir # pwd: /Serving/_monitor_test/remote_path/tmp_dir + wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz + tar -xzf uci_housing.tar.gz + touch donefile + cd ../.. # pwd: /Serving/_monitor_test + mkdir -p local_path/uci_housing_model + python -m paddle_serving_server_gpu.monitor \ + --type='general' --general_host='ftp://127.0.0.1:8000' \ + --remote_path='/tmp_dir' --remote_model_name='uci_housing_model' \ + --remote_donefile_name='donefile' --local_path='local_path' \ + --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \ + --local_tmp_path='___tmp' --interval='1' >/dev/null & + sleep 10 + if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then + echo "local_path/uci_housing_model/fluid_time_file not exist." + exit 1 + fi + ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill + rm -rf remote_path/* + rm -rf local_path/* + + ps -ef | grep "pyftpdlib" | grep -v grep | awk '{print $2}' | xargs kill + ;; + *) + echo "error type" + exit 1 + ;; + esac + cd .. # pwd: /Serving + rm -rf _monitor_test + echo "test monitor $TYPE finished as expected." +} + function main() { local TYPE=$1 # pwd: / init # pwd: /Serving @@ -415,6 +666,7 @@ function main() { build_server $TYPE # pwd: /Serving build_app $TYPE # pwd: /Serving python_run_test $TYPE # pwd: /Serving + monitor_test $TYPE # pwd: /Serving echo "serving $TYPE part finished as expected." }