server.py 27.6 KB
Newer Older
Z
update  
zhangjun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
M
MRXLT 已提交
14 15 16

import os
import tarfile
M
MRXLT 已提交
17
import socket
Z
zhangjun 已提交
18
import paddle_serving_server as paddle_serving_server
Z
zhangjun 已提交
19
from paddle_serving_server.rpc_service import MultiLangServerServiceServicer
Z
update  
zhangjun 已提交
20 21
from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
Z
zhangjun 已提交
22
from .proto import multi_lang_general_model_service_pb2_grpc
Z
update  
zhangjun 已提交
23
import google.protobuf.text_format
24
import time
Z
update  
zhangjun 已提交
25
from .version import serving_server_version, version_suffix, device_type
M
MRXLT 已提交
26
from contextlib import closing
G
guru4elephant 已提交
27
import argparse
Z
zhangjun 已提交
28

J
Jiawei Wang 已提交
29
import sys
W
wangjiawei04 已提交
30 31
if sys.platform.startswith('win') is False:
    import fcntl
M
MRXLT 已提交
32
import shutil
Z
update  
zhangjun 已提交
33
import platform
B
barrierye 已提交
34 35
import numpy as np
import grpc
B
barrierye 已提交
36
import sys
37
import collections
Z
zhangjun 已提交
38

B
barrierye 已提交
39 40 41
from multiprocessing import Pool, Process
from concurrent import futures

Z
update  
zhangjun 已提交
42

M
MRXLT 已提交
43 44 45 46
class Server(object):
    def __init__(self):
        self.server_handle_ = None
        self.infer_service_conf = None
Z
zhangjun 已提交
47 48
        self.model_toolkit_conf = [
        ]  #The quantity is equal to the InferOp quantity,Engine--OP
M
MRXLT 已提交
49 50
        self.resource_conf = None
        self.memory_optimization = False
M
MRXLT 已提交
51
        self.ir_optimization = False
Z
zhangjun 已提交
52 53 54 55 56 57 58 59 60 61 62
        # save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
        self.model_conf = collections.OrderedDict()
        self.workflow_fn = "workflow.prototxt"  #only one for one Service,Workflow--Op 
        self.resource_fn = "resource.prototxt"  #only one for one Service,model_toolkit_fn and  general_model_config_fn is recorded in this file
        self.infer_service_fn = "infer_service.prototxt"  #only one for one Service,Service--Workflow
        #["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
        self.model_toolkit_fn = []
        #["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
        self.general_model_config_fn = []
        #The quantity is equal to the InferOp quantity, and name = node.name = engine.name
        self.subdirectory = []
W
wangjiawei04 已提交
63
        self.cube_config_fn = "cube.conf"
M
MRXLT 已提交
64 65
        self.workdir = ""
        self.max_concurrency = 0
M
MRXLT 已提交
66
        self.num_threads = 2
M
MRXLT 已提交
67 68
        self.port = 8080
        self.reload_interval_s = 10
M
MRXLT 已提交
69
        self.max_body_size = 64 * 1024 * 1024
M
MRXLT 已提交
70 71
        self.module_path = os.path.dirname(paddle_serving_server.__file__)
        self.cur_path = os.getcwd()
M
MRXLT 已提交
72
        self.use_local_bin = False
Z
zhangjun 已提交
73
        self.mkl_flag = False
Z
zhangjun 已提交
74
        self.device = "cpu"
M
MRXLT 已提交
75
        self.gpuid = 0
M
add trt  
MRXLT 已提交
76
        self.use_trt = False
Z
zhangjun 已提交
77 78
        self.use_lite = False
        self.use_xpu = False
Z
zhangjun 已提交
79 80
        # save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
        self.model_config_paths = collections.OrderedDict()
81 82
        self.product_name = None
        self.container_id = None
M
MRXLT 已提交
83

Z
zhangjun 已提交
84 85 86 87 88
    def get_fetch_list(self, infer_node_idx=-1):
        fetch_names = [
            var.alias_name
            for var in list(self.model_conf.values())[infer_node_idx].fetch_var
        ]
B
fix cpu  
barriery 已提交
89 90
        return fetch_names

M
MRXLT 已提交
91 92 93 94 95 96
    def set_max_concurrency(self, concurrency):
        self.max_concurrency = concurrency

    def set_num_threads(self, threads):
        self.num_threads = threads

M
MRXLT 已提交
97 98 99 100 101 102 103 104
    def set_max_body_size(self, body_size):
        if body_size >= self.max_body_size:
            self.max_body_size = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )

105 106 107
    def use_encryption_model(self, flag=False):
        self.encryption_model = flag

M
MRXLT 已提交
108 109 110 111 112 113 114 115 116
    def set_port(self, port):
        self.port = port

    def set_reload_interval(self, interval):
        self.reload_interval_s = interval

    def set_op_sequence(self, op_seq):
        self.workflow_conf = op_seq

B
barrierye 已提交
117 118 119
    def set_op_graph(self, op_graph):
        self.workflow_conf = op_graph

M
MRXLT 已提交
120 121 122
    def set_memory_optimize(self, flag=False):
        self.memory_optimization = flag

M
MRXLT 已提交
123 124 125
    def set_ir_optimize(self, flag=False):
        self.ir_optimization = flag

126 127 128 129 130 131 132 133 134 135
    def set_product_name(self, product_name=None):
        if product_name == None:
            raise ValueError("product_name can't be None.")
        self.product_name = product_name

    def set_container_id(self, container_id):
        if container_id == None:
            raise ValueError("container_id can't be None.")
        self.container_id = container_id

M
MRXLT 已提交
136 137 138 139
    def check_local_bin(self):
        if "SERVING_BIN" in os.environ:
            self.use_local_bin = True
            self.bin_path = os.environ["SERVING_BIN"]
M
MRXLT 已提交
140

M
MRXLT 已提交
141
    def check_cuda(self):
M
MRXLT 已提交
142 143 144
        if os.system("ls /dev/ | grep nvidia > /dev/null") == 0:
            pass
        else:
M
MRXLT 已提交
145
            raise SystemExit(
M
MRXLT 已提交
146
                "GPU not found, please check your environment or use cpu version by \"pip install paddle_serving_server\""
M
MRXLT 已提交
147 148
            )

Z
zhangjun 已提交
149 150 151
    def set_device(self, device="cpu"):
        self.device = device

M
MRXLT 已提交
152 153 154
    def set_gpuid(self, gpuid=0):
        self.gpuid = gpuid

M
bug fix  
MRXLT 已提交
155
    def set_trt(self):
M
add trt  
MRXLT 已提交
156 157
        self.use_trt = True

Z
zhangjun 已提交
158 159 160 161 162 163
    def set_lite(self):
        self.use_lite = True

    def set_xpu(self):
        self.use_xpu = True

H
HexToString 已提交
164
    def _prepare_engine(self, model_config_paths, device, use_encryption_model):
M
MRXLT 已提交
165
        if self.model_toolkit_conf == None:
H
HexToString 已提交
166
            self.model_toolkit_conf = []
M
MRXLT 已提交
167

B
barrierye 已提交
168 169 170 171
        for engine_name, model_config_path in model_config_paths.items():
            engine = server_sdk.EngineDesc()
            engine.name = engine_name
            # engine.reloadable_meta = model_config_path + "/fluid_time_file"
172
            engine.reloadable_meta = model_config_path + "/fluid_time_file"
B
barrierye 已提交
173 174 175 176 177
            os.system("touch {}".format(engine.reloadable_meta))
            engine.reloadable_type = "timestamp_ne"
            engine.runtime_thread_num = 0
            engine.batch_infer_size = 0
            engine.enable_batch_align = 0
Z
update  
zhangjun 已提交
178
            engine.model_dir = model_config_path
B
barrierye 已提交
179
            engine.enable_memory_optimization = self.memory_optimization
M
MRXLT 已提交
180
            engine.enable_ir_optimization = self.ir_optimization
M
add trt  
MRXLT 已提交
181
            engine.use_trt = self.use_trt
Z
update  
zhangjun 已提交
182 183
            engine.use_lite = self.use_lite
            engine.use_xpu = self.use_xpu
Z
zhangjun 已提交
184 185 186 187
            engine.use_gpu = False
            if self.device == "gpu":
                engine.use_gpu = True

Z
fix  
zhangjun 已提交
188
            if os.path.exists('{}/__params__'.format(model_config_path)):
Z
update  
zhangjun 已提交
189
                engine.combined_model = True
Z
fix  
zhangjun 已提交
190 191
            else:
                engine.combined_model = False
Z
update  
zhangjun 已提交
192 193
            if use_encryption_model:
                engine.encrypted_model = True
Z
fix  
zhangjun 已提交
194
            engine.type = "PADDLE_INFER"
H
HexToString 已提交
195 196
            self.model_toolkit_conf.append(server_sdk.ModelToolkitConf())
            self.model_toolkit_conf[-1].engines.extend([engine])
M
MRXLT 已提交
197 198 199 200 201 202 203 204 205 206

    def _prepare_infer_service(self, port):
        if self.infer_service_conf == None:
            self.infer_service_conf = server_sdk.InferServiceConf()
            self.infer_service_conf.port = port
            infer_service = server_sdk.InferService()
            infer_service.name = "GeneralModelService"
            infer_service.workflows.extend(["workflow1"])
            self.infer_service_conf.services.extend([infer_service])

M
MRXLT 已提交
207
    def _prepare_resource(self, workdir, cube_conf):
208
        self.workdir = workdir
M
MRXLT 已提交
209 210
        if self.resource_conf == None:
            self.resource_conf = server_sdk.ResourceConf()
Z
zhangjun 已提交
211 212
            for idx, op_general_model_config_fn in enumerate(
                    self.general_model_config_fn):
H
HexToString 已提交
213
                with open("{}/{}".format(workdir, op_general_model_config_fn),
Z
zhangjun 已提交
214
                          "w") as fout:
H
HexToString 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228
                    fout.write(str(list(self.model_conf.values())[idx]))
                for workflow in self.workflow_conf.workflows:
                    for node in workflow.nodes:
                        if "dist_kv" in node.name:
                            self.resource_conf.cube_config_path = workdir
                            self.resource_conf.cube_config_file = self.cube_config_fn
                            if cube_conf == None:
                                raise ValueError(
                                    "Please set the path of cube.conf while use dist_kv op."
                                )
                            shutil.copy(cube_conf, workdir)
                            if "quant" in node.name:
                                self.resource_conf.cube_quant_bits = 8
                self.resource_conf.model_toolkit_path.extend([workdir])
Z
zhangjun 已提交
229 230
                self.resource_conf.model_toolkit_file.extend(
                    [self.model_toolkit_fn[idx]])
H
HexToString 已提交
231
                self.resource_conf.general_model_path.extend([workdir])
Z
zhangjun 已提交
232 233
                self.resource_conf.general_model_file.extend(
                    [op_general_model_config_fn])
H
HexToString 已提交
234 235 236 237 238
                #TODO:figure out the meaning of product_name and container_id.
                if self.product_name != None:
                    self.resource_conf.auth_product_name = self.product_name
                if self.container_id != None:
                    self.resource_conf.auth_container_id = self.container_id
M
MRXLT 已提交
239 240 241 242 243

    def _write_pb_str(self, filepath, pb_obj):
        with open(filepath, "w") as fout:
            fout.write(str(pb_obj))

H
HexToString 已提交
244
    def load_model_config(self, model_config_paths_args):
B
barrierye 已提交
245 246 247
        # At present, Serving needs to configure the model path in
        # the resource.prototxt file to determine the input and output
        # format of the workflow. To ensure that the input and output
B
barrierye 已提交
248
        # of multiple models are the same.
H
HexToString 已提交
249 250 251 252 253 254 255
        if isinstance(model_config_paths_args, str):
            model_config_paths_args = [model_config_paths_args]

        for single_model_config in model_config_paths_args:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
Z
zhangjun 已提交
256 257 258
                raise ValueError(
                    "The input of --model should be a dir not file.")

H
HexToString 已提交
259
        if isinstance(model_config_paths_args, list):
B
barrierye 已提交
260
            # If there is only one model path, use the default infer_op.
M
MRXLT 已提交
261
            # Because there are several infer_op type, we need to find
B
barrierye 已提交
262
            # it from workflow_conf.
H
HexToString 已提交
263
            default_engine_types = [
Z
zhangjun 已提交
264 265 266 267
                'GeneralInferOp',
                'GeneralDistKVInferOp',
                'GeneralDistKVQuantInferOp',
                'GeneralDetectionOp',
B
barrierye 已提交
268
            ]
H
HexToString 已提交
269 270 271
            # now only support single-workflow.
            # TODO:support multi-workflow
            model_config_paths_list_idx = 0
B
barrierye 已提交
272
            for node in self.workflow_conf.workflows[0].nodes:
H
HexToString 已提交
273 274 275 276 277
                if node.type in default_engine_types:
                    if node.name is None:
                        raise Exception(
                            "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
                        )
Z
zhangjun 已提交
278

H
HexToString 已提交
279
                    f = open("{}/serving_server_conf.prototxt".format(
Z
zhangjun 已提交
280 281 282 283 284 285 286 287 288 289 290 291
                        model_config_paths_args[model_config_paths_list_idx]),
                             'r')
                    self.model_conf[
                        node.name] = google.protobuf.text_format.Merge(
                            str(f.read()), m_config.GeneralModelConfig())
                    self.model_config_paths[
                        node.name] = model_config_paths_args[
                            model_config_paths_list_idx]
                    self.general_model_config_fn.append(
                        node.name + "/general_model.prototxt")
                    self.model_toolkit_fn.append(node.name +
                                                 "/model_toolkit.prototxt")
H
HexToString 已提交
292 293
                    self.subdirectory.append(node.name)
                    model_config_paths_list_idx += 1
Z
zhangjun 已提交
294 295
                    if model_config_paths_list_idx == len(
                            model_config_paths_args):
H
HexToString 已提交
296 297 298 299 300
                        break
        #Right now, this is not useful.
        elif isinstance(model_config_paths_args, dict):
            self.model_config_paths = collections.OrderedDict()
            for node_str, path in model_config_paths_args.items():
B
barrierye 已提交
301 302 303 304 305
                node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(node_str, node)
                self.model_config_paths[node.name] = path
            print("You have specified multiple model paths, please ensure "
                  "that the input and output of multiple models are the same.")
H
HexToString 已提交
306 307
            f = open("{}/serving_server_conf.prototxt".format(path), 'r')
            self.model_conf[node.name] = google.protobuf.text_format.Merge(
Z
zhangjun 已提交
308
                str(f.read()), m_config.GeneralModelConfig())
B
barrierye 已提交
309
        else:
Z
zhangjun 已提交
310 311 312 313
            raise Exception(
                "The type of model_config_paths must be str or list or "
                "dict({op: model_path}), not {}.".format(
                    type(model_config_paths_args)))
M
MRXLT 已提交
314 315
        # check config here
        # print config here
Z
update  
zhangjun 已提交
316

Z
zhangjun 已提交
317 318 319 320 321 322 323 324 325 326 327
    def use_mkl(self, flag):
        self.mkl_flag = flag

    def get_device_version(self):
        avx_flag = False
        mkl_flag = self.mkl_flag
        r = os.system("cat /proc/cpuinfo | grep avx > /dev/null 2>&1")
        if r == 0:
            avx_flag = True
        if avx_flag:
            if mkl_flag:
Z
update  
zhangjun 已提交
328
                device_version = "cpu-avx-mkl"
Z
zhangjun 已提交
329
            else:
Z
update  
zhangjun 已提交
330
                device_version = "cpu-avx-openblas"
Z
zhangjun 已提交
331 332 333 334 335
        else:
            if mkl_flag:
                print(
                    "Your CPU does not support AVX, server will running with noavx-openblas mode."
                )
Z
update  
zhangjun 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348
            device_version = "cpu-noavx-openblas"
        return device_version

    def get_serving_bin_name(self):
        if device_type == "0":
            device_version = self.get_device_version()
        elif device_type == "1":
            if version_suffix == "101" or version_suffix == "102":
                device_version = "gpu-" + version_suffix
            else:
                device_version = "gpu-cuda" + version_suffix
        elif device_type == "2":
            device_version = "xpu-" + platform.machine()
Z
zhangjun 已提交
349
        return device_version
M
MRXLT 已提交
350 351 352 353

    def download_bin(self):
        os.chdir(self.module_path)
        need_download = False
M
MRXLT 已提交
354 355 356

        #acquire lock
        version_file = open("{}/version.py".format(self.module_path), "r")
Z
update  
zhangjun 已提交
357

Z
fix  
zhangjun 已提交
358 359 360 361 362
        folder_name = "serving-%s-%s" % (self.get_serving_bin_name(),
                                         serving_server_version)
        tar_name = "%s.tar.gz" % folder_name
        bin_url = "https://paddle-serving.bj.bcebos.com/bin/%s" % tar_name

363 364 365 366
        self.server_path = os.path.join(self.module_path, folder_name)

        download_flag = "{}/{}.is_download".format(self.module_path,
                                                   folder_name)
M
MRXLT 已提交
367 368 369

        fcntl.flock(version_file, fcntl.LOCK_EX)

370 371 372 373 374
        if os.path.exists(download_flag):
            os.chdir(self.cur_path)
            self.bin_path = self.server_path + "/serving"
            return

M
MRXLT 已提交
375
        if not os.path.exists(self.server_path):
376 377
            os.system("touch {}/{}.is_download".format(self.module_path,
                                                       folder_name))
M
MRXLT 已提交
378
            print('Frist time run, downloading PaddleServing components ...')
M
MRXLT 已提交
379

M
MRXLT 已提交
380 381 382 383
            r = os.system('wget ' + bin_url + ' --no-check-certificate')
            if r != 0:
                if os.path.exists(tar_name):
                    os.remove(tar_name)
M
MRXLT 已提交
384
                raise SystemExit(
T
TeslaZhao 已提交
385 386
                    'Download failed, please check your network or permission of {}.'
                    .format(self.module_path))
M
MRXLT 已提交
387 388 389 390 391 392 393 394 395
            else:
                try:
                    print('Decompressing files ..')
                    tar = tarfile.open(tar_name)
                    tar.extractall()
                    tar.close()
                except:
                    if os.path.exists(exe_path):
                        os.remove(exe_path)
M
MRXLT 已提交
396
                    raise SystemExit(
T
TeslaZhao 已提交
397 398
                        'Decompressing failed, please check your permission of {} or disk space left.'
                        .format(self.module_path))
M
MRXLT 已提交
399 400
                finally:
                    os.remove(tar_name)
M
MRXLT 已提交
401
        #release lock
B
barrierye 已提交
402
        version_file.close()
M
MRXLT 已提交
403 404
        os.chdir(self.cur_path)
        self.bin_path = self.server_path + "/serving"
Z
update  
zhangjun 已提交
405

M
MRXLT 已提交
406 407 408 409
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
W
wangjiawei04 已提交
410
                       use_encryption_model=False,
M
MRXLT 已提交
411
                       cube_conf=None):
M
MRXLT 已提交
412 413
        if workdir == None:
            workdir = "./tmp"
Z
zhangjun 已提交
414
            os.system("mkdir -p {}".format(workdir))
M
MRXLT 已提交
415
        else:
Z
zhangjun 已提交
416
            os.system("mkdir -p {}".format(workdir))
H
HexToString 已提交
417
        for subdir in self.subdirectory:
418
            os.system("mkdir -p {}/{}".format(workdir, subdir))
H
HexToString 已提交
419 420
            os.system("touch {}/{}/fluid_time_file".format(workdir, subdir))

M
MRXLT 已提交
421
        if not self.port_is_available(port):
G
gongweibao 已提交
422
            raise SystemExit("Port {} is already used".format(port))
M
MRXLT 已提交
423

G
guru4elephant 已提交
424
        self.set_port(port)
M
MRXLT 已提交
425
        self._prepare_resource(workdir, cube_conf)
H
HexToString 已提交
426 427
        self._prepare_engine(self.model_config_paths, device,
                             use_encryption_model)
M
MRXLT 已提交
428 429 430 431 432
        self._prepare_infer_service(port)
        self.workdir = workdir

        infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
        self._write_pb_str(infer_service_fn, self.infer_service_conf)
H
HexToString 已提交
433 434

        workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
M
MRXLT 已提交
435
        self._write_pb_str(workflow_fn, self.workflow_conf)
H
HexToString 已提交
436 437

        resource_fn = "{}/{}".format(workdir, self.resource_fn)
M
MRXLT 已提交
438
        self._write_pb_str(resource_fn, self.resource_conf)
H
HexToString 已提交
439

Z
zhangjun 已提交
440
        for idx, single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
H
HexToString 已提交
441 442
            model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn)
            self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx])
M
MRXLT 已提交
443

M
MRXLT 已提交
444
    def port_is_available(self, port):
M
MRXLT 已提交
445 446
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
447
            result = sock.connect_ex(('0.0.0.0', port))
M
MRXLT 已提交
448 449 450 451 452
        if result != 0:
            return True
        else:
            return False

M
MRXLT 已提交
453 454 455
    def run_server(self):
        # just run server with system command
        # currently we do not load cube
M
MRXLT 已提交
456
        self.check_local_bin()
M
MRXLT 已提交
457 458
        if not self.use_local_bin:
            self.download_bin()
B
fix bug  
barrierye 已提交
459 460 461
            # wait for other process to download server bin
            while not os.path.exists(self.server_path):
                time.sleep(1)
M
MRXLT 已提交
462 463
        else:
            print("Use local bin : {}".format(self.bin_path))
Z
zhangjun 已提交
464
        #self.check_cuda()
Z
zhangjun 已提交
465 466
        # Todo: merge CPU and GPU code, remove device to model_toolkit
        if self.device == "cpu" or self.device == "arm":
Z
zhangjun 已提交
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
            command = "{} " \
                      "-enable_model_toolkit " \
                      "-inferservice_path {} " \
                      "-inferservice_file {} " \
                      "-max_concurrency {} " \
                      "-num_threads {} " \
                      "-port {} " \
                      "-reload_interval_s {} " \
                      "-resource_path {} " \
                      "-resource_file {} " \
                      "-workflow_path {} " \
                      "-workflow_file {} " \
                      "-bthread_concurrency {} " \
                      "-max_body_size {} ".format(
                          self.bin_path,
                          self.workdir,
                          self.infer_service_fn,
                          self.max_concurrency,
                          self.num_threads,
                          self.port,
                          self.reload_interval_s,
                          self.workdir,
                          self.resource_fn,
                          self.workdir,
                          self.workflow_fn,
                          self.num_threads,
                          self.max_body_size)
        else:
            command = "{} " \
                      "-enable_model_toolkit " \
                      "-inferservice_path {} " \
                      "-inferservice_file {} " \
                      "-max_concurrency {} " \
                      "-num_threads {} " \
                      "-port {} " \
                      "-reload_interval_s {} " \
                      "-resource_path {} " \
                      "-resource_file {} " \
                      "-workflow_path {} " \
                      "-workflow_file {} " \
                      "-bthread_concurrency {} " \
                      "-gpuid {} " \
                      "-max_body_size {} ".format(
                          self.bin_path,
                          self.workdir,
                          self.infer_service_fn,
                          self.max_concurrency,
                          self.num_threads,
                          self.port,
                          self.reload_interval_s,
                          self.workdir,
                          self.resource_fn,
                          self.workdir,
                          self.workflow_fn,
                          self.num_threads,
                          self.gpuid,
                          self.max_body_size)
M
MRXLT 已提交
524 525
        print("Going to Run Comand")
        print(command)
526

M
MRXLT 已提交
527
        os.system(command)
B
barrierye 已提交
528

Z
zhangjun 已提交
529

B
barrierye 已提交
530
class MultiLangServer(object):
B
barrierye 已提交
531
    def __init__(self):
B
barrierye 已提交
532
        self.bserver_ = Server()
B
barrierye 已提交
533 534 535
        self.worker_num_ = 4
        self.body_size_ = 64 * 1024 * 1024
        self.concurrency_ = 100000
536
        self.is_multi_model_ = False  # for model ensemble, which is not useful right now.
B
barrierye 已提交
537

B
barrierye 已提交
538
    def set_max_concurrency(self, concurrency):
B
barrierye 已提交
539
        self.concurrency_ = concurrency
B
barrierye 已提交
540 541
        self.bserver_.set_max_concurrency(concurrency)

542 543 544
    def set_device(self, device="cpu"):
        self.device = device

B
barrierye 已提交
545
    def set_num_threads(self, threads):
B
barrierye 已提交
546
        self.worker_num_ = threads
B
barrierye 已提交
547 548 549 550
        self.bserver_.set_num_threads(threads)

    def set_max_body_size(self, body_size):
        self.bserver_.set_max_body_size(body_size)
B
barrierye 已提交
551 552 553 554 555 556
        if body_size >= self.body_size_:
            self.body_size_ = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )
B
barrierye 已提交
557

558 559 560
    def use_encryption_model(self, flag=False):
        self.encryption_model = flag

B
barrierye 已提交
561 562 563 564 565
    def set_port(self, port):
        self.gport_ = port

    def set_reload_interval(self, interval):
        self.bserver_.set_reload_interval(interval)
B
barrierye 已提交
566 567 568 569

    def set_op_sequence(self, op_seq):
        self.bserver_.set_op_sequence(op_seq)

B
barrierye 已提交
570 571
    def set_op_graph(self, op_graph):
        self.bserver_.set_op_graph(op_graph)
Z
update  
zhangjun 已提交
572

Z
zhangjun 已提交
573 574
    def use_mkl(self, flag):
        self.bserver_.use_mkl(flag)
Z
update  
zhangjun 已提交
575

B
barrierye 已提交
576 577 578 579 580
    def set_memory_optimize(self, flag=False):
        self.bserver_.set_memory_optimize(flag)

    def set_ir_optimize(self, flag=False):
        self.bserver_.set_ir_optimize(flag)
B
barrierye 已提交
581

B
barrierye 已提交
582 583 584
    def set_gpuid(self, gpuid=0):
        self.bserver_.set_gpuid(gpuid)

Z
zhangjun 已提交
585 586 587
    def load_model_config(self,
                          server_config_dir_paths,
                          client_config_path=None):
H
HexToString 已提交
588 589 590 591 592 593
        if isinstance(server_config_dir_paths, str):
            server_config_dir_paths = [server_config_dir_paths]
        elif isinstance(server_config_dir_paths, list):
            pass
        else:
            raise Exception("The type of model_config_paths must be str or list"
Z
zhangjun 已提交
594
                            ", not {}.".format(type(server_config_dir_paths)))
H
HexToString 已提交
595 596 597 598 599

        for single_model_config in server_config_dir_paths:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
Z
zhangjun 已提交
600 601
                raise ValueError(
                    "The input of --model should be a dir not file.")
H
HexToString 已提交
602 603

        self.bserver_.load_model_config(server_config_dir_paths)
B
barrierye 已提交
604
        if client_config_path is None:
H
HexToString 已提交
605 606
            #now dict is not useful.
            if isinstance(server_config_dir_paths, dict):
B
barrierye 已提交
607
                self.is_multi_model_ = True
H
HexToString 已提交
608
                client_config_path = []
Z
zhangjun 已提交
609 610 611
                for server_config_path_items in list(
                        server_config_dir_paths.items()):
                    client_config_path.append(server_config_path_items[1])
H
HexToString 已提交
612 613 614
            elif isinstance(server_config_dir_paths, list):
                self.is_multi_model_ = False
                client_config_path = server_config_dir_paths
B
barrierye 已提交
615
            else:
Z
zhangjun 已提交
616 617 618 619
                raise Exception(
                    "The type of model_config_paths must be str or list or "
                    "dict({op: model_path}), not {}.".format(
                        type(server_config_dir_paths)))
H
HexToString 已提交
620 621 622 623
        if isinstance(client_config_path, str):
            client_config_path = [client_config_path]
        elif isinstance(client_config_path, list):
            pass
Z
zhangjun 已提交
624 625 626 627 628
        else:  # dict is not support right now.
            raise Exception(
                "The type of client_config_path must be str or list or "
                "dict({op: model_path}), not {}.".format(
                    type(client_config_path)))
H
HexToString 已提交
629
        if len(client_config_path) != len(server_config_dir_paths):
Z
zhangjun 已提交
630 631 632
            raise Warning(
                "The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
                .format(len(client_config_path), len(server_config_dir_paths)))
H
HexToString 已提交
633
        self.bclient_config_path_list = client_config_path
B
barrierye 已提交
634

M
MRXLT 已提交
635 636 637 638
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
H
HexToString 已提交
639
                       use_encryption_model=False,
M
MRXLT 已提交
640
                       cube_conf=None):
B
barrierye 已提交
641 642
        if not self._port_is_available(port):
            raise SystemExit("Prot {} is already used".format(port))
B
barrierye 已提交
643 644 645 646 647 648 649 650
        default_port = 12000
        self.port_list_ = []
        for i in range(1000):
            if default_port + i != port and self._port_is_available(default_port
                                                                    + i):
                self.port_list_.append(default_port + i)
                break
        self.bserver_.prepare_server(
M
MRXLT 已提交
651 652 653
            workdir=workdir,
            port=self.port_list_[0],
            device=device,
H
HexToString 已提交
654
            use_encryption_model=use_encryption_model,
M
MRXLT 已提交
655
            cube_conf=cube_conf)
B
barrierye 已提交
656
        self.set_port(port)
B
barrierye 已提交
657 658 659 660 661 662 663 664 665 666 667 668 669 670

    def _launch_brpc_service(self, bserver):
        bserver.run_server()

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def run_server(self):
        p_bserver = Process(
            target=self._launch_brpc_service, args=(self.bserver_, ))
        p_bserver.start()
B
barrierye 已提交
671 672
        options = [('grpc.max_send_message_length', self.body_size_),
                   ('grpc.max_receive_message_length', self.body_size_)]
B
barrierye 已提交
673
        server = grpc.server(
B
barrierye 已提交
674 675 676
            futures.ThreadPoolExecutor(max_workers=self.worker_num_),
            options=options,
            maximum_concurrent_rpcs=self.concurrency_)
B
barrierye 已提交
677
        multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
B
barrierye 已提交
678
            MultiLangServerServiceServicer(
H
HexToString 已提交
679
                self.bclient_config_path_list, self.is_multi_model_,
B
barrierye 已提交
680
                ["0.0.0.0:{}".format(self.port_list_[0])]), server)
B
barrierye 已提交
681 682 683
        server.add_insecure_port('[::]:{}'.format(self.gport_))
        server.start()
        p_bserver.join()
Z
update  
zhangjun 已提交
684
        server.wait_for_termination()