server.py 27.0 KB
Newer Older
Z
update  
zhangjun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
M
MRXLT 已提交
14 15 16

import os
import tarfile
M
MRXLT 已提交
17
import socket
Z
zhangjun 已提交
18
import paddle_serving_server as paddle_serving_server
Z
zhangjun 已提交
19
from paddle_serving_server.rpc_service import MultiLangServerServiceServicer
Z
update  
zhangjun 已提交
20 21
from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
Z
zhangjun 已提交
22
from .proto import multi_lang_general_model_service_pb2_grpc
Z
update  
zhangjun 已提交
23
import google.protobuf.text_format
24
import time
Z
update  
zhangjun 已提交
25
from .version import serving_server_version, version_suffix, device_type
M
MRXLT 已提交
26
from contextlib import closing
G
guru4elephant 已提交
27
import argparse
Z
zhangjun 已提交
28

J
Jiawei Wang 已提交
29
import sys
W
wangjiawei04 已提交
30 31
if sys.platform.startswith('win') is False:
    import fcntl
M
MRXLT 已提交
32
import shutil
Z
update  
zhangjun 已提交
33
import platform
B
barrierye 已提交
34 35
import numpy as np
import grpc
B
barrierye 已提交
36
import sys
37
import collections
Z
zhangjun 已提交
38

B
barrierye 已提交
39 40 41
from multiprocessing import Pool, Process
from concurrent import futures

Z
update  
zhangjun 已提交
42

M
MRXLT 已提交
43 44 45 46
class Server(object):
    def __init__(self):
        self.server_handle_ = None
        self.infer_service_conf = None
H
HexToString 已提交
47
        self.model_toolkit_conf = []#The quantity is equal to the InferOp quantity,Engine--OP
M
MRXLT 已提交
48 49
        self.resource_conf = None
        self.memory_optimization = False
M
MRXLT 已提交
50
        self.ir_optimization = False
H
HexToString 已提交
51 52 53 54 55 56 57
        self.model_conf = collections.OrderedDict()# save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
        self.workflow_fn = "workflow.prototxt"#only one for one Service,Workflow--Op 
        self.resource_fn = "resource.prototxt"#only one for one Service,model_toolkit_fn and  general_model_config_fn is recorded in this file
        self.infer_service_fn = "infer_service.prototxt"#only one for one Service,Service--Workflow
        self.model_toolkit_fn = []#["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
        self.general_model_config_fn = []#["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
        self.subdirectory = []#The quantity is equal to the InferOp quantity, and name = node.name = engine.name
W
wangjiawei04 已提交
58
        self.cube_config_fn = "cube.conf"
M
MRXLT 已提交
59 60
        self.workdir = ""
        self.max_concurrency = 0
M
MRXLT 已提交
61
        self.num_threads = 2
M
MRXLT 已提交
62 63
        self.port = 8080
        self.reload_interval_s = 10
M
MRXLT 已提交
64
        self.max_body_size = 64 * 1024 * 1024
M
MRXLT 已提交
65 66
        self.module_path = os.path.dirname(paddle_serving_server.__file__)
        self.cur_path = os.getcwd()
M
MRXLT 已提交
67
        self.use_local_bin = False
Z
zhangjun 已提交
68
        self.mkl_flag = False
Z
zhangjun 已提交
69
        self.device = "cpu"
M
MRXLT 已提交
70
        self.gpuid = 0
M
add trt  
MRXLT 已提交
71
        self.use_trt = False
Z
zhangjun 已提交
72 73
        self.use_lite = False
        self.use_xpu = False
H
HexToString 已提交
74
        self.model_config_paths = collections.OrderedDict()  # save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
75 76
        self.product_name = None
        self.container_id = None
M
MRXLT 已提交
77

H
HexToString 已提交
78 79
    def get_fetch_list(self,infer_node_idx = -1 ):
        fetch_names = [var.alias_name for var in list(self.model_conf.values())[infer_node_idx].fetch_var]
B
fix cpu  
barriery 已提交
80 81
        return fetch_names

M
MRXLT 已提交
82 83 84 85 86 87
    def set_max_concurrency(self, concurrency):
        self.max_concurrency = concurrency

    def set_num_threads(self, threads):
        self.num_threads = threads

M
MRXLT 已提交
88 89 90 91 92 93 94 95
    def set_max_body_size(self, body_size):
        if body_size >= self.max_body_size:
            self.max_body_size = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )

96 97 98
    def use_encryption_model(self, flag=False):
        self.encryption_model = flag

M
MRXLT 已提交
99 100 101 102 103 104 105 106 107
    def set_port(self, port):
        self.port = port

    def set_reload_interval(self, interval):
        self.reload_interval_s = interval

    def set_op_sequence(self, op_seq):
        self.workflow_conf = op_seq

B
barrierye 已提交
108 109 110
    def set_op_graph(self, op_graph):
        self.workflow_conf = op_graph

M
MRXLT 已提交
111 112 113
    def set_memory_optimize(self, flag=False):
        self.memory_optimization = flag

M
MRXLT 已提交
114 115 116
    def set_ir_optimize(self, flag=False):
        self.ir_optimization = flag

117 118 119 120 121 122 123 124 125 126
    def set_product_name(self, product_name=None):
        if product_name == None:
            raise ValueError("product_name can't be None.")
        self.product_name = product_name

    def set_container_id(self, container_id):
        if container_id == None:
            raise ValueError("container_id can't be None.")
        self.container_id = container_id

M
MRXLT 已提交
127 128 129 130
    def check_local_bin(self):
        if "SERVING_BIN" in os.environ:
            self.use_local_bin = True
            self.bin_path = os.environ["SERVING_BIN"]
M
MRXLT 已提交
131

M
MRXLT 已提交
132
    def check_cuda(self):
M
MRXLT 已提交
133 134 135
        if os.system("ls /dev/ | grep nvidia > /dev/null") == 0:
            pass
        else:
M
MRXLT 已提交
136
            raise SystemExit(
M
MRXLT 已提交
137
                "GPU not found, please check your environment or use cpu version by \"pip install paddle_serving_server\""
M
MRXLT 已提交
138 139
            )

Z
zhangjun 已提交
140 141 142
    def set_device(self, device="cpu"):
        self.device = device

M
MRXLT 已提交
143 144 145
    def set_gpuid(self, gpuid=0):
        self.gpuid = gpuid

M
bug fix  
MRXLT 已提交
146
    def set_trt(self):
M
add trt  
MRXLT 已提交
147 148
        self.use_trt = True

Z
zhangjun 已提交
149 150 151 152 153 154
    def set_lite(self):
        self.use_lite = True

    def set_xpu(self):
        self.use_xpu = True

H
HexToString 已提交
155
    def _prepare_engine(self, model_config_paths, device, use_encryption_model):
M
MRXLT 已提交
156
        if self.model_toolkit_conf == None:
H
HexToString 已提交
157
            self.model_toolkit_conf = []
M
MRXLT 已提交
158

B
barrierye 已提交
159 160 161 162
        for engine_name, model_config_path in model_config_paths.items():
            engine = server_sdk.EngineDesc()
            engine.name = engine_name
            # engine.reloadable_meta = model_config_path + "/fluid_time_file"
163
            engine.reloadable_meta = model_config_path + "/fluid_time_file"
B
barrierye 已提交
164 165 166 167 168
            os.system("touch {}".format(engine.reloadable_meta))
            engine.reloadable_type = "timestamp_ne"
            engine.runtime_thread_num = 0
            engine.batch_infer_size = 0
            engine.enable_batch_align = 0
Z
update  
zhangjun 已提交
169
            engine.model_dir = model_config_path
B
barrierye 已提交
170
            engine.enable_memory_optimization = self.memory_optimization
M
MRXLT 已提交
171
            engine.enable_ir_optimization = self.ir_optimization
M
add trt  
MRXLT 已提交
172
            engine.use_trt = self.use_trt
Z
update  
zhangjun 已提交
173 174
            engine.use_lite = self.use_lite
            engine.use_xpu = self.use_xpu
Z
fix  
zhangjun 已提交
175
            if os.path.exists('{}/__params__'.format(model_config_path)):
Z
update  
zhangjun 已提交
176
                engine.combined_model = True
Z
fix  
zhangjun 已提交
177 178
            else:
                engine.combined_model = False
Z
update  
zhangjun 已提交
179 180
            if use_encryption_model:
                engine.encrypted_model = True
Z
fix  
zhangjun 已提交
181
            engine.type = "PADDLE_INFER"
H
HexToString 已提交
182 183
            self.model_toolkit_conf.append(server_sdk.ModelToolkitConf())
            self.model_toolkit_conf[-1].engines.extend([engine])
M
MRXLT 已提交
184 185 186 187 188 189 190 191 192 193

    def _prepare_infer_service(self, port):
        if self.infer_service_conf == None:
            self.infer_service_conf = server_sdk.InferServiceConf()
            self.infer_service_conf.port = port
            infer_service = server_sdk.InferService()
            infer_service.name = "GeneralModelService"
            infer_service.workflows.extend(["workflow1"])
            self.infer_service_conf.services.extend([infer_service])

M
MRXLT 已提交
194
    def _prepare_resource(self, workdir, cube_conf):
195
        self.workdir = workdir
M
MRXLT 已提交
196 197
        if self.resource_conf == None:
            self.resource_conf = server_sdk.ResourceConf()
H
HexToString 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
            for idx, op_general_model_config_fn in enumerate(self.general_model_config_fn):
                with open("{}/{}".format(workdir, op_general_model_config_fn),
                        "w") as fout:
                    fout.write(str(list(self.model_conf.values())[idx]))
                for workflow in self.workflow_conf.workflows:
                    for node in workflow.nodes:
                        if "dist_kv" in node.name:
                            self.resource_conf.cube_config_path = workdir
                            self.resource_conf.cube_config_file = self.cube_config_fn
                            if cube_conf == None:
                                raise ValueError(
                                    "Please set the path of cube.conf while use dist_kv op."
                                )
                            shutil.copy(cube_conf, workdir)
                            if "quant" in node.name:
                                self.resource_conf.cube_quant_bits = 8
                self.resource_conf.model_toolkit_path.extend([workdir])
                self.resource_conf.model_toolkit_file.extend([self.model_toolkit_fn[idx]])
                self.resource_conf.general_model_path.extend([workdir])
                self.resource_conf.general_model_file.extend([op_general_model_config_fn])
                #TODO:figure out the meaning of product_name and container_id.
                if self.product_name != None:
                    self.resource_conf.auth_product_name = self.product_name
                if self.container_id != None:
                    self.resource_conf.auth_container_id = self.container_id
M
MRXLT 已提交
223 224 225 226 227

    def _write_pb_str(self, filepath, pb_obj):
        with open(filepath, "w") as fout:
            fout.write(str(pb_obj))

H
HexToString 已提交
228
    def load_model_config(self, model_config_paths_args):
B
barrierye 已提交
229 230 231
        # At present, Serving needs to configure the model path in
        # the resource.prototxt file to determine the input and output
        # format of the workflow. To ensure that the input and output
B
barrierye 已提交
232
        # of multiple models are the same.
H
HexToString 已提交
233 234 235 236 237 238 239 240 241 242
        if isinstance(model_config_paths_args, str):
            model_config_paths_args = [model_config_paths_args]

        for single_model_config in model_config_paths_args:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
                raise ValueError("The input of --model should be a dir not file.")
        
        if isinstance(model_config_paths_args, list):
B
barrierye 已提交
243
            # If there is only one model path, use the default infer_op.
M
MRXLT 已提交
244
            # Because there are several infer_op type, we need to find
B
barrierye 已提交
245
            # it from workflow_conf.
H
HexToString 已提交
246 247 248
            default_engine_types = [
                'GeneralInferOp', 'GeneralDistKVInferOp',
                'GeneralDistKVQuantInferOp','GeneralDetectionOp',
B
barrierye 已提交
249
            ]
H
HexToString 已提交
250 251 252
            # now only support single-workflow.
            # TODO:support multi-workflow
            model_config_paths_list_idx = 0
B
barrierye 已提交
253
            for node in self.workflow_conf.workflows[0].nodes:
H
HexToString 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
                if node.type in default_engine_types:
                    if node.name is None:
                        raise Exception(
                            "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
                        )
                    
                    f = open("{}/serving_server_conf.prototxt".format(
                        model_config_paths_args[model_config_paths_list_idx]), 'r')
                    self.model_conf[node.name] = google.protobuf.text_format.Merge(str(f.read()), m_config.GeneralModelConfig())
                    self.model_config_paths[node.name] = model_config_paths_args[model_config_paths_list_idx]
                    self.general_model_config_fn.append(node.name+"/general_model.prototxt")
                    self.model_toolkit_fn.append(node.name+"/model_toolkit.prototxt")
                    self.subdirectory.append(node.name)
                    model_config_paths_list_idx += 1
                    if model_config_paths_list_idx == len(model_config_paths_args):
                        break
        #Right now, this is not useful.
        elif isinstance(model_config_paths_args, dict):
            self.model_config_paths = collections.OrderedDict()
            for node_str, path in model_config_paths_args.items():
B
barrierye 已提交
274 275 276 277 278
                node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(node_str, node)
                self.model_config_paths[node.name] = path
            print("You have specified multiple model paths, please ensure "
                  "that the input and output of multiple models are the same.")
H
HexToString 已提交
279 280 281
            f = open("{}/serving_server_conf.prototxt".format(path), 'r')
            self.model_conf[node.name] = google.protobuf.text_format.Merge(
            str(f.read()), m_config.GeneralModelConfig())
B
barrierye 已提交
282
        else:
H
HexToString 已提交
283
            raise Exception("The type of model_config_paths must be str or list or "
B
barrierye 已提交
284
                            "dict({op: model_path}), not {}.".format(
H
HexToString 已提交
285
                                type(model_config_paths_args)))
M
MRXLT 已提交
286 287
        # check config here
        # print config here
Z
update  
zhangjun 已提交
288

Z
zhangjun 已提交
289 290 291 292 293 294 295 296 297 298 299
    def use_mkl(self, flag):
        self.mkl_flag = flag

    def get_device_version(self):
        avx_flag = False
        mkl_flag = self.mkl_flag
        r = os.system("cat /proc/cpuinfo | grep avx > /dev/null 2>&1")
        if r == 0:
            avx_flag = True
        if avx_flag:
            if mkl_flag:
Z
update  
zhangjun 已提交
300
                device_version = "cpu-avx-mkl"
Z
zhangjun 已提交
301
            else:
Z
update  
zhangjun 已提交
302
                device_version = "cpu-avx-openblas"
Z
zhangjun 已提交
303 304 305 306 307
        else:
            if mkl_flag:
                print(
                    "Your CPU does not support AVX, server will running with noavx-openblas mode."
                )
Z
update  
zhangjun 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320
            device_version = "cpu-noavx-openblas"
        return device_version

    def get_serving_bin_name(self):
        if device_type == "0":
            device_version = self.get_device_version()
        elif device_type == "1":
            if version_suffix == "101" or version_suffix == "102":
                device_version = "gpu-" + version_suffix
            else:
                device_version = "gpu-cuda" + version_suffix
        elif device_type == "2":
            device_version = "xpu-" + platform.machine()
Z
zhangjun 已提交
321
        return device_version
M
MRXLT 已提交
322 323 324 325

    def download_bin(self):
        os.chdir(self.module_path)
        need_download = False
M
MRXLT 已提交
326 327 328

        #acquire lock
        version_file = open("{}/version.py".format(self.module_path), "r")
Z
update  
zhangjun 已提交
329

Z
fix  
zhangjun 已提交
330 331 332 333 334
        folder_name = "serving-%s-%s" % (self.get_serving_bin_name(),
                                         serving_server_version)
        tar_name = "%s.tar.gz" % folder_name
        bin_url = "https://paddle-serving.bj.bcebos.com/bin/%s" % tar_name

335 336 337 338
        self.server_path = os.path.join(self.module_path, folder_name)

        download_flag = "{}/{}.is_download".format(self.module_path,
                                                   folder_name)
M
MRXLT 已提交
339 340 341

        fcntl.flock(version_file, fcntl.LOCK_EX)

342 343 344 345 346
        if os.path.exists(download_flag):
            os.chdir(self.cur_path)
            self.bin_path = self.server_path + "/serving"
            return

M
MRXLT 已提交
347
        if not os.path.exists(self.server_path):
348 349
            os.system("touch {}/{}.is_download".format(self.module_path,
                                                       folder_name))
M
MRXLT 已提交
350
            print('Frist time run, downloading PaddleServing components ...')
M
MRXLT 已提交
351

M
MRXLT 已提交
352 353 354 355
            r = os.system('wget ' + bin_url + ' --no-check-certificate')
            if r != 0:
                if os.path.exists(tar_name):
                    os.remove(tar_name)
M
MRXLT 已提交
356
                raise SystemExit(
T
TeslaZhao 已提交
357 358
                    'Download failed, please check your network or permission of {}.'
                    .format(self.module_path))
M
MRXLT 已提交
359 360 361 362 363 364 365 366 367
            else:
                try:
                    print('Decompressing files ..')
                    tar = tarfile.open(tar_name)
                    tar.extractall()
                    tar.close()
                except:
                    if os.path.exists(exe_path):
                        os.remove(exe_path)
M
MRXLT 已提交
368
                    raise SystemExit(
T
TeslaZhao 已提交
369 370
                        'Decompressing failed, please check your permission of {} or disk space left.'
                        .format(self.module_path))
M
MRXLT 已提交
371 372
                finally:
                    os.remove(tar_name)
M
MRXLT 已提交
373
        #release lock
B
barrierye 已提交
374
        version_file.close()
M
MRXLT 已提交
375 376
        os.chdir(self.cur_path)
        self.bin_path = self.server_path + "/serving"
Z
update  
zhangjun 已提交
377

M
MRXLT 已提交
378 379 380 381
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
W
wangjiawei04 已提交
382
                       use_encryption_model=False,
M
MRXLT 已提交
383
                       cube_conf=None):
M
MRXLT 已提交
384 385
        if workdir == None:
            workdir = "./tmp"
Z
zhangjun 已提交
386
            os.system("mkdir -p {}".format(workdir))
M
MRXLT 已提交
387
        else:
Z
zhangjun 已提交
388
            os.system("mkdir -p {}".format(workdir))
H
HexToString 已提交
389
        for subdir in self.subdirectory:
390
            os.system("mkdir -p {}/{}".format(workdir, subdir))
H
HexToString 已提交
391 392
            os.system("touch {}/{}/fluid_time_file".format(workdir, subdir))

M
MRXLT 已提交
393
        if not self.port_is_available(port):
G
gongweibao 已提交
394
            raise SystemExit("Port {} is already used".format(port))
M
MRXLT 已提交
395

G
guru4elephant 已提交
396
        self.set_port(port)
M
MRXLT 已提交
397
        self._prepare_resource(workdir, cube_conf)
H
HexToString 已提交
398 399
        self._prepare_engine(self.model_config_paths, device,
                             use_encryption_model)
M
MRXLT 已提交
400 401 402 403 404
        self._prepare_infer_service(port)
        self.workdir = workdir

        infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
        self._write_pb_str(infer_service_fn, self.infer_service_conf)
H
HexToString 已提交
405 406

        workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
M
MRXLT 已提交
407
        self._write_pb_str(workflow_fn, self.workflow_conf)
H
HexToString 已提交
408 409

        resource_fn = "{}/{}".format(workdir, self.resource_fn)
M
MRXLT 已提交
410
        self._write_pb_str(resource_fn, self.resource_conf)
H
HexToString 已提交
411 412 413 414

        for idx,single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
            model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn)
            self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx])
M
MRXLT 已提交
415

M
MRXLT 已提交
416
    def port_is_available(self, port):
M
MRXLT 已提交
417 418
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
419
            result = sock.connect_ex(('0.0.0.0', port))
M
MRXLT 已提交
420 421 422 423 424
        if result != 0:
            return True
        else:
            return False

M
MRXLT 已提交
425 426 427
    def run_server(self):
        # just run server with system command
        # currently we do not load cube
M
MRXLT 已提交
428
        self.check_local_bin()
M
MRXLT 已提交
429 430
        if not self.use_local_bin:
            self.download_bin()
B
fix bug  
barrierye 已提交
431 432 433
            # wait for other process to download server bin
            while not os.path.exists(self.server_path):
                time.sleep(1)
M
MRXLT 已提交
434 435
        else:
            print("Use local bin : {}".format(self.bin_path))
Z
zhangjun 已提交
436
        #self.check_cuda()
Z
zhangjun 已提交
437 438
        # Todo: merge CPU and GPU code, remove device to model_toolkit
        if self.device == "cpu" or self.device == "arm":
Z
zhangjun 已提交
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
            command = "{} " \
                      "-enable_model_toolkit " \
                      "-inferservice_path {} " \
                      "-inferservice_file {} " \
                      "-max_concurrency {} " \
                      "-num_threads {} " \
                      "-port {} " \
                      "-reload_interval_s {} " \
                      "-resource_path {} " \
                      "-resource_file {} " \
                      "-workflow_path {} " \
                      "-workflow_file {} " \
                      "-bthread_concurrency {} " \
                      "-max_body_size {} ".format(
                          self.bin_path,
                          self.workdir,
                          self.infer_service_fn,
                          self.max_concurrency,
                          self.num_threads,
                          self.port,
                          self.reload_interval_s,
                          self.workdir,
                          self.resource_fn,
                          self.workdir,
                          self.workflow_fn,
                          self.num_threads,
                          self.max_body_size)
        else:
            command = "{} " \
                      "-enable_model_toolkit " \
                      "-inferservice_path {} " \
                      "-inferservice_file {} " \
                      "-max_concurrency {} " \
                      "-num_threads {} " \
                      "-port {} " \
                      "-reload_interval_s {} " \
                      "-resource_path {} " \
                      "-resource_file {} " \
                      "-workflow_path {} " \
                      "-workflow_file {} " \
                      "-bthread_concurrency {} " \
                      "-gpuid {} " \
                      "-max_body_size {} ".format(
                          self.bin_path,
                          self.workdir,
                          self.infer_service_fn,
                          self.max_concurrency,
                          self.num_threads,
                          self.port,
                          self.reload_interval_s,
                          self.workdir,
                          self.resource_fn,
                          self.workdir,
                          self.workflow_fn,
                          self.num_threads,
                          self.gpuid,
                          self.max_body_size)
M
MRXLT 已提交
496 497
        print("Going to Run Comand")
        print(command)
498

M
MRXLT 已提交
499
        os.system(command)
B
barrierye 已提交
500 501

class MultiLangServer(object):
B
barrierye 已提交
502
    def __init__(self):
B
barrierye 已提交
503
        self.bserver_ = Server()
B
barrierye 已提交
504 505 506
        self.worker_num_ = 4
        self.body_size_ = 64 * 1024 * 1024
        self.concurrency_ = 100000
507
        self.is_multi_model_ = False  # for model ensemble, which is not useful right now.
B
barrierye 已提交
508

B
barrierye 已提交
509
    def set_max_concurrency(self, concurrency):
B
barrierye 已提交
510
        self.concurrency_ = concurrency
B
barrierye 已提交
511 512
        self.bserver_.set_max_concurrency(concurrency)

513 514 515
    def set_device(self, device="cpu"):
        self.device = device

B
barrierye 已提交
516
    def set_num_threads(self, threads):
B
barrierye 已提交
517
        self.worker_num_ = threads
B
barrierye 已提交
518 519 520 521
        self.bserver_.set_num_threads(threads)

    def set_max_body_size(self, body_size):
        self.bserver_.set_max_body_size(body_size)
B
barrierye 已提交
522 523 524 525 526 527
        if body_size >= self.body_size_:
            self.body_size_ = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )
B
barrierye 已提交
528

529 530 531
    def use_encryption_model(self, flag=False):
        self.encryption_model = flag

B
barrierye 已提交
532 533 534 535 536
    def set_port(self, port):
        self.gport_ = port

    def set_reload_interval(self, interval):
        self.bserver_.set_reload_interval(interval)
B
barrierye 已提交
537 538 539 540

    def set_op_sequence(self, op_seq):
        self.bserver_.set_op_sequence(op_seq)

B
barrierye 已提交
541 542
    def set_op_graph(self, op_graph):
        self.bserver_.set_op_graph(op_graph)
Z
update  
zhangjun 已提交
543

Z
zhangjun 已提交
544 545
    def use_mkl(self, flag):
        self.bserver_.use_mkl(flag)
Z
update  
zhangjun 已提交
546

B
barrierye 已提交
547 548 549 550 551
    def set_memory_optimize(self, flag=False):
        self.bserver_.set_memory_optimize(flag)

    def set_ir_optimize(self, flag=False):
        self.bserver_.set_ir_optimize(flag)
B
barrierye 已提交
552

B
barrierye 已提交
553 554 555
    def set_gpuid(self, gpuid=0):
        self.bserver_.set_gpuid(gpuid)

H
HexToString 已提交
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
    def load_model_config(self, server_config_dir_paths, client_config_path=None):
        if isinstance(server_config_dir_paths, str):
            server_config_dir_paths = [server_config_dir_paths]
        elif isinstance(server_config_dir_paths, list):
            pass
        else:
            raise Exception("The type of model_config_paths must be str or list"
                            ", not {}.".format(
                                type(server_config_dir_paths)))
        

        for single_model_config in server_config_dir_paths:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
                raise ValueError("The input of --model should be a dir not file.")

        self.bserver_.load_model_config(server_config_dir_paths)
B
barrierye 已提交
574
        if client_config_path is None:
H
HexToString 已提交
575 576
            #now dict is not useful.
            if isinstance(server_config_dir_paths, dict):
B
barrierye 已提交
577
                self.is_multi_model_ = True
H
HexToString 已提交
578 579 580 581 582 583
                client_config_path = []
                for server_config_path_items in list(server_config_dir_paths.items()):
                    client_config_path.append( server_config_path_items[1] )
            elif isinstance(server_config_dir_paths, list):
                self.is_multi_model_ = False
                client_config_path = server_config_dir_paths
B
barrierye 已提交
584
            else:
H
HexToString 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
                raise Exception("The type of model_config_paths must be str or list or "
                            "dict({op: model_path}), not {}.".format(
                                type(server_config_dir_paths)))
        if isinstance(client_config_path, str):
            client_config_path = [client_config_path]
        elif isinstance(client_config_path, list):
            pass
        else:# dict is not support right now.
            raise Exception("The type of client_config_path must be str or list or "
                            "dict({op: model_path}), not {}.".format(
                                type(client_config_path)))
        if len(client_config_path) != len(server_config_dir_paths):
            raise Warning("The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
                            .format( len(client_config_path), len(server_config_dir_paths) )
                            )
        self.bclient_config_path_list = client_config_path
B
barrierye 已提交
601

M
MRXLT 已提交
602 603 604 605
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
H
HexToString 已提交
606
                       use_encryption_model=False,
M
MRXLT 已提交
607
                       cube_conf=None):
B
barrierye 已提交
608 609
        if not self._port_is_available(port):
            raise SystemExit("Prot {} is already used".format(port))
B
barrierye 已提交
610 611 612 613 614 615 616 617
        default_port = 12000
        self.port_list_ = []
        for i in range(1000):
            if default_port + i != port and self._port_is_available(default_port
                                                                    + i):
                self.port_list_.append(default_port + i)
                break
        self.bserver_.prepare_server(
M
MRXLT 已提交
618 619 620
            workdir=workdir,
            port=self.port_list_[0],
            device=device,
H
HexToString 已提交
621
            use_encryption_model=use_encryption_model,
M
MRXLT 已提交
622
            cube_conf=cube_conf)
B
barrierye 已提交
623
        self.set_port(port)
B
barrierye 已提交
624 625 626 627 628 629 630 631 632 633 634 635 636 637

    def _launch_brpc_service(self, bserver):
        bserver.run_server()

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def run_server(self):
        p_bserver = Process(
            target=self._launch_brpc_service, args=(self.bserver_, ))
        p_bserver.start()
B
barrierye 已提交
638 639
        options = [('grpc.max_send_message_length', self.body_size_),
                   ('grpc.max_receive_message_length', self.body_size_)]
B
barrierye 已提交
640
        server = grpc.server(
B
barrierye 已提交
641 642 643
            futures.ThreadPoolExecutor(max_workers=self.worker_num_),
            options=options,
            maximum_concurrent_rpcs=self.concurrency_)
B
barrierye 已提交
644
        multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
B
barrierye 已提交
645
            MultiLangServerServiceServicer(
H
HexToString 已提交
646
                self.bclient_config_path_list, self.is_multi_model_,
B
barrierye 已提交
647
                ["0.0.0.0:{}".format(self.port_list_[0])]), server)
B
barrierye 已提交
648 649 650
        server.add_insecure_port('[::]:{}'.format(self.gport_))
        server.start()
        p_bserver.join()
Z
update  
zhangjun 已提交
651
        server.wait_for_termination()