local_service_handler.py 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
barriery 已提交
15
import os
16
import logging
B
barriery 已提交
17
import multiprocessing
18 19
try:
    from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
B
barriery 已提交
20
    PACKAGE_VERSION = "GPU"
B
fix cpu  
barriery 已提交
21 22
except ImportError:
    from paddle_serving_server import OpMaker, OpSeqMaker, Server
B
barriery 已提交
23
    PACKAGE_VERSION = "CPU"
B
barriery 已提交
24
from . import util
W
wangjiawei04 已提交
25
from paddle_serving_app.local_predict import LocalPredictor
26 27

_LOGGER = logging.getLogger(__name__)
B
barriery 已提交
28
_workdir_name_gen = util.NameGenerator("workdir_")
29 30


W
fix bug  
wangjiawei04 已提交
31
class LocalServiceHandler(object):
32
    def __init__(self,
B
barriery 已提交
33
                 model_config,
W
wangjiawei04 已提交
34
                 client_type='local_predictor',
B
barriery 已提交
35
                 workdir="",
36 37 38 39 40 41
                 thread_num=2,
                 devices="",
                 mem_optim=True,
                 ir_optim=False,
                 available_port_generator=None):
        if available_port_generator is None:
B
barriery 已提交
42
            available_port_generator = util.GetAvailablePortGenerator()
43

B
barriery 已提交
44
        self._model_config = model_config
45 46 47 48 49
        self._port_list = []
        if devices == "":
            # cpu
            devices = [-1]
            self._port_list.append(available_port_generator.next())
B
barriery 已提交
50 51
            _LOGGER.info("Model({}) will be launch in cpu device. Port({})"
                         .format(model_config, self._port_list))
52 53
        else:
            # gpu
B
barriery 已提交
54 55 56 57
            if PACKAGE_VERSION == "CPU":
                raise ValueError(
                    "You are using the CPU version package("
                    "paddle-serving-server), unable to set devices")
58 59 60
            devices = [int(x) for x in devices.split(",")]
            for _ in devices:
                self._port_list.append(available_port_generator.next())
B
barriery 已提交
61 62
            _LOGGER.info("Model({}) will be launch in gpu device: {}. Port({})"
                         .format(model_config, devices, self._port_list))
W
wangjiawei04 已提交
63
        self.client_type = client_type
64 65 66 67 68
        self._workdir = workdir
        self._devices = devices
        self._thread_num = thread_num
        self._mem_optim = mem_optim
        self._ir_optim = ir_optim
W
wangjiawei04 已提交
69
        self.local_predictor_client = None
70 71 72 73 74 75 76 77 78 79
        self._rpc_service_list = []
        self._server_pros = []
        self._fetch_vars = None

    def get_fetch_list(self):
        return self._fetch_vars

    def get_port_list(self):
        return self._port_list

W
wangjiawei04 已提交
80 81 82 83 84 85 86
    def get_client(self):  # for local_predictor_only
        if self.local_predictor_client is None:
            self.local_predictor_client = LocalPredictor()
            self.local_predictor_client.load_model_config(
                "{}".format(self._model_config), gpu=False, profile=False)
        return self.local_predictor_client

B
barriery 已提交
87 88
    def get_client_config(self):
        return os.path.join(self._model_config, "serving_server_conf.prototxt")
89 90 91 92 93 94

    def _prepare_one_server(self, workdir, port, gpuid, thread_num, mem_optim,
                            ir_optim):
        device = "gpu"
        if gpuid == -1:
            device = "cpu"
B
barriery 已提交
95
        op_maker = OpMaker()
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        read_op = op_maker.create('general_reader')
        general_infer_op = op_maker.create('general_infer')
        general_response_op = op_maker.create('general_response')

        op_seq_maker = OpSeqMaker()
        op_seq_maker.add_op(read_op)
        op_seq_maker.add_op(general_infer_op)
        op_seq_maker.add_op(general_response_op)

        server = Server()
        server.set_op_sequence(op_seq_maker.get_op_sequence())
        server.set_num_threads(thread_num)
        server.set_memory_optimize(mem_optim)
        server.set_ir_optimize(ir_optim)

        server.load_model_config(self._model_config)
        if gpuid >= 0:
            server.set_gpuid(gpuid)
        server.prepare_server(workdir=workdir, port=port, device=device)
        if self._fetch_vars is None:
            self._fetch_vars = server.get_fetch_list()
        return server

    def _start_one_server(self, service_idx):
        self._rpc_service_list[service_idx].run_server()

    def prepare_server(self):
        for i, device_id in enumerate(self._devices):
B
barriery 已提交
124
            if self._workdir != "":
125 126 127 128 129 130 131 132 133 134 135 136 137 138
                workdir = "{}_{}".format(self._workdir, i)
            else:
                workdir = _workdir_name_gen.next()
            self._rpc_service_list.append(
                self._prepare_one_server(
                    workdir,
                    self._port_list[i],
                    device_id,
                    thread_num=self._thread_num,
                    mem_optim=self._mem_optim,
                    ir_optim=self._ir_optim))

    def start_server(self):
        for i, service in enumerate(self._rpc_service_list):
B
barriery 已提交
139 140 141
            p = multiprocessing.Process(
                target=self._start_one_server, args=(i, ))
            p.daemon = True
142 143 144
            self._server_pros.append(p)
        for p in self._server_pros:
            p.start()