__init__.py 31.6 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
B
barrierye 已提交
14
# pylint: disable=doc-string-missing
M
MRXLT 已提交
15 16 17 18 19 20

import os
from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import tarfile
M
MRXLT 已提交
21
import socket
22
import paddle_serving_server_gpu as paddle_serving_server
23
import time
24
from .version import serving_server_version
M
MRXLT 已提交
25
from contextlib import closing
G
guru4elephant 已提交
26
import argparse
B
barrierye 已提交
27
import collections
M
MRXLT 已提交
28
import fcntl
M
MRXLT 已提交
29
import shutil
B
barrierye 已提交
30 31 32
import numpy as np
import grpc
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
33 34 35
import sys
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
36 37 38 39
from .proto import multi_lang_general_model_service_pb2_grpc
from multiprocessing import Pool, Process
from concurrent import futures

B
barrierye 已提交
40

41 42 43
def serve_args():
    parser = argparse.ArgumentParser("serve")
    parser.add_argument(
M
MRXLT 已提交
44
        "--thread", type=int, default=2, help="Concurrency of server")
45 46 47 48 49 50 51 52 53 54 55
    parser.add_argument(
        "--model", type=str, default="", help="Model for serving")
    parser.add_argument(
        "--port", type=int, default=9292, help="Port of the starting gpu")
    parser.add_argument(
        "--workdir",
        type=str,
        default="workdir",
        help="Working dir of current service")
    parser.add_argument(
        "--device", type=str, default="gpu", help="Type of device")
B
barrierye 已提交
56
    parser.add_argument("--gpu_ids", type=str, default="", help="gpu ids")
57
    parser.add_argument(
58
        "--name", type=str, default="None", help="Default service name")
M
MRXLT 已提交
59
    parser.add_argument(
M
MRXLT 已提交
60
        "--mem_optim_off",
M
MRXLT 已提交
61 62 63
        default=False,
        action="store_true",
        help="Memory optimize")
M
MRXLT 已提交
64
    parser.add_argument(
M
MRXLT 已提交
65
        "--ir_optim", default=False, action="store_true", help="Graph optimize")
M
MRXLT 已提交
66 67 68
    parser.add_argument(
        "--max_body_size",
        type=int,
M
MRXLT 已提交
69
        default=512 * 1024 * 1024,
M
MRXLT 已提交
70
        help="Limit sizes of messages")
B
barrierye 已提交
71 72 73 74 75
    parser.add_argument(
        "--use_multilang",
        default=False,
        action="store_true",
        help="Use Multi-language-service")
M
add trt  
MRXLT 已提交
76 77
    parser.add_argument(
        "--use_trt", default=False, action="store_true", help="Use TensorRT")
78 79 80 81 82 83 84 85 86 87
    parser.add_argument(
        "--product_name",
        type=str,
        default=None,
        help="product_name for authentication")
    parser.add_argument(
        "--container_id",
        type=str,
        default=None,
        help="container_id for authentication")
88
    return parser.parse_args()
M
MRXLT 已提交
89

B
barrierye 已提交
90

M
MRXLT 已提交
91 92 93
class OpMaker(object):
    def __init__(self):
        self.op_dict = {
M
MRXLT 已提交
94 95 96 97 98 99
            "general_infer": "GeneralInferOp",
            "general_reader": "GeneralReaderOp",
            "general_response": "GeneralResponseOp",
            "general_text_reader": "GeneralTextReaderOp",
            "general_text_response": "GeneralTextResponseOp",
            "general_single_kv": "GeneralSingleKVOp",
W
wangjiawei04 已提交
100
            "general_dist_kv_infer": "GeneralDistKVInferOp",
M
MRXLT 已提交
101
            "general_dist_kv": "GeneralDistKVOp"
M
MRXLT 已提交
102
        }
B
barrierye 已提交
103
        self.node_name_suffix_ = collections.defaultdict(int)
M
MRXLT 已提交
104

B
barrierye 已提交
105 106 107 108
    def create(self, node_type, engine_name=None, inputs=[], outputs=[]):
        if node_type not in self.op_dict:
            raise Exception("Op type {} is not supported right now".format(
                node_type))
M
MRXLT 已提交
109
        node = server_sdk.DAGNode()
B
barrierye 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        # node.name will be used as the infer engine name
        if engine_name:
            node.name = engine_name
        else:
            node.name = '{}_{}'.format(node_type,
                                       self.node_name_suffix_[node_type])
            self.node_name_suffix_[node_type] += 1

        node.type = self.op_dict[node_type]
        if inputs:
            for dep_node_str in inputs:
                dep_node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(dep_node_str, dep_node)
                dep = server_sdk.DAGNodeDependency()
                dep.name = dep_node.name
                dep.mode = "RO"
                node.dependencies.extend([dep])
        # Because the return value will be used as the key value of the
        # dict, and the proto object is variable which cannot be hashed,
        # so it is processed into a string. This has little effect on
        # overall efficiency.
        return google.protobuf.text_format.MessageToString(node)
M
MRXLT 已提交
132 133 134 135 136 137 138 139


class OpSeqMaker(object):
    def __init__(self):
        self.workflow = server_sdk.Workflow()
        self.workflow.name = "workflow1"
        self.workflow.workflow_type = "Sequence"

B
barrierye 已提交
140 141 142 143 144 145 146
    def add_op(self, node_str):
        node = server_sdk.DAGNode()
        google.protobuf.text_format.Parse(node_str, node)
        if len(node.dependencies) > 1:
            raise Exception(
                'Set more than one predecessor for op in OpSeqMaker is not allowed.'
            )
M
MRXLT 已提交
147
        if len(self.workflow.nodes) >= 1:
B
barrierye 已提交
148 149 150 151 152 153 154 155
            if len(node.dependencies) == 0:
                dep = server_sdk.DAGNodeDependency()
                dep.name = self.workflow.nodes[-1].name
                dep.mode = "RO"
                node.dependencies.extend([dep])
            elif len(node.dependencies) == 1:
                if node.dependencies[0].name != self.workflow.nodes[-1].name:
                    raise Exception(
T
TeslaZhao 已提交
156
                        'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'
157
                        .format(node.dependencies[0].name, self.workflow.nodes[
B
barrierye 已提交
158
                            -1].name))
M
MRXLT 已提交
159 160 161 162 163 164 165 166
        self.workflow.nodes.extend([node])

    def get_op_sequence(self):
        workflow_conf = server_sdk.WorkflowConf()
        workflow_conf.workflows.extend([self.workflow])
        return workflow_conf


B
barrierye 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
class OpGraphMaker(object):
    def __init__(self):
        self.workflow = server_sdk.Workflow()
        self.workflow.name = "workflow1"
        # Currently, SDK only supports "Sequence"
        self.workflow.workflow_type = "Sequence"

    def add_op(self, node_str):
        node = server_sdk.DAGNode()
        google.protobuf.text_format.Parse(node_str, node)
        self.workflow.nodes.extend([node])

    def get_op_graph(self):
        workflow_conf = server_sdk.WorkflowConf()
        workflow_conf.workflows.extend([self.workflow])
        return workflow_conf


M
MRXLT 已提交
185 186 187 188 189 190 191
class Server(object):
    def __init__(self):
        self.server_handle_ = None
        self.infer_service_conf = None
        self.model_toolkit_conf = None
        self.resource_conf = None
        self.memory_optimization = False
M
MRXLT 已提交
192
        self.ir_optimization = False
M
MRXLT 已提交
193 194 195 196 197 198
        self.model_conf = None
        self.workflow_fn = "workflow.prototxt"
        self.resource_fn = "resource.prototxt"
        self.infer_service_fn = "infer_service.prototxt"
        self.model_toolkit_fn = "model_toolkit.prototxt"
        self.general_model_config_fn = "general_model.prototxt"
W
wangjiawei04 已提交
199
        self.cube_config_fn = "cube.conf"
M
MRXLT 已提交
200 201
        self.workdir = ""
        self.max_concurrency = 0
M
MRXLT 已提交
202
        self.num_threads = 2
M
MRXLT 已提交
203 204
        self.port = 8080
        self.reload_interval_s = 10
M
MRXLT 已提交
205
        self.max_body_size = 64 * 1024 * 1024
M
MRXLT 已提交
206 207
        self.module_path = os.path.dirname(paddle_serving_server.__file__)
        self.cur_path = os.getcwd()
M
MRXLT 已提交
208
        self.use_local_bin = False
M
MRXLT 已提交
209
        self.gpuid = 0
M
add trt  
MRXLT 已提交
210
        self.use_trt = False
B
barrierye 已提交
211
        self.model_config_paths = None  # for multi-model in a workflow
212 213
        self.product_name = None
        self.container_id = None
M
MRXLT 已提交
214

B
fix cpu  
barriery 已提交
215 216 217
    def get_fetch_list(self):
        fetch_names = [var.alias_name for var in self.model_conf.fetch_var]
        return fetch_names
M
MRXLT 已提交
218 219 220 221 222 223 224

    def set_max_concurrency(self, concurrency):
        self.max_concurrency = concurrency

    def set_num_threads(self, threads):
        self.num_threads = threads

M
MRXLT 已提交
225 226 227 228 229 230 231 232
    def set_max_body_size(self, body_size):
        if body_size >= self.max_body_size:
            self.max_body_size = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )

M
MRXLT 已提交
233 234 235 236 237 238 239 240 241
    def set_port(self, port):
        self.port = port

    def set_reload_interval(self, interval):
        self.reload_interval_s = interval

    def set_op_sequence(self, op_seq):
        self.workflow_conf = op_seq

B
barrierye 已提交
242 243 244
    def set_op_graph(self, op_graph):
        self.workflow_conf = op_graph

M
MRXLT 已提交
245 246 247
    def set_memory_optimize(self, flag=False):
        self.memory_optimization = flag

M
MRXLT 已提交
248 249 250
    def set_ir_optimize(self, flag=False):
        self.ir_optimization = flag

251 252 253 254 255 256 257 258 259 260
    def set_product_name(self, product_name=None):
        if product_name == None:
            raise ValueError("product_name can't be None.")
        self.product_name = product_name

    def set_container_id(self, container_id):
        if container_id == None:
            raise ValueError("container_id can't be None.")
        self.container_id = container_id

M
MRXLT 已提交
261 262 263 264
    def check_local_bin(self):
        if "SERVING_BIN" in os.environ:
            self.use_local_bin = True
            self.bin_path = os.environ["SERVING_BIN"]
M
MRXLT 已提交
265

M
MRXLT 已提交
266
    def check_cuda(self):
M
MRXLT 已提交
267 268 269
        if os.system("ls /dev/ | grep nvidia > /dev/null") == 0:
            pass
        else:
M
MRXLT 已提交
270
            raise SystemExit(
M
MRXLT 已提交
271
                "GPU not found, please check your environment or use cpu version by \"pip install paddle_serving_server\""
M
MRXLT 已提交
272 273
            )

M
MRXLT 已提交
274 275 276
    def set_gpuid(self, gpuid=0):
        self.gpuid = gpuid

M
bug fix  
MRXLT 已提交
277
    def set_trt(self):
M
add trt  
MRXLT 已提交
278 279
        self.use_trt = True

B
barrierye 已提交
280
    def _prepare_engine(self, model_config_paths, device):
M
MRXLT 已提交
281 282 283
        if self.model_toolkit_conf == None:
            self.model_toolkit_conf = server_sdk.ModelToolkitConf()

B
barrierye 已提交
284 285 286 287 288 289 290 291 292 293 294 295
        for engine_name, model_config_path in model_config_paths.items():
            engine = server_sdk.EngineDesc()
            engine.name = engine_name
            # engine.reloadable_meta = model_config_path + "/fluid_time_file"
            engine.reloadable_meta = self.workdir + "/fluid_time_file"
            os.system("touch {}".format(engine.reloadable_meta))
            engine.reloadable_type = "timestamp_ne"
            engine.runtime_thread_num = 0
            engine.batch_infer_size = 0
            engine.enable_batch_align = 0
            engine.model_data_path = model_config_path
            engine.enable_memory_optimization = self.memory_optimization
M
MRXLT 已提交
296
            engine.enable_ir_optimization = self.ir_optimization
B
barrierye 已提交
297 298
            engine.static_optimization = False
            engine.force_update_static_cache = False
M
add trt  
MRXLT 已提交
299
            engine.use_trt = self.use_trt
B
barrierye 已提交
300 301 302 303 304 305 306

            if device == "cpu":
                engine.type = "FLUID_CPU_ANALYSIS_DIR"
            elif device == "gpu":
                engine.type = "FLUID_GPU_ANALYSIS_DIR"

            self.model_toolkit_conf.engines.extend([engine])
M
MRXLT 已提交
307 308 309 310 311 312 313 314 315 316

    def _prepare_infer_service(self, port):
        if self.infer_service_conf == None:
            self.infer_service_conf = server_sdk.InferServiceConf()
            self.infer_service_conf.port = port
            infer_service = server_sdk.InferService()
            infer_service.name = "GeneralModelService"
            infer_service.workflows.extend(["workflow1"])
            self.infer_service_conf.services.extend([infer_service])

M
MRXLT 已提交
317
    def _prepare_resource(self, workdir, cube_conf):
318
        self.workdir = workdir
M
MRXLT 已提交
319 320 321 322 323
        if self.resource_conf == None:
            with open("{}/{}".format(workdir, self.general_model_config_fn),
                      "w") as fout:
                fout.write(str(self.model_conf))
            self.resource_conf = server_sdk.ResourceConf()
W
wangjiawei04 已提交
324 325 326 327 328
            for workflow in self.workflow_conf.workflows:
                for node in workflow.nodes:
                    if "dist_kv" in node.name:
                        self.resource_conf.cube_config_path = workdir
                        self.resource_conf.cube_config_file = self.cube_config_fn
M
MRXLT 已提交
329 330 331 332 333
                        if cube_conf == None:
                            raise ValueError(
                                "Please set the path of cube.conf while use dist_kv op."
                            )
                        shutil.copy(cube_conf, workdir)
M
MRXLT 已提交
334 335 336 337
            self.resource_conf.model_toolkit_path = workdir
            self.resource_conf.model_toolkit_file = self.model_toolkit_fn
            self.resource_conf.general_model_path = workdir
            self.resource_conf.general_model_file = self.general_model_config_fn
338 339 340 341
            if self.product_name != None:
                self.resource_conf.auth_product_name = self.product_name
            if self.container_id != None:
                self.resource_conf.auth_container_id = self.container_id
M
MRXLT 已提交
342 343 344 345 346

    def _write_pb_str(self, filepath, pb_obj):
        with open(filepath, "w") as fout:
            fout.write(str(pb_obj))

B
barrierye 已提交
347 348 349 350
    def load_model_config(self, model_config_paths):
        # At present, Serving needs to configure the model path in
        # the resource.prototxt file to determine the input and output
        # format of the workflow. To ensure that the input and output
B
barrierye 已提交
351
        # of multiple models are the same.
B
barrierye 已提交
352 353
        workflow_oi_config_path = None
        if isinstance(model_config_paths, str):
B
barrierye 已提交
354
            # If there is only one model path, use the default infer_op.
M
MRXLT 已提交
355
            # Because there are several infer_op type, we need to find
B
barrierye 已提交
356 357 358
            # it from workflow_conf.
            default_engine_names = [
                'general_infer_0', 'general_dist_kv_infer_0',
B
barrierye 已提交
359
                'general_dist_kv_quant_infer_0'
B
barrierye 已提交
360 361
            ]
            engine_name = None
B
barrierye 已提交
362
            for node in self.workflow_conf.workflows[0].nodes:
B
barrierye 已提交
363 364 365 366 367 368 369 370 371
                if node.name in default_engine_names:
                    engine_name = node.name
                    break
            if engine_name is None:
                raise Exception(
                    "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
                )
            self.model_config_paths = {engine_name: model_config_paths}
            workflow_oi_config_path = self.model_config_paths[engine_name]
B
barrierye 已提交
372 373 374 375 376 377 378 379
        elif isinstance(model_config_paths, dict):
            self.model_config_paths = {}
            for node_str, path in model_config_paths.items():
                node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(node_str, node)
                self.model_config_paths[node.name] = path
            print("You have specified multiple model paths, please ensure "
                  "that the input and output of multiple models are the same.")
M
MRXLT 已提交
380 381
            workflow_oi_config_path = list(self.model_config_paths.items())[0][
                1]
B
barrierye 已提交
382 383 384 385 386
        else:
            raise Exception("The type of model_config_paths must be str or "
                            "dict({op: model_path}), not {}.".format(
                                type(model_config_paths)))

M
MRXLT 已提交
387
        self.model_conf = m_config.GeneralModelConfig()
B
barrierye 已提交
388 389 390
        f = open(
            "{}/serving_server_conf.prototxt".format(workflow_oi_config_path),
            'r')
M
MRXLT 已提交
391 392 393 394 395 396 397 398
        self.model_conf = google.protobuf.text_format.Merge(
            str(f.read()), self.model_conf)
        # check config here
        # print config here

    def download_bin(self):
        os.chdir(self.module_path)
        need_download = False
M
MRXLT 已提交
399 400 401 402 403 404 405

        #acquire lock
        version_file = open("{}/version.py".format(self.module_path), "r")
        import re
        for line in version_file.readlines():
            if re.match("cuda_version", line):
                cuda_version = line.split("\"")[1]
M
bug fix  
MRXLT 已提交
406 407 408 409
                if cuda_version != "trt":
                    device_version = "serving-gpu-cuda" + cuda_version + "-"
                else:
                    device_version = "serving-gpu-" + cuda_version + "-"
M
MRXLT 已提交
410

411 412
        folder_name = device_version + serving_server_version
        tar_name = folder_name + ".tar.gz"
M
MRXLT 已提交
413
        bin_url = "https://paddle-serving.bj.bcebos.com/bin/" + tar_name
414 415 416 417
        self.server_path = os.path.join(self.module_path, folder_name)

        download_flag = "{}/{}.is_download".format(self.module_path,
                                                   folder_name)
M
MRXLT 已提交
418 419 420

        fcntl.flock(version_file, fcntl.LOCK_EX)

421 422 423 424 425
        if os.path.exists(download_flag):
            os.chdir(self.cur_path)
            self.bin_path = self.server_path + "/serving"
            return

M
MRXLT 已提交
426
        if not os.path.exists(self.server_path):
427 428
            os.system("touch {}/{}.is_download".format(self.module_path,
                                                       folder_name))
M
MRXLT 已提交
429
            print('Frist time run, downloading PaddleServing components ...')
M
MRXLT 已提交
430

M
MRXLT 已提交
431 432 433 434
            r = os.system('wget ' + bin_url + ' --no-check-certificate')
            if r != 0:
                if os.path.exists(tar_name):
                    os.remove(tar_name)
M
MRXLT 已提交
435
                raise SystemExit(
T
TeslaZhao 已提交
436 437
                    'Download failed, please check your network or permission of {}.'
                    .format(self.module_path))
M
MRXLT 已提交
438 439 440 441 442 443 444 445 446
            else:
                try:
                    print('Decompressing files ..')
                    tar = tarfile.open(tar_name)
                    tar.extractall()
                    tar.close()
                except:
                    if os.path.exists(exe_path):
                        os.remove(exe_path)
M
MRXLT 已提交
447
                    raise SystemExit(
T
TeslaZhao 已提交
448 449
                        'Decompressing failed, please check your permission of {} or disk space left.'
                        .format(self.module_path))
M
MRXLT 已提交
450 451
                finally:
                    os.remove(tar_name)
M
MRXLT 已提交
452
        #release lock
B
barrierye 已提交
453
        version_file.close()
M
MRXLT 已提交
454 455 456
        os.chdir(self.cur_path)
        self.bin_path = self.server_path + "/serving"

M
MRXLT 已提交
457 458 459 460 461
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
                       cube_conf=None):
M
MRXLT 已提交
462 463 464 465 466 467 468
        if workdir == None:
            workdir = "./tmp"
            os.system("mkdir {}".format(workdir))
        else:
            os.system("mkdir {}".format(workdir))
        os.system("touch {}/fluid_time_file".format(workdir))

M
MRXLT 已提交
469
        if not self.port_is_available(port):
G
gongweibao 已提交
470
            raise SystemExit("Port {} is already used".format(port))
M
MRXLT 已提交
471

G
guru4elephant 已提交
472
        self.set_port(port)
M
MRXLT 已提交
473
        self._prepare_resource(workdir, cube_conf)
B
barrierye 已提交
474
        self._prepare_engine(self.model_config_paths, device)
M
MRXLT 已提交
475 476 477 478 479 480 481 482 483 484 485 486 487
        self._prepare_infer_service(port)
        self.workdir = workdir

        infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
        workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
        resource_fn = "{}/{}".format(workdir, self.resource_fn)
        model_toolkit_fn = "{}/{}".format(workdir, self.model_toolkit_fn)

        self._write_pb_str(infer_service_fn, self.infer_service_conf)
        self._write_pb_str(workflow_fn, self.workflow_conf)
        self._write_pb_str(resource_fn, self.resource_conf)
        self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf)

M
MRXLT 已提交
488
    def port_is_available(self, port):
M
MRXLT 已提交
489 490
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
491
            result = sock.connect_ex(('0.0.0.0', port))
M
MRXLT 已提交
492 493 494 495 496
        if result != 0:
            return True
        else:
            return False

M
MRXLT 已提交
497 498 499
    def run_server(self):
        # just run server with system command
        # currently we do not load cube
M
MRXLT 已提交
500
        self.check_local_bin()
M
MRXLT 已提交
501 502
        if not self.use_local_bin:
            self.download_bin()
B
fix bug  
barrierye 已提交
503 504 505
            # wait for other process to download server bin
            while not os.path.exists(self.server_path):
                time.sleep(1)
M
MRXLT 已提交
506 507
        else:
            print("Use local bin : {}".format(self.bin_path))
M
MRXLT 已提交
508
        self.check_cuda()
M
MRXLT 已提交
509 510 511 512 513 514 515 516 517 518 519
        command = "{} " \
                  "-enable_model_toolkit " \
                  "-inferservice_path {} " \
                  "-inferservice_file {} " \
                  "-max_concurrency {} " \
                  "-num_threads {} " \
                  "-port {} " \
                  "-reload_interval_s {} " \
                  "-resource_path {} " \
                  "-resource_file {} " \
                  "-workflow_path {} " \
M
MRXLT 已提交
520 521
                  "-workflow_file {} " \
                  "-bthread_concurrency {} " \
M
MRXLT 已提交
522 523
                  "-gpuid {} " \
                  "-max_body_size {} ".format(
M
MRXLT 已提交
524 525 526 527 528 529 530 531 532 533
                      self.bin_path,
                      self.workdir,
                      self.infer_service_fn,
                      self.max_concurrency,
                      self.num_threads,
                      self.port,
                      self.reload_interval_s,
                      self.workdir,
                      self.resource_fn,
                      self.workdir,
M
MRXLT 已提交
534 535
                      self.workflow_fn,
                      self.num_threads,
M
MRXLT 已提交
536 537
                      self.gpuid,
                      self.max_body_size)
M
MRXLT 已提交
538 539
        print("Going to Run Comand")
        print(command)
540

M
MRXLT 已提交
541
        os.system(command)
B
barrierye 已提交
542 543


B
barrierye 已提交
544 545 546
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
                                     MultiLangGeneralModelServiceServicer):
    def __init__(self, model_config_path, is_multi_model, endpoints):
B
barrierye 已提交
547
        self.is_multi_model_ = is_multi_model
B
barrierye 已提交
548 549 550 551 552 553 554 555
        self.model_config_path_ = model_config_path
        self.endpoints_ = endpoints
        with open(self.model_config_path_) as f:
            self.model_config_str_ = str(f.read())
        self._parse_model_config(self.model_config_str_)
        self._init_bclient(self.model_config_path_, self.endpoints_)

    def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
B
barrierye 已提交
556 557
        from paddle_serving_client import Client
        self.bclient_ = Client()
B
barrierye 已提交
558 559
        if timeout_ms is not None:
            self.bclient_.set_rpc_timeout_ms(timeout_ms)
B
barrierye 已提交
560
        self.bclient_.load_client_config(model_config_path)
B
barrierye 已提交
561 562
        self.bclient_.connect(endpoints)

B
barrierye 已提交
563
    def _parse_model_config(self, model_config_str):
B
barrierye 已提交
564
        model_conf = m_config.GeneralModelConfig()
B
barrierye 已提交
565 566
        model_conf = google.protobuf.text_format.Merge(model_config_str,
                                                       model_conf)
B
barrierye 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
        self.feed_shapes_ = {}
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)

    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

B
barrierye 已提交
591
    def _unpack_inference_request(self, request):
B
barrierye 已提交
592 593
        feed_names = list(request.feed_var_names)
        fetch_names = list(request.fetch_var_names)
B
barrierye 已提交
594
        is_python = request.is_python
B
barriery 已提交
595
        log_id = request.log_id
B
barrierye 已提交
596 597 598 599
        feed_batch = []
        for feed_inst in request.insts:
            feed_dict = {}
            for idx, name in enumerate(feed_names):
B
barrierye 已提交
600
                var = feed_inst.tensor_array[idx]
B
barrierye 已提交
601 602
                v_type = self.feed_types_[name]
                data = None
B
barrierye 已提交
603 604 605 606 607
                if is_python:
                    if v_type == 0:
                        data = np.frombuffer(var.data, dtype="int64")
                    elif v_type == 1:
                        data = np.frombuffer(var.data, dtype="float32")
B
barrierye 已提交
608 609
                    elif v_type == 2:
                        data = np.frombuffer(var.data, dtype="int32")
B
barrierye 已提交
610 611
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
612
                else:
B
barrierye 已提交
613 614 615 616
                    if v_type == 0:  # int64
                        data = np.array(list(var.int64_data), dtype="int64")
                    elif v_type == 1:  # float32
                        data = np.array(list(var.float_data), dtype="float32")
B
barrierye 已提交
617
                    elif v_type == 2:
618
                        data = np.array(list(var.int_data), dtype="int32")
B
barrierye 已提交
619 620 621
                    else:
                        raise Exception("error type.")
                data.shape = list(feed_inst.tensor_array[idx].shape)
B
barrierye 已提交
622 623
                feed_dict[name] = data
            feed_batch.append(feed_dict)
B
fix bug  
barriery 已提交
624
        return feed_batch, fetch_names, is_python, log_id
B
barrierye 已提交
625

B
barrierye 已提交
626
    def _pack_inference_response(self, ret, fetch_names, is_python):
B
barrierye 已提交
627
        resp = multi_lang_general_model_service_pb2.InferenceResponse()
B
fix bug  
barrierye 已提交
628
        if ret is None:
B
barrierye 已提交
629
            resp.err_code = 1
B
fix bug  
barrierye 已提交
630 631
            return resp
        results, tag = ret
B
barrierye 已提交
632
        resp.tag = tag
B
barrierye 已提交
633
        resp.err_code = 0
B
barrierye 已提交
634

B
barrierye 已提交
635 636 637 638 639 640 641 642 643 644
        if not self.is_multi_model_:
            results = {'general_infer_0': results}
        for model_name, model_result in results.items():
            model_output = multi_lang_general_model_service_pb2.ModelOutput()
            inst = multi_lang_general_model_service_pb2.FetchInst()
            for idx, name in enumerate(fetch_names):
                tensor = multi_lang_general_model_service_pb2.Tensor()
                v_type = self.fetch_types_[name]
                if is_python:
                    tensor.data = model_result[name].tobytes()
B
barrierye 已提交
645
                else:
B
barrierye 已提交
646 647 648 649 650 651
                    if v_type == 0:  # int64
                        tensor.int64_data.extend(model_result[name].reshape(-1)
                                                 .tolist())
                    elif v_type == 1:  # float32
                        tensor.float_data.extend(model_result[name].reshape(-1)
                                                 .tolist())
B
barrierye 已提交
652
                    elif v_type == 2:  # int32
653 654
                        tensor.int_data.extend(model_result[name].reshape(-1)
                                               .tolist())
B
barrierye 已提交
655 656 657 658 659 660 661 662 663 664
                    else:
                        raise Exception("error type.")
                tensor.shape.extend(list(model_result[name].shape))
                if name in self.lod_tensor_set_:
                    tensor.lod.extend(model_result["{}.lod".format(name)]
                                      .tolist())
                inst.tensor_array.append(tensor)
            model_output.insts.append(inst)
            model_output.engine_name = model_name
            resp.outputs.append(model_output)
B
barrierye 已提交
665 666
        return resp

B
barrierye 已提交
667 668 669 670 671 672 673
    def SetTimeout(self, request, context):
        # This porcess and Inference process cannot be operate at the same time.
        # For performance reasons, do not add thread lock temporarily.
        timeout_ms = request.timeout_ms
        self._init_bclient(self.model_config_path_, self.endpoints_, timeout_ms)
        resp = multi_lang_general_model_service_pb2.SimpleResponse()
        resp.err_code = 0
B
barrierye 已提交
674 675
        return resp

B
barrierye 已提交
676
    def Inference(self, request, context):
B
barriery 已提交
677 678
        feed_dict, fetch_names, is_python, log_id \
                = self._unpack_inference_request(request)
B
fix bug  
barrierye 已提交
679
        ret = self.bclient_.predict(
B
barriery 已提交
680 681 682 683
            feed=feed_dict,
            fetch=fetch_names,
            need_variant_tag=True,
            log_id=log_id)
B
barrierye 已提交
684 685 686 687 688 689
        return self._pack_inference_response(ret, fetch_names, is_python)

    def GetClientConfig(self, request, context):
        resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
        resp.client_config_str = self.model_config_str_
        return resp
B
barrierye 已提交
690 691 692


class MultiLangServer(object):
B
barrierye 已提交
693
    def __init__(self):
B
barrierye 已提交
694
        self.bserver_ = Server()
B
barrierye 已提交
695 696 697 698 699
        self.worker_num_ = 4
        self.body_size_ = 64 * 1024 * 1024
        self.concurrency_ = 100000
        self.is_multi_model_ = False  # for model ensemble

B
barrierye 已提交
700
    def set_max_concurrency(self, concurrency):
B
barrierye 已提交
701
        self.concurrency_ = concurrency
B
barrierye 已提交
702 703 704
        self.bserver_.set_max_concurrency(concurrency)

    def set_num_threads(self, threads):
B
barrierye 已提交
705
        self.worker_num_ = threads
B
barrierye 已提交
706 707 708 709
        self.bserver_.set_num_threads(threads)

    def set_max_body_size(self, body_size):
        self.bserver_.set_max_body_size(body_size)
B
barrierye 已提交
710 711 712 713 714 715
        if body_size >= self.body_size_:
            self.body_size_ = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )
B
barrierye 已提交
716 717 718 719 720 721

    def set_port(self, port):
        self.gport_ = port

    def set_reload_interval(self, interval):
        self.bserver_.set_reload_interval(interval)
B
barrierye 已提交
722 723 724 725

    def set_op_sequence(self, op_seq):
        self.bserver_.set_op_sequence(op_seq)

B
barrierye 已提交
726 727 728 729 730 731 732 733 734 735 736 737
    def set_op_graph(self, op_graph):
        self.bserver_.set_op_graph(op_graph)

    def set_memory_optimize(self, flag=False):
        self.bserver_.set_memory_optimize(flag)

    def set_ir_optimize(self, flag=False):
        self.bserver_.set_ir_optimize(flag)

    def set_gpuid(self, gpuid=0):
        self.bserver_.set_gpuid(gpuid)

B
barrierye 已提交
738 739 740 741 742 743 744 745 746 747 748
    def load_model_config(self, server_config_paths, client_config_path=None):
        self.bserver_.load_model_config(server_config_paths)
        if client_config_path is None:
            if isinstance(server_config_paths, dict):
                self.is_multi_model_ = True
                client_config_path = '{}/serving_server_conf.prototxt'.format(
                    list(server_config_paths.items())[0][1])
            else:
                client_config_path = '{}/serving_server_conf.prototxt'.format(
                    server_config_paths)
        self.bclient_config_path_ = client_config_path
B
barrierye 已提交
749

M
MRXLT 已提交
750 751 752 753 754
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
                       cube_conf=None):
B
barrierye 已提交
755 756
        if not self._port_is_available(port):
            raise SystemExit("Prot {} is already used".format(port))
B
barrierye 已提交
757 758 759 760 761 762 763 764
        default_port = 12000
        self.port_list_ = []
        for i in range(1000):
            if default_port + i != port and self._port_is_available(default_port
                                                                    + i):
                self.port_list_.append(default_port + i)
                break
        self.bserver_.prepare_server(
M
MRXLT 已提交
765 766 767 768
            workdir=workdir,
            port=self.port_list_[0],
            device=device,
            cube_conf=cube_conf)
B
barrierye 已提交
769
        self.set_port(port)
B
barrierye 已提交
770 771 772 773 774 775 776 777 778 779 780 781 782 783

    def _launch_brpc_service(self, bserver):
        bserver.run_server()

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def run_server(self):
        p_bserver = Process(
            target=self._launch_brpc_service, args=(self.bserver_, ))
        p_bserver.start()
B
barrierye 已提交
784 785
        options = [('grpc.max_send_message_length', self.body_size_),
                   ('grpc.max_receive_message_length', self.body_size_)]
B
barrierye 已提交
786
        server = grpc.server(
B
barrierye 已提交
787 788 789
            futures.ThreadPoolExecutor(max_workers=self.worker_num_),
            options=options,
            maximum_concurrent_rpcs=self.concurrency_)
B
barrierye 已提交
790
        multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
B
barrierye 已提交
791
            MultiLangServerServiceServicer(
B
barrierye 已提交
792
                self.bclient_config_path_, self.is_multi_model_,
B
barrierye 已提交
793
                ["0.0.0.0:{}".format(self.port_list_[0])]), server)
B
barrierye 已提交
794 795 796 797
        server.add_insecure_port('[::]:{}'.format(self.gport_))
        server.start()
        p_bserver.join()
        server.wait_for_termination()