diff --git a/docs/zh_CN/extension/paddle_serving.md b/docs/zh_CN/extension/paddle_serving.md index 62b8cda8b74283fea701a86f6f05e5194b0a6da8..7b9102042b5018ab972bbd2542526186b5b402df 100644 --- a/docs/zh_CN/extension/paddle_serving.md +++ b/docs/zh_CN/extension/paddle_serving.md @@ -1,4 +1,65 @@ # 模型服务化部署 -[Paddle Serving](https://github.com/PaddlePaddle/Serving) 旨在帮助深度学习开发者轻易部署在线预测服务,支持一键部署工业级的服务能力、客户端和服务端之间高并发和高效通信、并支持多种编程语言开发客户端等特点,详细使用请参考 [Paddle Serving 相关文档](https://github.com/PaddlePaddle/Serving)。 +## 一、简介 +[Paddle Serving](https://github.com/PaddlePaddle/Serving) 旨在帮助深度学习开发者轻易部署在线预测服务,支持一键部署工业级的服务能力、客户端和服务端之间高并发和高效通信、并支持多种编程语言开发客户端。 +该部分以HTTP预测服务部署为例,介绍怎样在PaddleClas中使用PaddleServing部署模型服务。 + + +## 二、Serving安装 + +Serving官网推荐使用docker安装并部署Serving环境。首先需要拉取docker环境并创建基于Serving的docker。 + +```shell +nvidia-docker pull hub.baidubce.com/paddlepaddle/serving:0.2.0-gpu +nvidia-docker run -p 9292:9292 --name test -dit hub.baidubce.com/paddlepaddle/serving:0.2.0-gpu +nvidia-docker exec -it test bash +``` + +进入docker后,需要安装Serving相关的python包。 + +```shell +pip install paddlepaddle-gpu +pip install paddle-serving-client +pip install paddle-serving-server-gpu +``` + +* 如果安装速度太慢,可以通过`-i https://pypi.tuna.tsinghua.edu.cn/simple`更换源,加速安装过程。 + +* 如果希望部署CPU服务,可以安装serving-server的cpu版本,安装命令如下。 + +```shell +pip install paddle-serving-server +``` + +### 三、导出模型 + +使用`tools/export_serving_model.py`脚本导出Serving模型,以`ResNet50_vd`为例,使用方法如下。 + +```shell +python tools/export_serving_model.py -m ResNet50_vd -p ./pretrained/ResNet50_vd_pretrained/ -o serving +``` + +最终在serving文件夹下会生成`ppcls_client_conf`与`ppcls_model`两个文件夹,分别存储了client配置、模型参数与结构文件。 + + +### 四、服务部署与请求 + +* 使用下面的方式启动Serving服务。 + +```shell +python tools/serving/image_service_gpu.py serving/ppcls_model workdir 9292 +``` + +其中`serving/ppcls_model`为刚才保存的Serving模型地址,`workdir`为为工作目录,`9292`为服务的端口号。 + + +* 使用下面的脚本向Serving服务发送识别请求,并返回结果。 + +``` +python tools/serving/image_http_client.py 9292 ./docs/images/logo.png +``` + +`9292`为发送请求的端口号,需要与服务启动时的端口号保持一致,`./docs/images/logo.png`为待识别的图像文件。最终返回Top1识别结果的类别ID以及概率值。 + +* 更多的服务部署类型,如`RPC预测服务`等,可以参考Serving的github官网:[https://github.com/PaddlePaddle/Serving/tree/develop/python/examples/imagenet](https://github.com/PaddlePaddle/Serving/tree/develop/python/examples/imagenet) diff --git a/tools/export_serving_model.py b/tools/export_serving_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e4472cdbf8dfd1738dede98b5aa61121f8191a --- /dev/null +++ b/tools/export_serving_model.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from ppcls.modeling import architectures + +import paddle.fluid as fluid +import paddle_serving_client.io as serving_io + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model", type=str) + parser.add_argument("-p", "--pretrained_model", type=str) + parser.add_argument("-o", "--output_path", type=str, default="") + parser.add_argument("--class_dim", type=int, default=1000) + parser.add_argument("--img_size", type=int, default=224) + + return parser.parse_args() + + +def create_input(img_size=224): + image = fluid.data( + name='image', shape=[None, 3, img_size, img_size], dtype='float32') + return image + + +def create_model(args, model, input, class_dim=1000): + if args.model == "GoogLeNet": + out, _, _ = model.net(input=input, class_dim=class_dim) + else: + out = model.net(input=input, class_dim=class_dim) + out = fluid.layers.softmax(out) + return out + + +def main(): + args = parse_args() + + model = architectures.__dict__[args.model]() + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + startup_prog = fluid.Program() + infer_prog = fluid.Program() + + with fluid.program_guard(infer_prog, startup_prog): + with fluid.unique_name.guard(): + image = create_input(args.img_size) + out = create_model(args, model, image, class_dim=args.class_dim) + + infer_prog = infer_prog.clone(for_test=True) + fluid.load( + program=infer_prog, model_path=args.pretrained_model, executor=exe) + + model_path = os.path.join(args.output_path, "ppcls_model") + conf_path = os.path.join(args.output_path, "ppcls_client_conf") + serving_io.save_model(model_path, conf_path, {"image": image}, + {"prediction": out}, infer_prog) + + +if __name__ == "__main__": + main() diff --git a/tools/serving/image_http_client.py b/tools/serving/image_http_client.py new file mode 100644 index 0000000000000000000000000000000000000000..3b92091c659613c83e4423a3f22b0d4d20321f43 --- /dev/null +++ b/tools/serving/image_http_client.py @@ -0,0 +1,47 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +import base64 +import json +import sys +import numpy as np + +py_version = sys.version_info[0] + + +def predict(image_path, server): + if py_version == 2: + image = base64.b64encode(open(image_path).read()) + else: + image = base64.b64encode(open(image_path, "rb").read()).decode("utf-8") + req = json.dumps({"feed": [{"image": image}], "fetch": ["prediction"]}) + r = requests.post( + server, data=req, headers={"Content-Type": "application/json"}) + try: + pred = r.json()["result"]["prediction"][0] + cls_id = np.argmax(pred) + score = pred[cls_id] + pred = {"cls_id": cls_id, "score": score} + return pred + except ValueError: + print(r.text) + return r + + +if __name__ == "__main__": + server = "http://127.0.0.1:{}/image/prediction".format(sys.argv[1]) + image_file = sys.argv[2] + res = predict(image_file, server) + print("res:", res) diff --git a/tools/serving/image_service_cpu.py b/tools/serving/image_service_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..92f67d3220670ffa880ff663ed887984325a0723 --- /dev/null +++ b/tools/serving/image_service_cpu.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import base64 +from paddle_serving_server.web_service import WebService +import utils + + +class ImageService(WebService): + def __init__(self, name): + super(ImageService, self).__init__(name=name) + self.operators = self.create_operators() + + def create_operators(self): + size = 224 + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + img_scale = 1.0 / 255.0 + decode_op = utils.DecodeImage() + resize_op = utils.ResizeImage(resize_short=256) + crop_op = utils.CropImage(size=(size, size)) + normalize_op = utils.NormalizeImage( + scale=img_scale, mean=img_mean, std=img_std) + totensor_op = utils.ToTensor() + return [decode_op, resize_op, crop_op, normalize_op, totensor_op] + + def _process_image(self, data, ops): + for op in ops: + data = op(data) + return data + + def preprocess(self, feed={}, fetch=[]): + feed_batch = [] + for ins in feed: + if "image" not in ins: + raise ("feed data error!") + sample = base64.b64decode(ins["image"]) + img = self._process_image(sample, self.operators) + feed_batch.append({"image": img}) + return feed_batch, fetch + + +image_service = ImageService(name="image") +image_service.load_model_config(sys.argv[1]) +image_service.prepare_server( + workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu") +image_service.run_server() +image_service.run_flask() diff --git a/tools/serving/image_service_gpu.py b/tools/serving/image_service_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..df61cdd60659713ac77176beb8c1ecfad1c8efd8 --- /dev/null +++ b/tools/serving/image_service_gpu.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import base64 +from paddle_serving_server_gpu.web_service import WebService + +import utils + + +class ImageService(WebService): + def __init__(self, name): + super(ImageService, self).__init__(name=name) + self.operators = self.create_operators() + + def create_operators(self): + size = 224 + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + img_scale = 1.0 / 255.0 + decode_op = utils.DecodeImage() + resize_op = utils.ResizeImage(resize_short=256) + crop_op = utils.CropImage(size=(size, size)) + normalize_op = utils.NormalizeImage( + scale=img_scale, mean=img_mean, std=img_std) + totensor_op = utils.ToTensor() + return [decode_op, resize_op, crop_op, normalize_op, totensor_op] + + def _process_image(self, data, ops): + for op in ops: + data = op(data) + return data + + def preprocess(self, feed={}, fetch=[]): + feed_batch = [] + for ins in feed: + if "image" not in ins: + raise ("feed data error!") + sample = base64.b64decode(ins["image"]) + img = self._process_image(sample, self.operators) + feed_batch.append({"image": img}) + return feed_batch, fetch + + +image_service = ImageService(name="image") +image_service.load_model_config(sys.argv[1]) +image_service.set_gpus("0") +image_service.prepare_server( + workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu") +image_service.run_server() +image_service.run_flask() diff --git a/tools/serving/utils.py b/tools/serving/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4a75e1afe2fc1e1710c7e8213f8ac4de8ffcc2 --- /dev/null +++ b/tools/serving/utils.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +class DecodeImage(object): + def __init__(self, to_rgb=True): + self.to_rgb = to_rgb + + def __call__(self, img): + data = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(data, 1) + if self.to_rgb: + assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( + img.shape) + img = img[:, :, ::-1] + + return img + + +class ResizeImage(object): + def __init__(self, resize_short=None): + self.resize_short = resize_short + + def __call__(self, img): + img_h, img_w = img.shape[:2] + percent = float(self.resize_short) / min(img_w, img_h) + w = int(round(img_w * percent)) + h = int(round(img_h * percent)) + return cv2.resize(img, (w, h)) + + +class CropImage(object): + def __init__(self, size): + if type(size) is int: + self.size = (size, size) + else: + self.size = size + + def __call__(self, img): + w, h = self.size + img_h, img_w = img.shape[:2] + w_start = (img_w - w) // 2 + h_start = (img_h - h) // 2 + + w_end = w_start + w + h_end = h_start + h + return img[h_start:h_end, w_start:w_end, :] + + +class NormalizeImage(object): + def __init__(self, scale=None, mean=None, std=None): + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, img): + return (img.astype('float32') * self.scale - self.mean) / self.std + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, img): + img = img.transpose((2, 0, 1)) + return img