resnet50_web_service.py 2.6 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from paddle_serving_client import Client
M
MRXLT 已提交
16
from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize, Base64ToImage
D
dongdaxiang 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32

if len(sys.argv) != 4:
    print("python resnet50_web_service.py model device port")
    sys.exit(-1)

device = sys.argv[2]

if device == "cpu":
    from paddle_serving_server.web_service import WebService
else:
    from paddle_serving_server_gpu.web_service import WebService


class ImageService(WebService):
    def init_imagenet_setting(self):
        self.seq = Sequential([
M
MRXLT 已提交
33
            URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose(
D
dongdaxiang 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
                (2, 0, 1)), Div(255), Normalize([0.485, 0.456, 0.406],
                                                [0.229, 0.224, 0.225], True)
        ])
        self.label_dict = {}
        label_idx = 0
        with open("imagenet.label") as fin:
            for line in fin:
                self.label_dict[label_idx] = line.strip()
                label_idx += 1

    def preprocess(self, feed=[], fetch=[]):
        feed_batch = []
        for ins in feed:
            if "image" not in ins:
                raise ("feed data error!")
            img = self.seq(ins["image"])
            feed_batch.append({"image": img})
        return feed_batch, fetch

    def postprocess(self, feed=[], fetch=[], fetch_map={}):
        score_list = fetch_map["score"]
        result = {"label": [], "prob": []}
        for score in score_list:
M
bug fix  
MRXLT 已提交
57
            score = score.tolist()
D
dongdaxiang 已提交
58 59 60 61 62 63 64 65 66 67 68
            max_score = max(score)
            result["label"].append(self.label_dict[score.index(max_score)]
                                   .strip().replace(",", ""))
            result["prob"].append(max_score)
        return result


image_service = ImageService(name="image")
image_service.load_model_config(sys.argv[1])
image_service.init_imagenet_setting()
if device == "gpu":
W
wangjiawei04 已提交
69
    image_service.set_gpus("0")
D
dongdaxiang 已提交
70 71
image_service.prepare_server(
    workdir="workdir", port=int(sys.argv[3]), device=device)
M
MRXLT 已提交
72 73
image_service.run_rpc_service()
image_service.run_web_service()