web_service.py 13.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!flask/bin/python
B
barrierye 已提交
15 16
# pylint: disable=doc-string-missing

17
from flask import Flask, request, abort
M
MRXLT 已提交
18
from contextlib import closing
Z
zhangjun 已提交
19 20
from multiprocessing import Pool, Process, Queue
from paddle_serving_client import Client
Z
zhangjun 已提交
21 22
from paddle_serving_server import OpMaker, OpSeqMaker, Server
from paddle_serving_server.serve import start_multi_card
M
MRXLT 已提交
23
import socket
Z
zhangjun 已提交
24
import sys
W
wangjiawei04 已提交
25
import numpy as np
H
HexToString 已提交
26
import os
Z
zhangjun 已提交
27 28
from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op
B
barrierye 已提交
29

H
HexToString 已提交
30

H
HexToString 已提交
31 32 33 34 35 36 37 38 39
def port_is_available(port):
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
        sock.settimeout(2)
        result = sock.connect_ex(('0.0.0.0', port))
    if result != 0:
        return True
    else:
        return False

H
HexToString 已提交
40

41 42 43
class WebService(object):
    def __init__(self, name="default_service"):
        self.name = name
B
barriery 已提交
44
        # pipeline
B
barriery 已提交
45
        self._server = pipeline.PipelineServer(self.name)
46

Z
zhangjun 已提交
47 48 49
        self.gpus = []  # deprecated
        self.rpc_service_list = []  # deprecated

B
barriery 已提交
50 51
    def get_pipeline_response(self, read_op):
        return None
52

B
barriery 已提交
53 54 55 56 57 58 59 60 61 62 63
    def prepare_pipeline_config(self, yaml_file):
        # build dag
        read_op = pipeline.RequestOp()
        last_op = self.get_pipeline_response(read_op)
        if not isinstance(last_op, Op):
            raise ValueError("The return value type of `get_pipeline_response` "
                             "function is not Op type, please check function "
                             "`get_pipeline_response`.")
        response_op = pipeline.ResponseOp(input_ops=[last_op])
        self._server.set_response_op(response_op)
        self._server.prepare_server(yaml_file)
64 65

    def run_service(self):
B
barriery 已提交
66
        self._server.run_server()
67

H
HexToString 已提交
68 69 70
    def load_model_config(self,
                          server_config_dir_paths,
                          client_config_path=None):
H
HexToString 已提交
71 72 73 74 75 76 77 78 79
        if isinstance(server_config_dir_paths, str):
            server_config_dir_paths = [server_config_dir_paths]
        elif isinstance(server_config_dir_paths, list):
            pass

        for single_model_config in server_config_dir_paths:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
H
HexToString 已提交
80 81
                raise ValueError(
                    "The input of --model should be a dir not file.")
H
HexToString 已提交
82
        self.server_config_dir_paths = server_config_dir_paths
83 84
        from .proto import general_model_config_pb2 as m_config
        import google.protobuf.text_format
H
HexToString 已提交
85 86
        file_path_list = []
        for single_model_config in self.server_config_dir_paths:
H
HexToString 已提交
87 88 89
            file_path_list.append("{}/serving_server_conf.prototxt".format(
                single_model_config))

90
        model_conf = m_config.GeneralModelConfig()
H
HexToString 已提交
91
        f = open(file_path_list[0], 'r')
92 93
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
W
wangjiawei04 已提交
94
        self.feed_vars = {var.name: var for var in model_conf.feed_var}
H
HexToString 已提交
95 96 97 98 99 100 101

        if len(file_path_list) > 1:
            model_conf = m_config.GeneralModelConfig()
            f = open(file_path_list[-1], 'r')
            model_conf = google.protobuf.text_format.Merge(
                str(f.read()), model_conf)

W
wangjiawei04 已提交
102
        self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
H
HexToString 已提交
103
        if client_config_path == None:
H
HexToString 已提交
104
            self.client_config_path = file_path_list
105

Z
zhangjun 已提交
106 107 108 109 110 111 112 113 114 115 116 117
    def set_gpus(self, gpus):
        print("This API will be deprecated later. Please do not use it")
        self.gpus = [int(x) for x in gpus.split(",")]

    def default_rpc_service(self,
                            workdir="conf",
                            port=9292,
                            gpuid=0,
                            thread_num=2,
                            mem_optim=True,
                            use_lite=False,
                            use_xpu=False,
118 119 120
                            ir_optim=False,
                            precision="fp32",
                            use_calib=False):
Z
zhangjun 已提交
121 122 123 124 125 126
        device = "gpu"
        if gpuid == -1:
            if use_lite:
                device = "arm"
            else:
                device = "cpu"
127
        op_maker = OpMaker()
128
        op_seq_maker = OpSeqMaker()
H
HexToString 已提交
129 130

        read_op = op_maker.create('general_reader')
131
        op_seq_maker.add_op(read_op)
H
HexToString 已提交
132 133 134 135 136 137 138 139 140

        for idx, single_model in enumerate(self.server_config_dir_paths):
            infer_op_name = "general_infer"
            if len(self.server_config_dir_paths) == 2 and idx == 0:
                infer_op_name = "general_detection"
            else:
                infer_op_name = "general_infer"
            general_infer_op = op_maker.create(infer_op_name)
            op_seq_maker.add_op(general_infer_op)
H
HexToString 已提交
141

H
HexToString 已提交
142
        general_response_op = op_maker.create('general_response')
143
        op_seq_maker.add_op(general_response_op)
Z
zhangjun 已提交
144

145 146
        server = Server()
        server.set_op_sequence(op_seq_maker.get_op_sequence())
Z
zhangjun 已提交
147 148 149 150
        server.set_num_threads(thread_num)
        server.set_memory_optimize(mem_optim)
        server.set_ir_optimize(ir_optim)
        server.set_device(device)
151 152
        server.set_precision(precision)
        server.set_use_calib(use_calib)
Z
zhangjun 已提交
153 154 155 156 157 158

        if use_lite:
            server.set_lite()
        if use_xpu:
            server.set_xpu()

H
HexToString 已提交
159 160
        server.load_model_config(self.server_config_dir_paths
                                 )  #brpc Server support server_config_dir_paths
Z
zhangjun 已提交
161 162 163 164 165 166 167
        if gpuid >= 0:
            server.set_gpuid(gpuid)
        server.prepare_server(workdir=workdir, port=port, device=device)
        return server

    def _launch_rpc_service(self, service_idx):
        self.rpc_service_list[service_idx].run_server()
M
MRXLT 已提交
168

H
HexToString 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    def create_rpc_config(self):
        if len(self.gpus) == 0:
            # init cpu service
            self.rpc_service_list.append(
                self.default_rpc_service(
                    self.workdir,
                    self.port_list[0],
                    -1,
                    thread_num=self.thread_num,
                    mem_optim=self.mem_optim,
                    use_lite=self.use_lite,
                    use_xpu=self.use_xpu,
                    ir_optim=self.ir_optim,
                    precision=self.precision,
                    use_calib=self.use_calib))
        else:
            for i, gpuid in enumerate(self.gpus):
                self.rpc_service_list.append(
                    self.default_rpc_service(
                        "{}_{}".format(self.workdir, i),
                        self.port_list[i],
                        gpuid,
                        thread_num=self.thread_num,
                        mem_optim=self.mem_optim,
                        use_lite=self.use_lite,
                        use_xpu=self.use_xpu,
                        ir_optim=self.ir_optim,
                        precision=self.precision,
                        use_calib=self.use_calib))

M
MRXLT 已提交
199 200 201
    def prepare_server(self,
                       workdir="",
                       port=9393,
Z
zhangjun 已提交
202
                       device="gpu",
203 204
                       precision="fp32",
                       use_calib=False,
Z
zhangjun 已提交
205 206 207 208
                       use_lite=False,
                       use_xpu=False,
                       ir_optim=False,
                       gpuid=0,
H
HexToString 已提交
209
                       thread_num=2,
Z
zhangjun 已提交
210
                       mem_optim=True):
B
barriery 已提交
211
        print("This API will be deprecated later. Please do not use it")
212 213
        self.workdir = workdir
        self.port = port
H
HexToString 已提交
214
        self.thread_num = thread_num
215
        self.device = device
H
HexToString 已提交
216 217 218 219 220 221
        self.precision = precision
        self.use_calib = use_calib
        self.use_lite = use_lite
        self.use_xpu = use_xpu
        self.ir_optim = ir_optim
        self.mem_optim = mem_optim
Z
zhangjun 已提交
222
        self.gpuid = gpuid
M
MRXLT 已提交
223
        self.port_list = []
Z
zhangjun 已提交
224
        default_port = 12000
M
MRXLT 已提交
225
        for i in range(1000):
W
wangjiawei04 已提交
226
            if port_is_available(default_port + i):
M
MRXLT 已提交
227
                self.port_list.append(default_port + i)
Z
zhangjun 已提交
228
            if len(self.port_list) > len(self.gpus):
M
MRXLT 已提交
229
                break
230 231

    def _launch_web_service(self):
Z
zhangjun 已提交
232
        gpu_num = len(self.gpus)
M
MRXLT 已提交
233
        self.client = Client()
H
HexToString 已提交
234
        self.client.load_client_config(self.client_config_path)
Z
zhangjun 已提交
235 236 237 238 239 240 241
        endpoints = ""
        if gpu_num > 0:
            for i in range(gpu_num):
                endpoints += "127.0.0.1:{},".format(self.port_list[i])
        else:
            endpoints = "127.0.0.1:{}".format(self.port_list[0])
        self.client.connect([endpoints])
B
barrierye 已提交
242

D
dongdaxiang 已提交
243
    def get_prediction(self, request):
D
dongdaxiang 已提交
244 245 246 247 248
        if not request.json:
            abort(400)
        if "fetch" not in request.json:
            abort(400)
        try:
249 250
            feed, fetch, is_batch = self.preprocess(request.json["feed"],
                                                    request.json["fetch"])
B
barrierye 已提交
251 252
            if isinstance(feed, dict) and "fetch" in feed:
                del feed["fetch"]
W
wangjiawei04 已提交
253 254
            if len(feed) == 0:
                raise ValueError("empty input")
255 256
            fetch_map = self.client.predict(
                feed=feed, fetch=fetch, batch=is_batch)
G
gongweibao 已提交
257
            result = self.postprocess(
M
MRXLT 已提交
258
                feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
G
gongweibao 已提交
259
            result = {"result": result}
M
bug fix  
MRXLT 已提交
260
        except ValueError as err:
M
MRXLT 已提交
261
            result = {"result": str(err)}
D
dongdaxiang 已提交
262
        return result
263

M
MRXLT 已提交
264
    def run_rpc_service(self):
B
barriery 已提交
265
        print("This API will be deprecated later. Please do not use it")
266 267 268
        import socket
        localIP = socket.gethostbyname(socket.gethostname())
        print("web service address:")
B
barrierye 已提交
269 270
        print("http://{}:{}/{}/prediction".format(localIP, self.port,
                                                  self.name))
Z
zhangjun 已提交
271
        server_pros = []
H
HexToString 已提交
272
        self.create_rpc_config()
Z
zhangjun 已提交
273 274 275 276 277
        for i, service in enumerate(self.rpc_service_list):
            p = Process(target=self._launch_rpc_service, args=(i, ))
            server_pros.append(p)
        for p in server_pros:
            p.start()
278

M
MRXLT 已提交
279 280 281 282 283 284 285 286 287 288 289 290
        app_instance = Flask(__name__)

        @app_instance.before_first_request
        def init():
            self._launch_web_service()

        service_name = "/" + self.name + "/prediction"

        @app_instance.route(service_name, methods=["POST"])
        def run():
            return self.get_prediction(request)

M
MRXLT 已提交
291 292
        self.app_instance = app_instance

Z
zhangjun 已提交
293 294 295
    # TODO: maybe change another API name: maybe run_local_predictor?
    def run_debugger_service(self, gpu=False):
        print("This API will be deprecated later. Please do not use it")
W
wangjiawei04 已提交
296 297 298 299 300 301 302 303 304
        import socket
        localIP = socket.gethostbyname(socket.gethostname())
        print("web service address:")
        print("http://{}:{}/{}/prediction".format(localIP, self.port,
                                                  self.name))
        app_instance = Flask(__name__)

        @app_instance.before_first_request
        def init():
Z
zhangjun 已提交
305
            self._launch_local_predictor(gpu)
W
wangjiawei04 已提交
306 307 308 309 310 311 312 313 314

        service_name = "/" + self.name + "/prediction"

        @app_instance.route(service_name, methods=["POST"])
        def run():
            return self.get_prediction(request)

        self.app_instance = app_instance

Z
zhangjun 已提交
315
    def _launch_local_predictor(self, gpu):
H
HexToString 已提交
316 317 318 319
        # actually, LocalPredictor is like a server, but it is WebService Request initiator
        # for WebService it is a Client.
        # local_predictor only support single-Model DirPath - Type:str
        # so the input must be self.server_config_dir_paths[0]
W
wangjiawei04 已提交
320 321
        from paddle_serving_app.local_predict import LocalPredictor
        self.client = LocalPredictor()
Z
zhangjun 已提交
322
        if gpu:
H
HexToString 已提交
323 324 325 326
            # if user forget to call function `set_gpus` to set self.gpus.
            # default self.gpus = [0].
            if len(self.gpus) == 0:
                self.gpus.append(0)
H
HexToString 已提交
327 328 329 330
            self.client.load_model_config(
                self.server_config_dir_paths[0],
                use_gpu=True,
                gpu_id=self.gpus[0])
Z
zhangjun 已提交
331
        else:
H
HexToString 已提交
332 333
            self.client.load_model_config(
                self.server_config_dir_paths[0], use_gpu=False)
W
wangjiawei04 已提交
334

M
MRXLT 已提交
335
    def run_web_service(self):
B
barriery 已提交
336
        print("This API will be deprecated later. Please do not use it")
337
        self.app_instance.run(host="0.0.0.0", port=self.port, threaded=True)
M
MRXLT 已提交
338 339 340

    def get_app_instance(self):
        return self.app_instance
M
MRXLT 已提交
341

M
MRXLT 已提交
342
    def preprocess(self, feed=[], fetch=[]):
B
barriery 已提交
343
        print("This API will be deprecated later. Please do not use it")
344
        is_batch = True
W
wangjiawei04 已提交
345 346 347 348 349
        feed_dict = {}
        for var_name in self.feed_vars.keys():
            feed_dict[var_name] = []
        for feed_ins in feed:
            for key in feed_ins:
W
wangjiawei04 已提交
350 351 352
                feed_dict[key].append(
                    np.array(feed_ins[key]).reshape(
                        list(self.feed_vars[key].shape))[np.newaxis, :])
W
wangjiawei04 已提交
353 354
        feed = {}
        for key in feed_dict:
W
wangjiawei04 已提交
355
            feed[key] = np.concatenate(feed_dict[key], axis=0)
356
        return feed, fetch, is_batch
357

M
MRXLT 已提交
358
    def postprocess(self, feed=[], fetch=[], fetch_map=None):
B
barriery 已提交
359
        print("This API will be deprecated later. Please do not use it")
M
bug fix  
MRXLT 已提交
360 361
        for key in fetch_map:
            fetch_map[key] = fetch_map[key].tolist()
362
        return fetch_map