__init__.py 33.6 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
B
barrierye 已提交
14
# pylint: disable=doc-string-missing
M
MRXLT 已提交
15 16 17 18 19 20

import os
from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import tarfile
M
MRXLT 已提交
21
import socket
22
import paddle_serving_server_gpu as paddle_serving_server
23
import time
24
from .version import serving_server_version
M
MRXLT 已提交
25
from contextlib import closing
G
guru4elephant 已提交
26
import argparse
B
barrierye 已提交
27
import collections
J
Jiawei Wang 已提交
28
import sys
W
wangjiawei04 已提交
29 30
if sys.platform.startswith('win') is False:
    import fcntl
M
MRXLT 已提交
31
import shutil
B
barrierye 已提交
32 33 34
import numpy as np
import grpc
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
35 36 37
import sys
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
38 39 40 41
from .proto import multi_lang_general_model_service_pb2_grpc
from multiprocessing import Pool, Process
from concurrent import futures

B
barrierye 已提交
42

43 44 45
def serve_args():
    parser = argparse.ArgumentParser("serve")
    parser.add_argument(
M
MRXLT 已提交
46
        "--thread", type=int, default=2, help="Concurrency of server")
47 48 49 50 51 52 53 54 55 56 57
    parser.add_argument(
        "--model", type=str, default="", help="Model for serving")
    parser.add_argument(
        "--port", type=int, default=9292, help="Port of the starting gpu")
    parser.add_argument(
        "--workdir",
        type=str,
        default="workdir",
        help="Working dir of current service")
    parser.add_argument(
        "--device", type=str, default="gpu", help="Type of device")
B
barrierye 已提交
58
    parser.add_argument("--gpu_ids", type=str, default="", help="gpu ids")
59
    parser.add_argument(
60
        "--name", type=str, default="None", help="Default service name")
M
MRXLT 已提交
61
    parser.add_argument(
M
MRXLT 已提交
62
        "--mem_optim_off",
M
MRXLT 已提交
63 64 65
        default=False,
        action="store_true",
        help="Memory optimize")
M
MRXLT 已提交
66
    parser.add_argument(
M
MRXLT 已提交
67
        "--ir_optim", default=False, action="store_true", help="Graph optimize")
M
MRXLT 已提交
68 69 70
    parser.add_argument(
        "--max_body_size",
        type=int,
M
MRXLT 已提交
71
        default=512 * 1024 * 1024,
M
MRXLT 已提交
72
        help="Limit sizes of messages")
B
barrierye 已提交
73 74 75 76 77
    parser.add_argument(
        "--use_multilang",
        default=False,
        action="store_true",
        help="Use Multi-language-service")
M
add trt  
MRXLT 已提交
78 79
    parser.add_argument(
        "--use_trt", default=False, action="store_true", help="Use TensorRT")
Z
zhangjun 已提交
80 81 82 83
    parser.add_argument(
        "--use_lite", default=False, action="store_true", help="Use PaddleLite")
    parser.add_argument(
        "--use_xpu", default=False, action="store_true", help="Use XPU")
84 85 86 87 88 89 90 91 92 93
    parser.add_argument(
        "--product_name",
        type=str,
        default=None,
        help="product_name for authentication")
    parser.add_argument(
        "--container_id",
        type=str,
        default=None,
        help="container_id for authentication")
94
    return parser.parse_args()
M
MRXLT 已提交
95

B
barrierye 已提交
96

M
MRXLT 已提交
97 98 99
class OpMaker(object):
    def __init__(self):
        self.op_dict = {
M
MRXLT 已提交
100 101 102 103 104 105
            "general_infer": "GeneralInferOp",
            "general_reader": "GeneralReaderOp",
            "general_response": "GeneralResponseOp",
            "general_text_reader": "GeneralTextReaderOp",
            "general_text_response": "GeneralTextResponseOp",
            "general_single_kv": "GeneralSingleKVOp",
W
wangjiawei04 已提交
106
            "general_dist_kv_infer": "GeneralDistKVInferOp",
M
MRXLT 已提交
107
            "general_dist_kv": "GeneralDistKVOp"
M
MRXLT 已提交
108
        }
B
barrierye 已提交
109
        self.node_name_suffix_ = collections.defaultdict(int)
M
MRXLT 已提交
110

B
barrierye 已提交
111 112 113 114
    def create(self, node_type, engine_name=None, inputs=[], outputs=[]):
        if node_type not in self.op_dict:
            raise Exception("Op type {} is not supported right now".format(
                node_type))
M
MRXLT 已提交
115
        node = server_sdk.DAGNode()
B
barrierye 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        # node.name will be used as the infer engine name
        if engine_name:
            node.name = engine_name
        else:
            node.name = '{}_{}'.format(node_type,
                                       self.node_name_suffix_[node_type])
            self.node_name_suffix_[node_type] += 1

        node.type = self.op_dict[node_type]
        if inputs:
            for dep_node_str in inputs:
                dep_node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(dep_node_str, dep_node)
                dep = server_sdk.DAGNodeDependency()
                dep.name = dep_node.name
                dep.mode = "RO"
                node.dependencies.extend([dep])
        # Because the return value will be used as the key value of the
        # dict, and the proto object is variable which cannot be hashed,
        # so it is processed into a string. This has little effect on
        # overall efficiency.
        return google.protobuf.text_format.MessageToString(node)
M
MRXLT 已提交
138 139 140 141 142 143 144 145


class OpSeqMaker(object):
    def __init__(self):
        self.workflow = server_sdk.Workflow()
        self.workflow.name = "workflow1"
        self.workflow.workflow_type = "Sequence"

B
barrierye 已提交
146 147 148 149 150 151 152
    def add_op(self, node_str):
        node = server_sdk.DAGNode()
        google.protobuf.text_format.Parse(node_str, node)
        if len(node.dependencies) > 1:
            raise Exception(
                'Set more than one predecessor for op in OpSeqMaker is not allowed.'
            )
M
MRXLT 已提交
153
        if len(self.workflow.nodes) >= 1:
B
barrierye 已提交
154 155 156 157 158 159 160 161
            if len(node.dependencies) == 0:
                dep = server_sdk.DAGNodeDependency()
                dep.name = self.workflow.nodes[-1].name
                dep.mode = "RO"
                node.dependencies.extend([dep])
            elif len(node.dependencies) == 1:
                if node.dependencies[0].name != self.workflow.nodes[-1].name:
                    raise Exception(
T
TeslaZhao 已提交
162
                        'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'
163
                        .format(node.dependencies[0].name, self.workflow.nodes[
B
barrierye 已提交
164
                            -1].name))
M
MRXLT 已提交
165 166 167 168 169 170 171 172
        self.workflow.nodes.extend([node])

    def get_op_sequence(self):
        workflow_conf = server_sdk.WorkflowConf()
        workflow_conf.workflows.extend([self.workflow])
        return workflow_conf


B
barrierye 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
class OpGraphMaker(object):
    def __init__(self):
        self.workflow = server_sdk.Workflow()
        self.workflow.name = "workflow1"
        # Currently, SDK only supports "Sequence"
        self.workflow.workflow_type = "Sequence"

    def add_op(self, node_str):
        node = server_sdk.DAGNode()
        google.protobuf.text_format.Parse(node_str, node)
        self.workflow.nodes.extend([node])

    def get_op_graph(self):
        workflow_conf = server_sdk.WorkflowConf()
        workflow_conf.workflows.extend([self.workflow])
        return workflow_conf


M
MRXLT 已提交
191 192 193 194 195 196 197
class Server(object):
    def __init__(self):
        self.server_handle_ = None
        self.infer_service_conf = None
        self.model_toolkit_conf = None
        self.resource_conf = None
        self.memory_optimization = False
M
MRXLT 已提交
198
        self.ir_optimization = False
M
MRXLT 已提交
199 200 201 202 203 204
        self.model_conf = None
        self.workflow_fn = "workflow.prototxt"
        self.resource_fn = "resource.prototxt"
        self.infer_service_fn = "infer_service.prototxt"
        self.model_toolkit_fn = "model_toolkit.prototxt"
        self.general_model_config_fn = "general_model.prototxt"
W
wangjiawei04 已提交
205
        self.cube_config_fn = "cube.conf"
M
MRXLT 已提交
206 207
        self.workdir = ""
        self.max_concurrency = 0
M
MRXLT 已提交
208
        self.num_threads = 2
M
MRXLT 已提交
209 210
        self.port = 8080
        self.reload_interval_s = 10
M
MRXLT 已提交
211
        self.max_body_size = 64 * 1024 * 1024
M
MRXLT 已提交
212 213
        self.module_path = os.path.dirname(paddle_serving_server.__file__)
        self.cur_path = os.getcwd()
M
MRXLT 已提交
214
        self.use_local_bin = False
M
MRXLT 已提交
215
        self.gpuid = 0
M
add trt  
MRXLT 已提交
216
        self.use_trt = False
Z
zhangjun 已提交
217 218
        self.use_lite = False
        self.use_xpu = False
B
barrierye 已提交
219
        self.model_config_paths = None  # for multi-model in a workflow
220 221
        self.product_name = None
        self.container_id = None
M
MRXLT 已提交
222

B
fix cpu  
barriery 已提交
223 224 225 226
    def get_fetch_list(self):
        fetch_names = [var.alias_name for var in self.model_conf.fetch_var]
        return fetch_names

M
MRXLT 已提交
227 228 229 230 231 232
    def set_max_concurrency(self, concurrency):
        self.max_concurrency = concurrency

    def set_num_threads(self, threads):
        self.num_threads = threads

M
MRXLT 已提交
233 234 235 236 237 238 239 240
    def set_max_body_size(self, body_size):
        if body_size >= self.max_body_size:
            self.max_body_size = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )

M
MRXLT 已提交
241 242 243 244 245 246 247 248 249
    def set_port(self, port):
        self.port = port

    def set_reload_interval(self, interval):
        self.reload_interval_s = interval

    def set_op_sequence(self, op_seq):
        self.workflow_conf = op_seq

B
barrierye 已提交
250 251 252
    def set_op_graph(self, op_graph):
        self.workflow_conf = op_graph

M
MRXLT 已提交
253 254 255
    def set_memory_optimize(self, flag=False):
        self.memory_optimization = flag

M
MRXLT 已提交
256 257 258
    def set_ir_optimize(self, flag=False):
        self.ir_optimization = flag

259 260 261 262 263 264 265 266 267 268
    def set_product_name(self, product_name=None):
        if product_name == None:
            raise ValueError("product_name can't be None.")
        self.product_name = product_name

    def set_container_id(self, container_id):
        if container_id == None:
            raise ValueError("container_id can't be None.")
        self.container_id = container_id

M
MRXLT 已提交
269 270 271 272
    def check_local_bin(self):
        if "SERVING_BIN" in os.environ:
            self.use_local_bin = True
            self.bin_path = os.environ["SERVING_BIN"]
M
MRXLT 已提交
273

M
MRXLT 已提交
274
    def check_cuda(self):
M
MRXLT 已提交
275 276 277
        if os.system("ls /dev/ | grep nvidia > /dev/null") == 0:
            pass
        else:
M
MRXLT 已提交
278
            raise SystemExit(
M
MRXLT 已提交
279
                "GPU not found, please check your environment or use cpu version by \"pip install paddle_serving_server\""
M
MRXLT 已提交
280 281
            )

M
MRXLT 已提交
282 283 284
    def set_gpuid(self, gpuid=0):
        self.gpuid = gpuid

M
bug fix  
MRXLT 已提交
285
    def set_trt(self):
M
add trt  
MRXLT 已提交
286 287
        self.use_trt = True

Z
zhangjun 已提交
288 289 290 291 292 293
    def set_lite(self):
        self.use_lite = True

    def set_xpu(self):
        self.use_xpu = True

294
    def _prepare_engine(self, model_config_paths, device):
M
MRXLT 已提交
295 296 297
        if self.model_toolkit_conf == None:
            self.model_toolkit_conf = server_sdk.ModelToolkitConf()

B
barrierye 已提交
298 299 300 301 302 303 304 305 306 307 308 309
        for engine_name, model_config_path in model_config_paths.items():
            engine = server_sdk.EngineDesc()
            engine.name = engine_name
            # engine.reloadable_meta = model_config_path + "/fluid_time_file"
            engine.reloadable_meta = self.workdir + "/fluid_time_file"
            os.system("touch {}".format(engine.reloadable_meta))
            engine.reloadable_type = "timestamp_ne"
            engine.runtime_thread_num = 0
            engine.batch_infer_size = 0
            engine.enable_batch_align = 0
            engine.model_data_path = model_config_path
            engine.enable_memory_optimization = self.memory_optimization
M
MRXLT 已提交
310
            engine.enable_ir_optimization = self.ir_optimization
B
barrierye 已提交
311 312
            engine.static_optimization = False
            engine.force_update_static_cache = False
M
add trt  
MRXLT 已提交
313
            engine.use_trt = self.use_trt
Z
zhangjun 已提交
314 315 316 317
            engine.use_lite = self.use_lite
            engine.use_xpu = self.use_xpu


B
barrierye 已提交
318 319

            if device == "cpu":
320
                engine.type = "FLUID_CPU_ANALYSIS_DIR"
B
barrierye 已提交
321
            elif device == "gpu":
322
                engine.type = "FLUID_GPU_ANALYSIS_DIR"
Z
zhangjun 已提交
323 324
            elif device == "arm":
                engine.type = "FLUID_ARM_ANALYSIS_DIR"
B
barrierye 已提交
325 326

            self.model_toolkit_conf.engines.extend([engine])
M
MRXLT 已提交
327 328 329 330 331 332 333 334 335 336

    def _prepare_infer_service(self, port):
        if self.infer_service_conf == None:
            self.infer_service_conf = server_sdk.InferServiceConf()
            self.infer_service_conf.port = port
            infer_service = server_sdk.InferService()
            infer_service.name = "GeneralModelService"
            infer_service.workflows.extend(["workflow1"])
            self.infer_service_conf.services.extend([infer_service])

M
MRXLT 已提交
337
    def _prepare_resource(self, workdir, cube_conf):
338
        self.workdir = workdir
M
MRXLT 已提交
339 340 341 342 343
        if self.resource_conf == None:
            with open("{}/{}".format(workdir, self.general_model_config_fn),
                      "w") as fout:
                fout.write(str(self.model_conf))
            self.resource_conf = server_sdk.ResourceConf()
W
wangjiawei04 已提交
344 345 346 347 348
            for workflow in self.workflow_conf.workflows:
                for node in workflow.nodes:
                    if "dist_kv" in node.name:
                        self.resource_conf.cube_config_path = workdir
                        self.resource_conf.cube_config_file = self.cube_config_fn
M
MRXLT 已提交
349 350 351 352 353
                        if cube_conf == None:
                            raise ValueError(
                                "Please set the path of cube.conf while use dist_kv op."
                            )
                        shutil.copy(cube_conf, workdir)
M
MRXLT 已提交
354 355 356 357
            self.resource_conf.model_toolkit_path = workdir
            self.resource_conf.model_toolkit_file = self.model_toolkit_fn
            self.resource_conf.general_model_path = workdir
            self.resource_conf.general_model_file = self.general_model_config_fn
358 359 360 361
            if self.product_name != None:
                self.resource_conf.auth_product_name = self.product_name
            if self.container_id != None:
                self.resource_conf.auth_container_id = self.container_id
M
MRXLT 已提交
362 363 364 365 366

    def _write_pb_str(self, filepath, pb_obj):
        with open(filepath, "w") as fout:
            fout.write(str(pb_obj))

B
barrierye 已提交
367 368 369 370
    def load_model_config(self, model_config_paths):
        # At present, Serving needs to configure the model path in
        # the resource.prototxt file to determine the input and output
        # format of the workflow. To ensure that the input and output
B
barrierye 已提交
371
        # of multiple models are the same.
B
barrierye 已提交
372 373
        workflow_oi_config_path = None
        if isinstance(model_config_paths, str):
B
barrierye 已提交
374
            # If there is only one model path, use the default infer_op.
M
MRXLT 已提交
375
            # Because there are several infer_op type, we need to find
B
barrierye 已提交
376 377 378
            # it from workflow_conf.
            default_engine_names = [
                'general_infer_0', 'general_dist_kv_infer_0',
B
barrierye 已提交
379
                'general_dist_kv_quant_infer_0'
B
barrierye 已提交
380 381
            ]
            engine_name = None
B
barrierye 已提交
382
            for node in self.workflow_conf.workflows[0].nodes:
B
barrierye 已提交
383 384 385 386 387 388 389 390 391
                if node.name in default_engine_names:
                    engine_name = node.name
                    break
            if engine_name is None:
                raise Exception(
                    "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
                )
            self.model_config_paths = {engine_name: model_config_paths}
            workflow_oi_config_path = self.model_config_paths[engine_name]
B
barrierye 已提交
392 393 394 395 396 397 398 399
        elif isinstance(model_config_paths, dict):
            self.model_config_paths = {}
            for node_str, path in model_config_paths.items():
                node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(node_str, node)
                self.model_config_paths[node.name] = path
            print("You have specified multiple model paths, please ensure "
                  "that the input and output of multiple models are the same.")
M
MRXLT 已提交
400 401
            workflow_oi_config_path = list(self.model_config_paths.items())[0][
                1]
B
barrierye 已提交
402 403 404 405 406
        else:
            raise Exception("The type of model_config_paths must be str or "
                            "dict({op: model_path}), not {}.".format(
                                type(model_config_paths)))

M
MRXLT 已提交
407
        self.model_conf = m_config.GeneralModelConfig()
B
barrierye 已提交
408 409 410
        f = open(
            "{}/serving_server_conf.prototxt".format(workflow_oi_config_path),
            'r')
M
MRXLT 已提交
411 412 413 414 415 416 417 418
        self.model_conf = google.protobuf.text_format.Merge(
            str(f.read()), self.model_conf)
        # check config here
        # print config here

    def download_bin(self):
        os.chdir(self.module_path)
        need_download = False
M
MRXLT 已提交
419 420 421 422 423 424 425

        #acquire lock
        version_file = open("{}/version.py".format(self.module_path), "r")
        import re
        for line in version_file.readlines():
            if re.match("cuda_version", line):
                cuda_version = line.split("\"")[1]
Z
zhangjun 已提交
426
                if cuda_version == "trt":
M
bug fix  
MRXLT 已提交
427
                    device_version = "serving-gpu-" + cuda_version + "-"
Z
zhangjun 已提交
428 429
                elif cuda_version == "arm":
                    device_version = "serving-" + cuda_version + "-"
Z
update  
zhangjun 已提交
430
                else:
Z
zhangjun 已提交
431
                    device_version = "serving-gpu-cuda" + cuda_version + "-"
M
MRXLT 已提交
432

433 434
        folder_name = device_version + serving_server_version
        tar_name = folder_name + ".tar.gz"
M
MRXLT 已提交
435
        bin_url = "https://paddle-serving.bj.bcebos.com/bin/" + tar_name
436 437 438 439
        self.server_path = os.path.join(self.module_path, folder_name)

        download_flag = "{}/{}.is_download".format(self.module_path,
                                                   folder_name)
M
MRXLT 已提交
440 441 442

        fcntl.flock(version_file, fcntl.LOCK_EX)

443 444 445 446 447
        if os.path.exists(download_flag):
            os.chdir(self.cur_path)
            self.bin_path = self.server_path + "/serving"
            return

M
MRXLT 已提交
448
        if not os.path.exists(self.server_path):
449 450
            os.system("touch {}/{}.is_download".format(self.module_path,
                                                       folder_name))
M
MRXLT 已提交
451
            print('Frist time run, downloading PaddleServing components ...')
M
MRXLT 已提交
452

M
MRXLT 已提交
453 454 455 456
            r = os.system('wget ' + bin_url + ' --no-check-certificate')
            if r != 0:
                if os.path.exists(tar_name):
                    os.remove(tar_name)
M
MRXLT 已提交
457
                raise SystemExit(
T
TeslaZhao 已提交
458 459
                    'Download failed, please check your network or permission of {}.'
                    .format(self.module_path))
M
MRXLT 已提交
460 461 462 463 464 465 466 467 468
            else:
                try:
                    print('Decompressing files ..')
                    tar = tarfile.open(tar_name)
                    tar.extractall()
                    tar.close()
                except:
                    if os.path.exists(exe_path):
                        os.remove(exe_path)
M
MRXLT 已提交
469
                    raise SystemExit(
T
TeslaZhao 已提交
470 471
                        'Decompressing failed, please check your permission of {} or disk space left.'
                        .format(self.module_path))
M
MRXLT 已提交
472 473
                finally:
                    os.remove(tar_name)
M
MRXLT 已提交
474
        #release lock
B
barrierye 已提交
475
        version_file.close()
M
MRXLT 已提交
476 477 478
        os.chdir(self.cur_path)
        self.bin_path = self.server_path + "/serving"

M
MRXLT 已提交
479 480 481 482
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
M
MRXLT 已提交
483
                       cube_conf=None):
M
MRXLT 已提交
484 485 486 487 488 489 490
        if workdir == None:
            workdir = "./tmp"
            os.system("mkdir {}".format(workdir))
        else:
            os.system("mkdir {}".format(workdir))
        os.system("touch {}/fluid_time_file".format(workdir))

M
MRXLT 已提交
491
        if not self.port_is_available(port):
G
gongweibao 已提交
492
            raise SystemExit("Port {} is already used".format(port))
M
MRXLT 已提交
493

G
guru4elephant 已提交
494
        self.set_port(port)
M
MRXLT 已提交
495
        self._prepare_resource(workdir, cube_conf)
496
        self._prepare_engine(self.model_config_paths, device)
M
MRXLT 已提交
497 498 499 500 501 502 503 504 505 506 507 508 509
        self._prepare_infer_service(port)
        self.workdir = workdir

        infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
        workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
        resource_fn = "{}/{}".format(workdir, self.resource_fn)
        model_toolkit_fn = "{}/{}".format(workdir, self.model_toolkit_fn)

        self._write_pb_str(infer_service_fn, self.infer_service_conf)
        self._write_pb_str(workflow_fn, self.workflow_conf)
        self._write_pb_str(resource_fn, self.resource_conf)
        self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf)

M
MRXLT 已提交
510
    def port_is_available(self, port):
M
MRXLT 已提交
511 512
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
513
            result = sock.connect_ex(('0.0.0.0', port))
M
MRXLT 已提交
514 515 516 517 518
        if result != 0:
            return True
        else:
            return False

M
MRXLT 已提交
519 520 521
    def run_server(self):
        # just run server with system command
        # currently we do not load cube
M
MRXLT 已提交
522
        self.check_local_bin()
M
MRXLT 已提交
523 524
        if not self.use_local_bin:
            self.download_bin()
B
fix bug  
barrierye 已提交
525 526 527
            # wait for other process to download server bin
            while not os.path.exists(self.server_path):
                time.sleep(1)
M
MRXLT 已提交
528 529
        else:
            print("Use local bin : {}".format(self.bin_path))
Z
zhangjun 已提交
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588
        #self.check_cuda()
        if self.use_lite:
            command = "{} " \
                      "-enable_model_toolkit " \
                      "-inferservice_path {} " \
                      "-inferservice_file {} " \
                      "-max_concurrency {} " \
                      "-num_threads {} " \
                      "-port {} " \
                      "-reload_interval_s {} " \
                      "-resource_path {} " \
                      "-resource_file {} " \
                      "-workflow_path {} " \
                      "-workflow_file {} " \
                      "-bthread_concurrency {} " \
                      "-max_body_size {} ".format(
                          self.bin_path,
                          self.workdir,
                          self.infer_service_fn,
                          self.max_concurrency,
                          self.num_threads,
                          self.port,
                          self.reload_interval_s,
                          self.workdir,
                          self.resource_fn,
                          self.workdir,
                          self.workflow_fn,
                          self.num_threads,
                          self.max_body_size)
        else:
            command = "{} " \
                      "-enable_model_toolkit " \
                      "-inferservice_path {} " \
                      "-inferservice_file {} " \
                      "-max_concurrency {} " \
                      "-num_threads {} " \
                      "-port {} " \
                      "-reload_interval_s {} " \
                      "-resource_path {} " \
                      "-resource_file {} " \
                      "-workflow_path {} " \
                      "-workflow_file {} " \
                      "-bthread_concurrency {} " \
                      "-gpuid {} " \
                      "-max_body_size {} ".format(
                          self.bin_path,
                          self.workdir,
                          self.infer_service_fn,
                          self.max_concurrency,
                          self.num_threads,
                          self.port,
                          self.reload_interval_s,
                          self.workdir,
                          self.resource_fn,
                          self.workdir,
                          self.workflow_fn,
                          self.num_threads,
                          self.gpuid,
                          self.max_body_size)
M
MRXLT 已提交
589 590
        print("Going to Run Comand")
        print(command)
591

M
MRXLT 已提交
592
        os.system(command)
B
barrierye 已提交
593 594


B
barrierye 已提交
595 596 597
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
                                     MultiLangGeneralModelServiceServicer):
    def __init__(self, model_config_path, is_multi_model, endpoints):
B
barrierye 已提交
598
        self.is_multi_model_ = is_multi_model
B
barrierye 已提交
599 600 601 602 603 604 605 606
        self.model_config_path_ = model_config_path
        self.endpoints_ = endpoints
        with open(self.model_config_path_) as f:
            self.model_config_str_ = str(f.read())
        self._parse_model_config(self.model_config_str_)
        self._init_bclient(self.model_config_path_, self.endpoints_)

    def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
B
barrierye 已提交
607 608
        from paddle_serving_client import Client
        self.bclient_ = Client()
B
barrierye 已提交
609 610
        if timeout_ms is not None:
            self.bclient_.set_rpc_timeout_ms(timeout_ms)
B
barrierye 已提交
611
        self.bclient_.load_client_config(model_config_path)
B
barrierye 已提交
612 613
        self.bclient_.connect(endpoints)

B
barrierye 已提交
614
    def _parse_model_config(self, model_config_str):
B
barrierye 已提交
615
        model_conf = m_config.GeneralModelConfig()
B
barrierye 已提交
616 617
        model_conf = google.protobuf.text_format.Merge(model_config_str,
                                                       model_conf)
B
barrierye 已提交
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
        self.feed_shapes_ = {}
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)

    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

B
barrierye 已提交
642
    def _unpack_inference_request(self, request):
B
barrierye 已提交
643 644
        feed_names = list(request.feed_var_names)
        fetch_names = list(request.fetch_var_names)
B
barrierye 已提交
645
        is_python = request.is_python
B
barriery 已提交
646
        log_id = request.log_id
B
barrierye 已提交
647 648 649 650
        feed_batch = []
        for feed_inst in request.insts:
            feed_dict = {}
            for idx, name in enumerate(feed_names):
B
barrierye 已提交
651
                var = feed_inst.tensor_array[idx]
B
barrierye 已提交
652 653
                v_type = self.feed_types_[name]
                data = None
B
barrierye 已提交
654 655 656 657 658
                if is_python:
                    if v_type == 0:
                        data = np.frombuffer(var.data, dtype="int64")
                    elif v_type == 1:
                        data = np.frombuffer(var.data, dtype="float32")
B
barrierye 已提交
659 660
                    elif v_type == 2:
                        data = np.frombuffer(var.data, dtype="int32")
B
barrierye 已提交
661 662
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
663
                else:
B
barrierye 已提交
664 665 666 667
                    if v_type == 0:  # int64
                        data = np.array(list(var.int64_data), dtype="int64")
                    elif v_type == 1:  # float32
                        data = np.array(list(var.float_data), dtype="float32")
B
barrierye 已提交
668
                    elif v_type == 2:
669
                        data = np.array(list(var.int_data), dtype="int32")
B
barrierye 已提交
670 671 672
                    else:
                        raise Exception("error type.")
                data.shape = list(feed_inst.tensor_array[idx].shape)
B
barrierye 已提交
673 674
                feed_dict[name] = data
            feed_batch.append(feed_dict)
B
fix bug  
barriery 已提交
675
        return feed_batch, fetch_names, is_python, log_id
B
barrierye 已提交
676

B
barrierye 已提交
677
    def _pack_inference_response(self, ret, fetch_names, is_python):
B
barrierye 已提交
678
        resp = multi_lang_general_model_service_pb2.InferenceResponse()
B
fix bug  
barrierye 已提交
679
        if ret is None:
B
barrierye 已提交
680
            resp.err_code = 1
B
fix bug  
barrierye 已提交
681 682
            return resp
        results, tag = ret
B
barrierye 已提交
683
        resp.tag = tag
B
barrierye 已提交
684
        resp.err_code = 0
B
barrierye 已提交
685

B
barrierye 已提交
686 687 688 689 690 691 692 693 694 695
        if not self.is_multi_model_:
            results = {'general_infer_0': results}
        for model_name, model_result in results.items():
            model_output = multi_lang_general_model_service_pb2.ModelOutput()
            inst = multi_lang_general_model_service_pb2.FetchInst()
            for idx, name in enumerate(fetch_names):
                tensor = multi_lang_general_model_service_pb2.Tensor()
                v_type = self.fetch_types_[name]
                if is_python:
                    tensor.data = model_result[name].tobytes()
B
barrierye 已提交
696
                else:
B
barrierye 已提交
697 698 699 700 701 702
                    if v_type == 0:  # int64
                        tensor.int64_data.extend(model_result[name].reshape(-1)
                                                 .tolist())
                    elif v_type == 1:  # float32
                        tensor.float_data.extend(model_result[name].reshape(-1)
                                                 .tolist())
B
barrierye 已提交
703
                    elif v_type == 2:  # int32
704 705
                        tensor.int_data.extend(model_result[name].reshape(-1)
                                               .tolist())
B
barrierye 已提交
706 707 708
                    else:
                        raise Exception("error type.")
                tensor.shape.extend(list(model_result[name].shape))
M
MRXLT 已提交
709
                if "{}.lod".format(name) in model_result:
B
barrierye 已提交
710 711 712 713 714 715
                    tensor.lod.extend(model_result["{}.lod".format(name)]
                                      .tolist())
                inst.tensor_array.append(tensor)
            model_output.insts.append(inst)
            model_output.engine_name = model_name
            resp.outputs.append(model_output)
B
barrierye 已提交
716 717
        return resp

B
barrierye 已提交
718 719 720 721 722 723 724
    def SetTimeout(self, request, context):
        # This porcess and Inference process cannot be operate at the same time.
        # For performance reasons, do not add thread lock temporarily.
        timeout_ms = request.timeout_ms
        self._init_bclient(self.model_config_path_, self.endpoints_, timeout_ms)
        resp = multi_lang_general_model_service_pb2.SimpleResponse()
        resp.err_code = 0
B
barrierye 已提交
725 726
        return resp

B
barrierye 已提交
727
    def Inference(self, request, context):
B
barriery 已提交
728 729
        feed_dict, fetch_names, is_python, log_id \
                = self._unpack_inference_request(request)
B
fix bug  
barrierye 已提交
730
        ret = self.bclient_.predict(
731 732 733 734
            feed=feed_dict,
            fetch=fetch_names,
            need_variant_tag=True,
            log_id=log_id)
B
barrierye 已提交
735 736 737 738 739 740
        return self._pack_inference_response(ret, fetch_names, is_python)

    def GetClientConfig(self, request, context):
        resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
        resp.client_config_str = self.model_config_str_
        return resp
B
barrierye 已提交
741 742 743


class MultiLangServer(object):
B
barrierye 已提交
744
    def __init__(self):
B
barrierye 已提交
745
        self.bserver_ = Server()
B
barrierye 已提交
746 747 748 749 750
        self.worker_num_ = 4
        self.body_size_ = 64 * 1024 * 1024
        self.concurrency_ = 100000
        self.is_multi_model_ = False  # for model ensemble

B
barrierye 已提交
751
    def set_max_concurrency(self, concurrency):
B
barrierye 已提交
752
        self.concurrency_ = concurrency
B
barrierye 已提交
753 754 755
        self.bserver_.set_max_concurrency(concurrency)

    def set_num_threads(self, threads):
B
barrierye 已提交
756
        self.worker_num_ = threads
B
barrierye 已提交
757 758 759 760
        self.bserver_.set_num_threads(threads)

    def set_max_body_size(self, body_size):
        self.bserver_.set_max_body_size(body_size)
B
barrierye 已提交
761 762 763 764 765 766
        if body_size >= self.body_size_:
            self.body_size_ = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )
B
barrierye 已提交
767 768 769 770 771 772

    def set_port(self, port):
        self.gport_ = port

    def set_reload_interval(self, interval):
        self.bserver_.set_reload_interval(interval)
B
barrierye 已提交
773 774 775 776

    def set_op_sequence(self, op_seq):
        self.bserver_.set_op_sequence(op_seq)

B
barrierye 已提交
777 778 779 780 781 782 783 784
    def set_op_graph(self, op_graph):
        self.bserver_.set_op_graph(op_graph)

    def set_memory_optimize(self, flag=False):
        self.bserver_.set_memory_optimize(flag)

    def set_ir_optimize(self, flag=False):
        self.bserver_.set_ir_optimize(flag)
B
barrierye 已提交
785

B
barrierye 已提交
786 787 788
    def set_gpuid(self, gpuid=0):
        self.bserver_.set_gpuid(gpuid)

B
barrierye 已提交
789 790 791 792 793 794 795 796 797 798 799
    def load_model_config(self, server_config_paths, client_config_path=None):
        self.bserver_.load_model_config(server_config_paths)
        if client_config_path is None:
            if isinstance(server_config_paths, dict):
                self.is_multi_model_ = True
                client_config_path = '{}/serving_server_conf.prototxt'.format(
                    list(server_config_paths.items())[0][1])
            else:
                client_config_path = '{}/serving_server_conf.prototxt'.format(
                    server_config_paths)
        self.bclient_config_path_ = client_config_path
B
barrierye 已提交
800

M
MRXLT 已提交
801 802 803 804 805
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
                       cube_conf=None):
B
barrierye 已提交
806 807
        if not self._port_is_available(port):
            raise SystemExit("Prot {} is already used".format(port))
B
barrierye 已提交
808 809 810 811 812 813 814 815
        default_port = 12000
        self.port_list_ = []
        for i in range(1000):
            if default_port + i != port and self._port_is_available(default_port
                                                                    + i):
                self.port_list_.append(default_port + i)
                break
        self.bserver_.prepare_server(
M
MRXLT 已提交
816 817 818 819
            workdir=workdir,
            port=self.port_list_[0],
            device=device,
            cube_conf=cube_conf)
B
barrierye 已提交
820
        self.set_port(port)
B
barrierye 已提交
821 822 823 824 825 826 827 828 829 830 831 832 833 834

    def _launch_brpc_service(self, bserver):
        bserver.run_server()

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def run_server(self):
        p_bserver = Process(
            target=self._launch_brpc_service, args=(self.bserver_, ))
        p_bserver.start()
B
barrierye 已提交
835 836
        options = [('grpc.max_send_message_length', self.body_size_),
                   ('grpc.max_receive_message_length', self.body_size_)]
B
barrierye 已提交
837
        server = grpc.server(
B
barrierye 已提交
838 839 840
            futures.ThreadPoolExecutor(max_workers=self.worker_num_),
            options=options,
            maximum_concurrent_rpcs=self.concurrency_)
B
barrierye 已提交
841
        multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
B
barrierye 已提交
842
            MultiLangServerServiceServicer(
B
barrierye 已提交
843
                self.bclient_config_path_, self.is_multi_model_,
B
barrierye 已提交
844
                ["0.0.0.0:{}".format(self.port_list_[0])]), server)
B
barrierye 已提交
845 846 847 848
        server.add_insecure_port('[::]:{}'.format(self.gport_))
        server.start()
        p_bserver.join()
        server.wait_for_termination()