server.py 28.7 KB
Newer Older
Z
update  
zhangjun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
M
MRXLT 已提交
14 15 16

import os
import tarfile
M
MRXLT 已提交
17
import socket
Z
zhangjun 已提交
18
import paddle_serving_server as paddle_serving_server
Z
zhangjun 已提交
19
from paddle_serving_server.rpc_service import MultiLangServerServiceServicer
Z
update  
zhangjun 已提交
20 21
from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
Z
zhangjun 已提交
22
from .proto import multi_lang_general_model_service_pb2_grpc
Z
update  
zhangjun 已提交
23
import google.protobuf.text_format
24
import time
25
from .version import version_tag, version_suffix, device_type
M
MRXLT 已提交
26
from contextlib import closing
G
guru4elephant 已提交
27
import argparse
Z
zhangjun 已提交
28

J
Jiawei Wang 已提交
29
import sys
W
wangjiawei04 已提交
30 31
if sys.platform.startswith('win') is False:
    import fcntl
M
MRXLT 已提交
32
import shutil
Z
update  
zhangjun 已提交
33
import platform
B
barrierye 已提交
34 35
import numpy as np
import grpc
B
barrierye 已提交
36
import sys
37
import collections
Z
zhangjun 已提交
38

B
barrierye 已提交
39 40 41
from multiprocessing import Pool, Process
from concurrent import futures

Z
update  
zhangjun 已提交
42

M
MRXLT 已提交
43 44
class Server(object):
    def __init__(self):
H
HexToString 已提交
45 46 47 48 49 50 51 52 53 54 55
        """
        self.model_toolkit_conf:'list'=[] # The quantity of self.model_toolkit_conf is equal to the InferOp quantity/Engine--OP
        self.model_conf:'collections.OrderedDict()' # Save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
        self.workflow_fn:'str'="workflow.prototxt" # Only one for one Service/Workflow
        self.resource_fn:'str'="resource.prototxt" # Only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file
        self.infer_service_fn:'str'="infer_service.prototxt" # Only one for one Service,Service--Workflow
        self.model_toolkit_fn:'list'=[] # ["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
        self.general_model_config_fn:'list'=[] # ["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
        self.subdirectory:'list'=[] # The quantity is equal to the InferOp quantity, and name = node.name = engine.name
        self.model_config_paths:'collections.OrderedDict()' # Save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
        """
M
MRXLT 已提交
56 57
        self.server_handle_ = None
        self.infer_service_conf = None
H
HexToString 已提交
58
        self.model_toolkit_conf = []
M
MRXLT 已提交
59 60
        self.resource_conf = None
        self.memory_optimization = False
M
MRXLT 已提交
61
        self.ir_optimization = False
Z
zhangjun 已提交
62
        self.model_conf = collections.OrderedDict()
H
HexToString 已提交
63 64 65
        self.workflow_fn = "workflow.prototxt"
        self.resource_fn = "resource.prototxt"
        self.infer_service_fn = "infer_service.prototxt"
Z
zhangjun 已提交
66 67 68
        self.model_toolkit_fn = []
        self.general_model_config_fn = []
        self.subdirectory = []
W
wangjiawei04 已提交
69
        self.cube_config_fn = "cube.conf"
M
MRXLT 已提交
70 71
        self.workdir = ""
        self.max_concurrency = 0
M
MRXLT 已提交
72
        self.num_threads = 2
M
MRXLT 已提交
73
        self.port = 8080
74 75
        self.precision = "fp32"
        self.use_calib = False
M
MRXLT 已提交
76
        self.reload_interval_s = 10
M
MRXLT 已提交
77
        self.max_body_size = 64 * 1024 * 1024
M
MRXLT 已提交
78 79
        self.module_path = os.path.dirname(paddle_serving_server.__file__)
        self.cur_path = os.getcwd()
M
MRXLT 已提交
80
        self.use_local_bin = False
Z
zhangjun 已提交
81
        self.mkl_flag = False
Z
zhangjun 已提交
82
        self.device = "cpu"
M
MRXLT 已提交
83
        self.gpuid = 0
M
add trt  
MRXLT 已提交
84
        self.use_trt = False
Z
zhangjun 已提交
85 86
        self.use_lite = False
        self.use_xpu = False
Z
zhangjun 已提交
87
        self.model_config_paths = collections.OrderedDict()
88 89
        self.product_name = None
        self.container_id = None
M
MRXLT 已提交
90

Z
zhangjun 已提交
91 92 93 94 95
    def get_fetch_list(self, infer_node_idx=-1):
        fetch_names = [
            var.alias_name
            for var in list(self.model_conf.values())[infer_node_idx].fetch_var
        ]
B
fix cpu  
barriery 已提交
96 97
        return fetch_names

M
MRXLT 已提交
98 99 100 101 102 103
    def set_max_concurrency(self, concurrency):
        self.max_concurrency = concurrency

    def set_num_threads(self, threads):
        self.num_threads = threads

M
MRXLT 已提交
104 105 106 107 108 109 110 111
    def set_max_body_size(self, body_size):
        if body_size >= self.max_body_size:
            self.max_body_size = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )

112 113 114
    def use_encryption_model(self, flag=False):
        self.encryption_model = flag

M
MRXLT 已提交
115 116 117
    def set_port(self, port):
        self.port = port

118 119 120 121 122 123
    def set_precision(self, precision="fp32"):
        self.precision = precision

    def set_use_calib(self, use_calib=False):
        self.use_calib = use_calib

M
MRXLT 已提交
124 125 126 127 128 129
    def set_reload_interval(self, interval):
        self.reload_interval_s = interval

    def set_op_sequence(self, op_seq):
        self.workflow_conf = op_seq

B
barrierye 已提交
130 131 132
    def set_op_graph(self, op_graph):
        self.workflow_conf = op_graph

M
MRXLT 已提交
133 134 135
    def set_memory_optimize(self, flag=False):
        self.memory_optimization = flag

M
MRXLT 已提交
136 137 138
    def set_ir_optimize(self, flag=False):
        self.ir_optimization = flag

139 140 141 142 143 144 145 146 147 148
    def set_product_name(self, product_name=None):
        if product_name == None:
            raise ValueError("product_name can't be None.")
        self.product_name = product_name

    def set_container_id(self, container_id):
        if container_id == None:
            raise ValueError("container_id can't be None.")
        self.container_id = container_id

M
MRXLT 已提交
149 150 151 152
    def check_local_bin(self):
        if "SERVING_BIN" in os.environ:
            self.use_local_bin = True
            self.bin_path = os.environ["SERVING_BIN"]
M
MRXLT 已提交
153

M
MRXLT 已提交
154
    def check_cuda(self):
M
MRXLT 已提交
155 156 157
        if os.system("ls /dev/ | grep nvidia > /dev/null") == 0:
            pass
        else:
M
MRXLT 已提交
158
            raise SystemExit(
M
MRXLT 已提交
159
                "GPU not found, please check your environment or use cpu version by \"pip install paddle_serving_server\""
M
MRXLT 已提交
160 161
            )

Z
zhangjun 已提交
162 163 164
    def set_device(self, device="cpu"):
        self.device = device

M
MRXLT 已提交
165 166 167
    def set_gpuid(self, gpuid=0):
        self.gpuid = gpuid

M
bug fix  
MRXLT 已提交
168
    def set_trt(self):
M
add trt  
MRXLT 已提交
169 170
        self.use_trt = True

Z
zhangjun 已提交
171 172 173 174 175 176
    def set_lite(self):
        self.use_lite = True

    def set_xpu(self):
        self.use_xpu = True

H
HexToString 已提交
177
    def _prepare_engine(self, model_config_paths, device, use_encryption_model):
M
MRXLT 已提交
178
        if self.model_toolkit_conf == None:
H
HexToString 已提交
179
            self.model_toolkit_conf = []
M
MRXLT 已提交
180

B
barrierye 已提交
181 182 183 184
        for engine_name, model_config_path in model_config_paths.items():
            engine = server_sdk.EngineDesc()
            engine.name = engine_name
            # engine.reloadable_meta = model_config_path + "/fluid_time_file"
185
            engine.reloadable_meta = model_config_path + "/fluid_time_file"
B
barrierye 已提交
186 187 188 189 190
            os.system("touch {}".format(engine.reloadable_meta))
            engine.reloadable_type = "timestamp_ne"
            engine.runtime_thread_num = 0
            engine.batch_infer_size = 0
            engine.enable_batch_align = 0
Z
update  
zhangjun 已提交
191
            engine.model_dir = model_config_path
B
barrierye 已提交
192
            engine.enable_memory_optimization = self.memory_optimization
M
MRXLT 已提交
193
            engine.enable_ir_optimization = self.ir_optimization
M
add trt  
MRXLT 已提交
194
            engine.use_trt = self.use_trt
Z
update  
zhangjun 已提交
195 196
            engine.use_lite = self.use_lite
            engine.use_xpu = self.use_xpu
Z
zhangjun 已提交
197 198 199 200
            engine.use_gpu = False
            if self.device == "gpu":
                engine.use_gpu = True

Z
fix  
zhangjun 已提交
201
            if os.path.exists('{}/__params__'.format(model_config_path)):
Z
update  
zhangjun 已提交
202
                engine.combined_model = True
Z
fix  
zhangjun 已提交
203 204
            else:
                engine.combined_model = False
Z
update  
zhangjun 已提交
205 206
            if use_encryption_model:
                engine.encrypted_model = True
Z
fix  
zhangjun 已提交
207
            engine.type = "PADDLE_INFER"
H
HexToString 已提交
208 209
            self.model_toolkit_conf.append(server_sdk.ModelToolkitConf())
            self.model_toolkit_conf[-1].engines.extend([engine])
M
MRXLT 已提交
210 211 212 213 214 215 216 217 218 219

    def _prepare_infer_service(self, port):
        if self.infer_service_conf == None:
            self.infer_service_conf = server_sdk.InferServiceConf()
            self.infer_service_conf.port = port
            infer_service = server_sdk.InferService()
            infer_service.name = "GeneralModelService"
            infer_service.workflows.extend(["workflow1"])
            self.infer_service_conf.services.extend([infer_service])

M
MRXLT 已提交
220
    def _prepare_resource(self, workdir, cube_conf):
221
        self.workdir = workdir
M
MRXLT 已提交
222 223
        if self.resource_conf == None:
            self.resource_conf = server_sdk.ResourceConf()
Z
zhangjun 已提交
224 225
            for idx, op_general_model_config_fn in enumerate(
                    self.general_model_config_fn):
H
HexToString 已提交
226
                with open("{}/{}".format(workdir, op_general_model_config_fn),
Z
zhangjun 已提交
227
                          "w") as fout:
H
HexToString 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241
                    fout.write(str(list(self.model_conf.values())[idx]))
                for workflow in self.workflow_conf.workflows:
                    for node in workflow.nodes:
                        if "dist_kv" in node.name:
                            self.resource_conf.cube_config_path = workdir
                            self.resource_conf.cube_config_file = self.cube_config_fn
                            if cube_conf == None:
                                raise ValueError(
                                    "Please set the path of cube.conf while use dist_kv op."
                                )
                            shutil.copy(cube_conf, workdir)
                            if "quant" in node.name:
                                self.resource_conf.cube_quant_bits = 8
                self.resource_conf.model_toolkit_path.extend([workdir])
Z
zhangjun 已提交
242 243
                self.resource_conf.model_toolkit_file.extend(
                    [self.model_toolkit_fn[idx]])
H
HexToString 已提交
244
                self.resource_conf.general_model_path.extend([workdir])
Z
zhangjun 已提交
245 246
                self.resource_conf.general_model_file.extend(
                    [op_general_model_config_fn])
H
HexToString 已提交
247 248 249 250 251
                #TODO:figure out the meaning of product_name and container_id.
                if self.product_name != None:
                    self.resource_conf.auth_product_name = self.product_name
                if self.container_id != None:
                    self.resource_conf.auth_container_id = self.container_id
M
MRXLT 已提交
252 253 254 255 256

    def _write_pb_str(self, filepath, pb_obj):
        with open(filepath, "w") as fout:
            fout.write(str(pb_obj))

H
HexToString 已提交
257
    def load_model_config(self, model_config_paths_args):
B
barrierye 已提交
258 259 260
        # At present, Serving needs to configure the model path in
        # the resource.prototxt file to determine the input and output
        # format of the workflow. To ensure that the input and output
B
barrierye 已提交
261
        # of multiple models are the same.
H
HexToString 已提交
262 263 264 265 266 267 268
        if isinstance(model_config_paths_args, str):
            model_config_paths_args = [model_config_paths_args]

        for single_model_config in model_config_paths_args:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
Z
zhangjun 已提交
269 270 271
                raise ValueError(
                    "The input of --model should be a dir not file.")

H
HexToString 已提交
272
        if isinstance(model_config_paths_args, list):
B
barrierye 已提交
273
            # If there is only one model path, use the default infer_op.
M
MRXLT 已提交
274
            # Because there are several infer_op type, we need to find
B
barrierye 已提交
275
            # it from workflow_conf.
H
HexToString 已提交
276
            default_engine_types = [
Z
zhangjun 已提交
277 278 279 280
                'GeneralInferOp',
                'GeneralDistKVInferOp',
                'GeneralDistKVQuantInferOp',
                'GeneralDetectionOp',
B
barrierye 已提交
281
            ]
H
HexToString 已提交
282 283 284
            # now only support single-workflow.
            # TODO:support multi-workflow
            model_config_paths_list_idx = 0
B
barrierye 已提交
285
            for node in self.workflow_conf.workflows[0].nodes:
H
HexToString 已提交
286 287 288 289 290
                if node.type in default_engine_types:
                    if node.name is None:
                        raise Exception(
                            "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
                        )
Z
zhangjun 已提交
291

H
HexToString 已提交
292
                    f = open("{}/serving_server_conf.prototxt".format(
Z
zhangjun 已提交
293 294 295 296 297 298 299 300 301 302 303 304
                        model_config_paths_args[model_config_paths_list_idx]),
                             'r')
                    self.model_conf[
                        node.name] = google.protobuf.text_format.Merge(
                            str(f.read()), m_config.GeneralModelConfig())
                    self.model_config_paths[
                        node.name] = model_config_paths_args[
                            model_config_paths_list_idx]
                    self.general_model_config_fn.append(
                        node.name + "/general_model.prototxt")
                    self.model_toolkit_fn.append(node.name +
                                                 "/model_toolkit.prototxt")
H
HexToString 已提交
305 306
                    self.subdirectory.append(node.name)
                    model_config_paths_list_idx += 1
Z
zhangjun 已提交
307 308
                    if model_config_paths_list_idx == len(
                            model_config_paths_args):
H
HexToString 已提交
309 310 311 312 313
                        break
        #Right now, this is not useful.
        elif isinstance(model_config_paths_args, dict):
            self.model_config_paths = collections.OrderedDict()
            for node_str, path in model_config_paths_args.items():
B
barrierye 已提交
314 315 316 317 318
                node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(node_str, node)
                self.model_config_paths[node.name] = path
            print("You have specified multiple model paths, please ensure "
                  "that the input and output of multiple models are the same.")
H
HexToString 已提交
319 320
            f = open("{}/serving_server_conf.prototxt".format(path), 'r')
            self.model_conf[node.name] = google.protobuf.text_format.Merge(
Z
zhangjun 已提交
321
                str(f.read()), m_config.GeneralModelConfig())
B
barrierye 已提交
322
        else:
Z
zhangjun 已提交
323 324 325 326
            raise Exception(
                "The type of model_config_paths must be str or list or "
                "dict({op: model_path}), not {}.".format(
                    type(model_config_paths_args)))
M
MRXLT 已提交
327 328
        # check config here
        # print config here
Z
update  
zhangjun 已提交
329

Z
zhangjun 已提交
330 331 332 333 334 335 336 337 338 339 340
    def use_mkl(self, flag):
        self.mkl_flag = flag

    def get_device_version(self):
        avx_flag = False
        mkl_flag = self.mkl_flag
        r = os.system("cat /proc/cpuinfo | grep avx > /dev/null 2>&1")
        if r == 0:
            avx_flag = True
        if avx_flag:
            if mkl_flag:
Z
update  
zhangjun 已提交
341
                device_version = "cpu-avx-mkl"
Z
zhangjun 已提交
342
            else:
Z
update  
zhangjun 已提交
343
                device_version = "cpu-avx-openblas"
Z
zhangjun 已提交
344 345 346 347 348
        else:
            if mkl_flag:
                print(
                    "Your CPU does not support AVX, server will running with noavx-openblas mode."
                )
Z
update  
zhangjun 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361
            device_version = "cpu-noavx-openblas"
        return device_version

    def get_serving_bin_name(self):
        if device_type == "0":
            device_version = self.get_device_version()
        elif device_type == "1":
            if version_suffix == "101" or version_suffix == "102":
                device_version = "gpu-" + version_suffix
            else:
                device_version = "gpu-cuda" + version_suffix
        elif device_type == "2":
            device_version = "xpu-" + platform.machine()
Z
zhangjun 已提交
362
        return device_version
M
MRXLT 已提交
363 364 365 366

    def download_bin(self):
        os.chdir(self.module_path)
        need_download = False
M
MRXLT 已提交
367 368 369

        #acquire lock
        version_file = open("{}/version.py".format(self.module_path), "r")
Z
update  
zhangjun 已提交
370

Z
fix  
zhangjun 已提交
371
        folder_name = "serving-%s-%s" % (self.get_serving_bin_name(),
372
                                         version_tag)
Z
fix  
zhangjun 已提交
373 374 375
        tar_name = "%s.tar.gz" % folder_name
        bin_url = "https://paddle-serving.bj.bcebos.com/bin/%s" % tar_name

376 377 378 379
        self.server_path = os.path.join(self.module_path, folder_name)

        download_flag = "{}/{}.is_download".format(self.module_path,
                                                   folder_name)
M
MRXLT 已提交
380 381 382

        fcntl.flock(version_file, fcntl.LOCK_EX)

383 384 385 386 387
        if os.path.exists(download_flag):
            os.chdir(self.cur_path)
            self.bin_path = self.server_path + "/serving"
            return

M
MRXLT 已提交
388
        if not os.path.exists(self.server_path):
389 390
            os.system("touch {}/{}.is_download".format(self.module_path,
                                                       folder_name))
M
MRXLT 已提交
391
            print('Frist time run, downloading PaddleServing components ...')
M
MRXLT 已提交
392

M
MRXLT 已提交
393 394 395 396
            r = os.system('wget ' + bin_url + ' --no-check-certificate')
            if r != 0:
                if os.path.exists(tar_name):
                    os.remove(tar_name)
M
MRXLT 已提交
397
                raise SystemExit(
T
TeslaZhao 已提交
398 399
                    'Download failed, please check your network or permission of {}.'
                    .format(self.module_path))
M
MRXLT 已提交
400 401 402 403 404 405 406 407 408
            else:
                try:
                    print('Decompressing files ..')
                    tar = tarfile.open(tar_name)
                    tar.extractall()
                    tar.close()
                except:
                    if os.path.exists(exe_path):
                        os.remove(exe_path)
M
MRXLT 已提交
409
                    raise SystemExit(
T
TeslaZhao 已提交
410 411
                        'Decompressing failed, please check your permission of {} or disk space left.'
                        .format(self.module_path))
M
MRXLT 已提交
412 413
                finally:
                    os.remove(tar_name)
M
MRXLT 已提交
414
        #release lock
B
barrierye 已提交
415
        version_file.close()
M
MRXLT 已提交
416 417
        os.chdir(self.cur_path)
        self.bin_path = self.server_path + "/serving"
Z
update  
zhangjun 已提交
418

M
MRXLT 已提交
419 420 421 422
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
W
wangjiawei04 已提交
423
                       use_encryption_model=False,
M
MRXLT 已提交
424
                       cube_conf=None):
M
MRXLT 已提交
425 426
        if workdir == None:
            workdir = "./tmp"
Z
zhangjun 已提交
427
            os.system("mkdir -p {}".format(workdir))
M
MRXLT 已提交
428
        else:
Z
zhangjun 已提交
429
            os.system("mkdir -p {}".format(workdir))
H
HexToString 已提交
430
        for subdir in self.subdirectory:
431
            os.system("mkdir -p {}/{}".format(workdir, subdir))
H
HexToString 已提交
432 433
            os.system("touch {}/{}/fluid_time_file".format(workdir, subdir))

M
MRXLT 已提交
434
        if not self.port_is_available(port):
G
gongweibao 已提交
435
            raise SystemExit("Port {} is already used".format(port))
M
MRXLT 已提交
436

G
guru4elephant 已提交
437
        self.set_port(port)
M
MRXLT 已提交
438
        self._prepare_resource(workdir, cube_conf)
H
HexToString 已提交
439 440
        self._prepare_engine(self.model_config_paths, device,
                             use_encryption_model)
M
MRXLT 已提交
441 442 443 444 445
        self._prepare_infer_service(port)
        self.workdir = workdir

        infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
        self._write_pb_str(infer_service_fn, self.infer_service_conf)
H
HexToString 已提交
446 447

        workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
M
MRXLT 已提交
448
        self._write_pb_str(workflow_fn, self.workflow_conf)
H
HexToString 已提交
449 450

        resource_fn = "{}/{}".format(workdir, self.resource_fn)
M
MRXLT 已提交
451
        self._write_pb_str(resource_fn, self.resource_conf)
H
HexToString 已提交
452

Z
zhangjun 已提交
453
        for idx, single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
H
HexToString 已提交
454 455
            model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn)
            self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx])
M
MRXLT 已提交
456

M
MRXLT 已提交
457
    def port_is_available(self, port):
M
MRXLT 已提交
458 459
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
460
            result = sock.connect_ex(('0.0.0.0', port))
M
MRXLT 已提交
461 462 463 464 465
        if result != 0:
            return True
        else:
            return False

M
MRXLT 已提交
466 467 468
    def run_server(self):
        # just run server with system command
        # currently we do not load cube
M
MRXLT 已提交
469
        self.check_local_bin()
M
MRXLT 已提交
470 471
        if not self.use_local_bin:
            self.download_bin()
B
fix bug  
barrierye 已提交
472 473 474
            # wait for other process to download server bin
            while not os.path.exists(self.server_path):
                time.sleep(1)
M
MRXLT 已提交
475 476
        else:
            print("Use local bin : {}".format(self.bin_path))
Z
zhangjun 已提交
477
        #self.check_cuda()
Z
zhangjun 已提交
478 479
        # Todo: merge CPU and GPU code, remove device to model_toolkit
        if self.device == "cpu" or self.device == "arm":
Z
zhangjun 已提交
480 481 482 483 484 485 486
            command = "{} " \
                      "-enable_model_toolkit " \
                      "-inferservice_path {} " \
                      "-inferservice_file {} " \
                      "-max_concurrency {} " \
                      "-num_threads {} " \
                      "-port {} " \
487 488
                      "-precision {} " \
                      "-use_calib {} " \
Z
zhangjun 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501
                      "-reload_interval_s {} " \
                      "-resource_path {} " \
                      "-resource_file {} " \
                      "-workflow_path {} " \
                      "-workflow_file {} " \
                      "-bthread_concurrency {} " \
                      "-max_body_size {} ".format(
                          self.bin_path,
                          self.workdir,
                          self.infer_service_fn,
                          self.max_concurrency,
                          self.num_threads,
                          self.port,
502 503
                          self.precision,
                          self.use_calib,
Z
zhangjun 已提交
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
                          self.reload_interval_s,
                          self.workdir,
                          self.resource_fn,
                          self.workdir,
                          self.workflow_fn,
                          self.num_threads,
                          self.max_body_size)
        else:
            command = "{} " \
                      "-enable_model_toolkit " \
                      "-inferservice_path {} " \
                      "-inferservice_file {} " \
                      "-max_concurrency {} " \
                      "-num_threads {} " \
                      "-port {} " \
519 520
                      "-precision {} " \
                      "-use_calib {} " \
Z
zhangjun 已提交
521 522 523 524 525 526 527 528 529 530 531 532 533 534
                      "-reload_interval_s {} " \
                      "-resource_path {} " \
                      "-resource_file {} " \
                      "-workflow_path {} " \
                      "-workflow_file {} " \
                      "-bthread_concurrency {} " \
                      "-gpuid {} " \
                      "-max_body_size {} ".format(
                          self.bin_path,
                          self.workdir,
                          self.infer_service_fn,
                          self.max_concurrency,
                          self.num_threads,
                          self.port,
535 536
                          self.precision,
                          self.use_calib,
Z
zhangjun 已提交
537 538 539 540 541 542 543 544
                          self.reload_interval_s,
                          self.workdir,
                          self.resource_fn,
                          self.workdir,
                          self.workflow_fn,
                          self.num_threads,
                          self.gpuid,
                          self.max_body_size)
M
MRXLT 已提交
545 546
        print("Going to Run Comand")
        print(command)
547

M
MRXLT 已提交
548
        os.system(command)
B
barrierye 已提交
549

Z
zhangjun 已提交
550

B
barrierye 已提交
551
class MultiLangServer(object):
B
barrierye 已提交
552
    def __init__(self):
B
barrierye 已提交
553
        self.bserver_ = Server()
B
barrierye 已提交
554 555 556
        self.worker_num_ = 4
        self.body_size_ = 64 * 1024 * 1024
        self.concurrency_ = 100000
557
        self.is_multi_model_ = False  # for model ensemble, which is not useful right now.
B
barrierye 已提交
558

B
barrierye 已提交
559
    def set_max_concurrency(self, concurrency):
B
barrierye 已提交
560
        self.concurrency_ = concurrency
B
barrierye 已提交
561 562
        self.bserver_.set_max_concurrency(concurrency)

563 564 565
    def set_device(self, device="cpu"):
        self.device = device

B
barrierye 已提交
566
    def set_num_threads(self, threads):
B
barrierye 已提交
567
        self.worker_num_ = threads
B
barrierye 已提交
568 569 570 571
        self.bserver_.set_num_threads(threads)

    def set_max_body_size(self, body_size):
        self.bserver_.set_max_body_size(body_size)
B
barrierye 已提交
572 573 574 575 576 577
        if body_size >= self.body_size_:
            self.body_size_ = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )
B
barrierye 已提交
578

579 580 581
    def use_encryption_model(self, flag=False):
        self.encryption_model = flag

B
barrierye 已提交
582 583 584
    def set_port(self, port):
        self.gport_ = port

585 586 587 588 589 590
    def set_precision(self, precision="fp32"):
        self.precision = precision

    def set_use_calib(self, use_calib=False):
        self.use_calib = use_calib

B
barrierye 已提交
591 592
    def set_reload_interval(self, interval):
        self.bserver_.set_reload_interval(interval)
B
barrierye 已提交
593 594 595 596

    def set_op_sequence(self, op_seq):
        self.bserver_.set_op_sequence(op_seq)

B
barrierye 已提交
597 598
    def set_op_graph(self, op_graph):
        self.bserver_.set_op_graph(op_graph)
Z
update  
zhangjun 已提交
599

Z
zhangjun 已提交
600 601
    def use_mkl(self, flag):
        self.bserver_.use_mkl(flag)
Z
update  
zhangjun 已提交
602

B
barrierye 已提交
603 604 605 606 607
    def set_memory_optimize(self, flag=False):
        self.bserver_.set_memory_optimize(flag)

    def set_ir_optimize(self, flag=False):
        self.bserver_.set_ir_optimize(flag)
B
barrierye 已提交
608

B
barrierye 已提交
609 610 611
    def set_gpuid(self, gpuid=0):
        self.bserver_.set_gpuid(gpuid)

Z
zhangjun 已提交
612 613 614
    def load_model_config(self,
                          server_config_dir_paths,
                          client_config_path=None):
H
HexToString 已提交
615 616 617 618 619 620
        if isinstance(server_config_dir_paths, str):
            server_config_dir_paths = [server_config_dir_paths]
        elif isinstance(server_config_dir_paths, list):
            pass
        else:
            raise Exception("The type of model_config_paths must be str or list"
Z
zhangjun 已提交
621
                            ", not {}.".format(type(server_config_dir_paths)))
H
HexToString 已提交
622 623 624 625 626

        for single_model_config in server_config_dir_paths:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
Z
zhangjun 已提交
627 628
                raise ValueError(
                    "The input of --model should be a dir not file.")
H
HexToString 已提交
629 630

        self.bserver_.load_model_config(server_config_dir_paths)
B
barrierye 已提交
631
        if client_config_path is None:
H
HexToString 已提交
632 633
            #now dict is not useful.
            if isinstance(server_config_dir_paths, dict):
B
barrierye 已提交
634
                self.is_multi_model_ = True
H
HexToString 已提交
635
                client_config_path = []
Z
zhangjun 已提交
636 637 638
                for server_config_path_items in list(
                        server_config_dir_paths.items()):
                    client_config_path.append(server_config_path_items[1])
H
HexToString 已提交
639 640 641
            elif isinstance(server_config_dir_paths, list):
                self.is_multi_model_ = False
                client_config_path = server_config_dir_paths
B
barrierye 已提交
642
            else:
Z
zhangjun 已提交
643 644 645 646
                raise Exception(
                    "The type of model_config_paths must be str or list or "
                    "dict({op: model_path}), not {}.".format(
                        type(server_config_dir_paths)))
H
HexToString 已提交
647 648 649 650
        if isinstance(client_config_path, str):
            client_config_path = [client_config_path]
        elif isinstance(client_config_path, list):
            pass
Z
zhangjun 已提交
651 652 653 654 655
        else:  # dict is not support right now.
            raise Exception(
                "The type of client_config_path must be str or list or "
                "dict({op: model_path}), not {}.".format(
                    type(client_config_path)))
H
HexToString 已提交
656
        if len(client_config_path) != len(server_config_dir_paths):
Z
zhangjun 已提交
657 658 659
            raise Warning(
                "The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
                .format(len(client_config_path), len(server_config_dir_paths)))
H
HexToString 已提交
660
        self.bclient_config_path_list = client_config_path
B
barrierye 已提交
661

M
MRXLT 已提交
662 663 664 665
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
H
HexToString 已提交
666
                       use_encryption_model=False,
M
MRXLT 已提交
667
                       cube_conf=None):
B
barrierye 已提交
668 669
        if not self._port_is_available(port):
            raise SystemExit("Prot {} is already used".format(port))
B
barrierye 已提交
670 671 672 673 674 675 676 677
        default_port = 12000
        self.port_list_ = []
        for i in range(1000):
            if default_port + i != port and self._port_is_available(default_port
                                                                    + i):
                self.port_list_.append(default_port + i)
                break
        self.bserver_.prepare_server(
M
MRXLT 已提交
678 679 680
            workdir=workdir,
            port=self.port_list_[0],
            device=device,
H
HexToString 已提交
681
            use_encryption_model=use_encryption_model,
M
MRXLT 已提交
682
            cube_conf=cube_conf)
B
barrierye 已提交
683
        self.set_port(port)
B
barrierye 已提交
684 685 686 687 688 689 690 691 692 693 694 695 696 697

    def _launch_brpc_service(self, bserver):
        bserver.run_server()

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def run_server(self):
        p_bserver = Process(
            target=self._launch_brpc_service, args=(self.bserver_, ))
        p_bserver.start()
B
barrierye 已提交
698 699
        options = [('grpc.max_send_message_length', self.body_size_),
                   ('grpc.max_receive_message_length', self.body_size_)]
B
barrierye 已提交
700
        server = grpc.server(
B
barrierye 已提交
701 702 703
            futures.ThreadPoolExecutor(max_workers=self.worker_num_),
            options=options,
            maximum_concurrent_rpcs=self.concurrency_)
B
barrierye 已提交
704
        multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
B
barrierye 已提交
705
            MultiLangServerServiceServicer(
H
HexToString 已提交
706
                self.bclient_config_path_list, self.is_multi_model_,
B
barrierye 已提交
707
                ["0.0.0.0:{}".format(self.port_list_[0])]), server)
B
barrierye 已提交
708 709 710
        server.add_insecure_port('[::]:{}'.format(self.gport_))
        server.start()
        p_bserver.join()
Z
update  
zhangjun 已提交
711
        server.wait_for_termination()