server.py 30.4 KB
Newer Older
Z
update  
zhangjun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
M
MRXLT 已提交
14 15 16

import os
import tarfile
M
MRXLT 已提交
17
import socket
Z
zhangjun 已提交
18
import paddle_serving_server as paddle_serving_server
Z
zhangjun 已提交
19
from paddle_serving_server.rpc_service import MultiLangServerServiceServicer
Z
update  
zhangjun 已提交
20 21
from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
Z
zhangjun 已提交
22
from .proto import multi_lang_general_model_service_pb2_grpc
Z
update  
zhangjun 已提交
23
import google.protobuf.text_format
24
import time
25
from .version import version_tag, version_suffix, device_type
M
MRXLT 已提交
26
from contextlib import closing
G
guru4elephant 已提交
27
import argparse
Z
zhangjun 已提交
28

J
Jiawei Wang 已提交
29
import sys
W
wangjiawei04 已提交
30 31
if sys.platform.startswith('win') is False:
    import fcntl
M
MRXLT 已提交
32
import shutil
Z
update  
zhangjun 已提交
33
import platform
B
barrierye 已提交
34 35
import numpy as np
import grpc
B
barrierye 已提交
36
import sys
37
import collections
Z
update  
zhangjun 已提交
38
import subprocess
Z
zhangjun 已提交
39

B
barrierye 已提交
40 41 42
from multiprocessing import Pool, Process
from concurrent import futures

Z
update  
zhangjun 已提交
43

H
HexToString 已提交
44 45
# The whole file is about to be discarded.
# We will use default config-file to start C++Server.
M
MRXLT 已提交
46 47
class Server(object):
    def __init__(self):
H
HexToString 已提交
48 49 50 51 52 53 54 55 56 57 58
        """
        self.model_toolkit_conf:'list'=[] # The quantity of self.model_toolkit_conf is equal to the InferOp quantity/Engine--OP
        self.model_conf:'collections.OrderedDict()' # Save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
        self.workflow_fn:'str'="workflow.prototxt" # Only one for one Service/Workflow
        self.resource_fn:'str'="resource.prototxt" # Only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file
        self.infer_service_fn:'str'="infer_service.prototxt" # Only one for one Service,Service--Workflow
        self.model_toolkit_fn:'list'=[] # ["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
        self.general_model_config_fn:'list'=[] # ["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
        self.subdirectory:'list'=[] # The quantity is equal to the InferOp quantity, and name = node.name = engine.name
        self.model_config_paths:'collections.OrderedDict()' # Save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
        """
M
MRXLT 已提交
59 60
        self.server_handle_ = None
        self.infer_service_conf = None
H
HexToString 已提交
61
        self.model_toolkit_conf = []
M
MRXLT 已提交
62 63
        self.resource_conf = None
        self.memory_optimization = False
M
MRXLT 已提交
64
        self.ir_optimization = False
Z
zhangjun 已提交
65
        self.model_conf = collections.OrderedDict()
H
HexToString 已提交
66 67 68
        self.workflow_fn = "workflow.prototxt"
        self.resource_fn = "resource.prototxt"
        self.infer_service_fn = "infer_service.prototxt"
Z
zhangjun 已提交
69 70 71
        self.model_toolkit_fn = []
        self.general_model_config_fn = []
        self.subdirectory = []
W
wangjiawei04 已提交
72
        self.cube_config_fn = "cube.conf"
M
MRXLT 已提交
73 74
        self.workdir = ""
        self.max_concurrency = 0
M
MRXLT 已提交
75
        self.num_threads = 2
M
MRXLT 已提交
76
        self.port = 8080
77 78
        self.precision = "fp32"
        self.use_calib = False
M
MRXLT 已提交
79
        self.reload_interval_s = 10
M
MRXLT 已提交
80
        self.max_body_size = 64 * 1024 * 1024
M
MRXLT 已提交
81 82
        self.module_path = os.path.dirname(paddle_serving_server.__file__)
        self.cur_path = os.getcwd()
M
MRXLT 已提交
83
        self.use_local_bin = False
Z
zhangjun 已提交
84
        self.mkl_flag = False
Z
zhangjun 已提交
85
        self.device = "cpu"
86 87 88
        self.gpuid = []
        self.op_num = [0]
        self.op_max_batch = [32]
M
add trt  
MRXLT 已提交
89
        self.use_trt = False
90
        self.gpu_multi_stream = False
Z
zhangjun 已提交
91 92
        self.use_lite = False
        self.use_xpu = False
Z
zhangjun 已提交
93
        self.model_config_paths = collections.OrderedDict()
94 95
        self.product_name = None
        self.container_id = None
M
MRXLT 已提交
96

Z
zhangjun 已提交
97 98 99 100 101
    def get_fetch_list(self, infer_node_idx=-1):
        fetch_names = [
            var.alias_name
            for var in list(self.model_conf.values())[infer_node_idx].fetch_var
        ]
B
fix cpu  
barriery 已提交
102 103
        return fetch_names

M
MRXLT 已提交
104 105 106 107 108 109
    def set_max_concurrency(self, concurrency):
        self.max_concurrency = concurrency

    def set_num_threads(self, threads):
        self.num_threads = threads

M
MRXLT 已提交
110 111 112 113 114 115 116 117
    def set_max_body_size(self, body_size):
        if body_size >= self.max_body_size:
            self.max_body_size = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )

118 119 120
    def use_encryption_model(self, flag=False):
        self.encryption_model = flag

M
MRXLT 已提交
121 122 123
    def set_port(self, port):
        self.port = port

124 125 126 127 128 129
    def set_precision(self, precision="fp32"):
        self.precision = precision

    def set_use_calib(self, use_calib=False):
        self.use_calib = use_calib

M
MRXLT 已提交
130 131 132 133 134 135
    def set_reload_interval(self, interval):
        self.reload_interval_s = interval

    def set_op_sequence(self, op_seq):
        self.workflow_conf = op_seq

B
barrierye 已提交
136 137 138
    def set_op_graph(self, op_graph):
        self.workflow_conf = op_graph

M
MRXLT 已提交
139 140 141
    def set_memory_optimize(self, flag=False):
        self.memory_optimization = flag

M
MRXLT 已提交
142 143 144
    def set_ir_optimize(self, flag=False):
        self.ir_optimization = flag

145
    # Multi-Server does not have this Function. 
146 147 148 149 150
    def set_product_name(self, product_name=None):
        if product_name == None:
            raise ValueError("product_name can't be None.")
        self.product_name = product_name

151
    # Multi-Server does not have this Function.
152 153 154 155 156
    def set_container_id(self, container_id):
        if container_id == None:
            raise ValueError("container_id can't be None.")
        self.container_id = container_id

M
MRXLT 已提交
157 158 159 160
    def check_local_bin(self):
        if "SERVING_BIN" in os.environ:
            self.use_local_bin = True
            self.bin_path = os.environ["SERVING_BIN"]
M
MRXLT 已提交
161

M
MRXLT 已提交
162
    def check_cuda(self):
M
MRXLT 已提交
163 164 165
        if os.system("ls /dev/ | grep nvidia > /dev/null") == 0:
            pass
        else:
M
MRXLT 已提交
166
            raise SystemExit(
M
MRXLT 已提交
167
                "GPU not found, please check your environment or use cpu version by \"pip install paddle_serving_server\""
M
MRXLT 已提交
168 169
            )

Z
zhangjun 已提交
170 171 172
    def set_device(self, device="cpu"):
        self.device = device

173
    def set_gpuid(self, gpuid):
H
HexToString 已提交
174 175 176
        if isinstance(gpuid, int):
            self.gpuid = str(gpuid)
        elif isinstance(gpuid, list):
H
HexToString 已提交
177
            self.gpuid = [str(x) for x in gpuid]
H
HexToString 已提交
178 179
        else:
            self.gpuid = gpuid
M
MRXLT 已提交
180

181 182 183 184 185 186
    def set_op_num(self, op_num):
        self.op_num = op_num

    def set_op_max_batch(self, op_max_batch):
        self.op_max_batch = op_max_batch

M
bug fix  
MRXLT 已提交
187
    def set_trt(self):
M
add trt  
MRXLT 已提交
188 189
        self.use_trt = True

190 191 192
    def set_gpu_multi_stream(self):
        self.gpu_multi_stream = True

Z
zhangjun 已提交
193 194 195 196 197 198
    def set_lite(self):
        self.use_lite = True

    def set_xpu(self):
        self.use_xpu = True

H
HexToString 已提交
199
    def _prepare_engine(self, model_config_paths, device, use_encryption_model):
M
MRXLT 已提交
200
        if self.model_toolkit_conf == None:
H
HexToString 已提交
201
            self.model_toolkit_conf = []
202 203
        self.device = device

H
HexToString 已提交
204 205
        # Generally, self.gpuid = str[] or str.
        # such as "0" or ["0"] or ["0,1"] or ["0,1" , "1,2"]
206 207
        if isinstance(self.gpuid, str):
            self.gpuid = [self.gpuid]
H
HexToString 已提交
208 209 210 211

        # when len(self.gpuid) means no gpuid is specified.
        # if self.device == "gpu" or self.use_trt:
        # we assume you forget to set gpuid, so set gpuid = ['0'];
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        if len(self.gpuid) == 0:
            if self.device == "gpu" or self.use_trt:
                self.gpuid.append("0")
            else:
                self.gpuid.append("-1")

        if isinstance(self.op_num, int):
            self.op_num = [self.op_num]
        if len(self.op_num) == 0:
            self.op_num.append(0)

        if isinstance(self.op_max_batch, int):
            self.op_max_batch = [self.op_max_batch]
        if len(self.op_max_batch) == 0:
            self.op_max_batch.append(32)

        index = 0
M
MRXLT 已提交
229

B
barrierye 已提交
230 231 232 233
        for engine_name, model_config_path in model_config_paths.items():
            engine = server_sdk.EngineDesc()
            engine.name = engine_name
            # engine.reloadable_meta = model_config_path + "/fluid_time_file"
234
            engine.reloadable_meta = model_config_path + "/fluid_time_file"
B
barrierye 已提交
235 236
            os.system("touch {}".format(engine.reloadable_meta))
            engine.reloadable_type = "timestamp_ne"
237 238 239 240 241
            engine.runtime_thread_num = self.op_num[index % len(self.op_num)]
            engine.batch_infer_size = self.op_max_batch[index %
                                                        len(self.op_max_batch)]

            engine.enable_batch_align = 1
Z
update  
zhangjun 已提交
242
            engine.model_dir = model_config_path
B
barrierye 已提交
243
            engine.enable_memory_optimization = self.memory_optimization
M
MRXLT 已提交
244
            engine.enable_ir_optimization = self.ir_optimization
M
add trt  
MRXLT 已提交
245
            engine.use_trt = self.use_trt
246
            engine.gpu_multi_stream = self.gpu_multi_stream
Z
update  
zhangjun 已提交
247 248
            engine.use_lite = self.use_lite
            engine.use_xpu = self.use_xpu
Z
zhangjun 已提交
249 250
            engine.use_gpu = False

251 252 253 254 255 256
            if len(self.gpuid) == 0:
                raise ValueError("CPU: self.gpuid = -1, GPU: must set it ")
            op_gpu_list = self.gpuid[index % len(self.gpuid)].split(",")
            for ids in op_gpu_list:
                engine.gpu_ids.extend([int(ids)])

H
HexToString 已提交
257 258 259 260 261 262 263 264 265 266 267 268
            if self.device == "gpu" or self.use_trt:
                engine.use_gpu = True
                # this is for Mixed use of GPU and CPU
                # if model-1 use GPU and set the device="gpu"
                # but gpuid[1] = "-1" which means use CPU in Model-2
                # so config about GPU should be False.
                if len(op_gpu_list) == 1:
                    if int(op_gpu_list[0]) == -1:
                        engine.use_gpu = False
                        engine.gpu_multi_stream = False
                        engine.use_trt = False

Z
fix  
zhangjun 已提交
269
            if os.path.exists('{}/__params__'.format(model_config_path)):
Z
update  
zhangjun 已提交
270
                engine.combined_model = True
Z
fix  
zhangjun 已提交
271 272
            else:
                engine.combined_model = False
Z
update  
zhangjun 已提交
273 274
            if use_encryption_model:
                engine.encrypted_model = True
Z
fix  
zhangjun 已提交
275
            engine.type = "PADDLE_INFER"
H
HexToString 已提交
276 277
            self.model_toolkit_conf.append(server_sdk.ModelToolkitConf())
            self.model_toolkit_conf[-1].engines.extend([engine])
278
            index = index + 1
M
MRXLT 已提交
279 280 281 282 283 284 285 286 287 288

    def _prepare_infer_service(self, port):
        if self.infer_service_conf == None:
            self.infer_service_conf = server_sdk.InferServiceConf()
            self.infer_service_conf.port = port
            infer_service = server_sdk.InferService()
            infer_service.name = "GeneralModelService"
            infer_service.workflows.extend(["workflow1"])
            self.infer_service_conf.services.extend([infer_service])

M
MRXLT 已提交
289
    def _prepare_resource(self, workdir, cube_conf):
290
        self.workdir = workdir
M
MRXLT 已提交
291 292
        if self.resource_conf == None:
            self.resource_conf = server_sdk.ResourceConf()
Z
zhangjun 已提交
293 294
            for idx, op_general_model_config_fn in enumerate(
                    self.general_model_config_fn):
H
HexToString 已提交
295
                with open("{}/{}".format(workdir, op_general_model_config_fn),
Z
zhangjun 已提交
296
                          "w") as fout:
H
HexToString 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310
                    fout.write(str(list(self.model_conf.values())[idx]))
                for workflow in self.workflow_conf.workflows:
                    for node in workflow.nodes:
                        if "dist_kv" in node.name:
                            self.resource_conf.cube_config_path = workdir
                            self.resource_conf.cube_config_file = self.cube_config_fn
                            if cube_conf == None:
                                raise ValueError(
                                    "Please set the path of cube.conf while use dist_kv op."
                                )
                            shutil.copy(cube_conf, workdir)
                            if "quant" in node.name:
                                self.resource_conf.cube_quant_bits = 8
                self.resource_conf.model_toolkit_path.extend([workdir])
Z
zhangjun 已提交
311 312
                self.resource_conf.model_toolkit_file.extend(
                    [self.model_toolkit_fn[idx]])
H
HexToString 已提交
313
                self.resource_conf.general_model_path.extend([workdir])
Z
zhangjun 已提交
314 315
                self.resource_conf.general_model_file.extend(
                    [op_general_model_config_fn])
H
HexToString 已提交
316 317 318 319 320
                #TODO:figure out the meaning of product_name and container_id.
                if self.product_name != None:
                    self.resource_conf.auth_product_name = self.product_name
                if self.container_id != None:
                    self.resource_conf.auth_container_id = self.container_id
M
MRXLT 已提交
321 322 323 324 325

    def _write_pb_str(self, filepath, pb_obj):
        with open(filepath, "w") as fout:
            fout.write(str(pb_obj))

H
HexToString 已提交
326
    def load_model_config(self, model_config_paths_args):
B
barrierye 已提交
327 328 329
        # At present, Serving needs to configure the model path in
        # the resource.prototxt file to determine the input and output
        # format of the workflow. To ensure that the input and output
B
barrierye 已提交
330
        # of multiple models are the same.
H
HexToString 已提交
331 332 333 334 335 336 337
        if isinstance(model_config_paths_args, str):
            model_config_paths_args = [model_config_paths_args]

        for single_model_config in model_config_paths_args:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
Z
zhangjun 已提交
338 339 340
                raise ValueError(
                    "The input of --model should be a dir not file.")

H
HexToString 已提交
341
        if isinstance(model_config_paths_args, list):
B
barrierye 已提交
342
            # If there is only one model path, use the default infer_op.
M
MRXLT 已提交
343
            # Because there are several infer_op type, we need to find
B
barrierye 已提交
344
            # it from workflow_conf.
H
HexToString 已提交
345
            default_engine_types = [
Z
zhangjun 已提交
346 347 348 349
                'GeneralInferOp',
                'GeneralDistKVInferOp',
                'GeneralDistKVQuantInferOp',
                'GeneralDetectionOp',
B
barrierye 已提交
350
            ]
H
HexToString 已提交
351 352 353
            # now only support single-workflow.
            # TODO:support multi-workflow
            model_config_paths_list_idx = 0
B
barrierye 已提交
354
            for node in self.workflow_conf.workflows[0].nodes:
H
HexToString 已提交
355 356 357 358 359
                if node.type in default_engine_types:
                    if node.name is None:
                        raise Exception(
                            "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
                        )
Z
zhangjun 已提交
360

H
HexToString 已提交
361
                    f = open("{}/serving_server_conf.prototxt".format(
Z
zhangjun 已提交
362 363 364 365 366 367 368 369 370 371 372 373
                        model_config_paths_args[model_config_paths_list_idx]),
                             'r')
                    self.model_conf[
                        node.name] = google.protobuf.text_format.Merge(
                            str(f.read()), m_config.GeneralModelConfig())
                    self.model_config_paths[
                        node.name] = model_config_paths_args[
                            model_config_paths_list_idx]
                    self.general_model_config_fn.append(
                        node.name + "/general_model.prototxt")
                    self.model_toolkit_fn.append(node.name +
                                                 "/model_toolkit.prototxt")
H
HexToString 已提交
374 375
                    self.subdirectory.append(node.name)
                    model_config_paths_list_idx += 1
Z
zhangjun 已提交
376 377
                    if model_config_paths_list_idx == len(
                            model_config_paths_args):
H
HexToString 已提交
378 379 380 381 382
                        break
        #Right now, this is not useful.
        elif isinstance(model_config_paths_args, dict):
            self.model_config_paths = collections.OrderedDict()
            for node_str, path in model_config_paths_args.items():
B
barrierye 已提交
383 384 385 386 387
                node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(node_str, node)
                self.model_config_paths[node.name] = path
            print("You have specified multiple model paths, please ensure "
                  "that the input and output of multiple models are the same.")
H
HexToString 已提交
388 389
            f = open("{}/serving_server_conf.prototxt".format(path), 'r')
            self.model_conf[node.name] = google.protobuf.text_format.Merge(
Z
zhangjun 已提交
390
                str(f.read()), m_config.GeneralModelConfig())
B
barrierye 已提交
391
        else:
Z
zhangjun 已提交
392 393 394 395
            raise Exception(
                "The type of model_config_paths must be str or list or "
                "dict({op: model_path}), not {}.".format(
                    type(model_config_paths_args)))
M
MRXLT 已提交
396 397
        # check config here
        # print config here
Z
update  
zhangjun 已提交
398

Z
zhangjun 已提交
399 400 401
    def use_mkl(self, flag):
        self.mkl_flag = flag

Z
zhangjun 已提交
402
    def check_avx(self):
403 404 405 406 407
        p = subprocess.Popen(
            ['cat /proc/cpuinfo | grep avx 2>/dev/null'],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=True)
Z
zhangjun 已提交
408
        out, err = p.communicate()
Z
zhangjun 已提交
409
        if err == b'' and len(out) > 0:
Z
zhangjun 已提交
410 411 412 413
            return True
        else:
            return False

Z
zhangjun 已提交
414 415
    def get_device_version(self):
        avx_flag = False
Z
zhangjun 已提交
416
        avx_support = self.check_avx()
Z
update  
zhangjun 已提交
417
        if avx_support:
Z
zhangjun 已提交
418
            avx_flag = True
Z
zhangjun 已提交
419 420
            self.use_mkl(True)
        mkl_flag = self.mkl_flag
Z
zhangjun 已提交
421 422
        if avx_flag:
            if mkl_flag:
Z
update  
zhangjun 已提交
423
                device_version = "cpu-avx-mkl"
Z
zhangjun 已提交
424
            else:
Z
update  
zhangjun 已提交
425
                device_version = "cpu-avx-openblas"
Z
zhangjun 已提交
426 427 428 429 430
        else:
            if mkl_flag:
                print(
                    "Your CPU does not support AVX, server will running with noavx-openblas mode."
                )
Z
update  
zhangjun 已提交
431 432 433 434 435 436 437 438 439 440 441 442 443
            device_version = "cpu-noavx-openblas"
        return device_version

    def get_serving_bin_name(self):
        if device_type == "0":
            device_version = self.get_device_version()
        elif device_type == "1":
            if version_suffix == "101" or version_suffix == "102":
                device_version = "gpu-" + version_suffix
            else:
                device_version = "gpu-cuda" + version_suffix
        elif device_type == "2":
            device_version = "xpu-" + platform.machine()
Z
zhangjun 已提交
444
        return device_version
M
MRXLT 已提交
445 446 447 448

    def download_bin(self):
        os.chdir(self.module_path)
        need_download = False
M
MRXLT 已提交
449 450 451

        #acquire lock
        version_file = open("{}/version.py".format(self.module_path), "r")
Z
update  
zhangjun 已提交
452

Z
fix  
zhangjun 已提交
453
        folder_name = "serving-%s-%s" % (self.get_serving_bin_name(),
454
                                         version_tag)
Z
fix  
zhangjun 已提交
455 456 457
        tar_name = "%s.tar.gz" % folder_name
        bin_url = "https://paddle-serving.bj.bcebos.com/bin/%s" % tar_name

458 459 460 461
        self.server_path = os.path.join(self.module_path, folder_name)

        download_flag = "{}/{}.is_download".format(self.module_path,
                                                   folder_name)
M
MRXLT 已提交
462 463 464

        fcntl.flock(version_file, fcntl.LOCK_EX)

465 466 467 468 469
        if os.path.exists(download_flag):
            os.chdir(self.cur_path)
            self.bin_path = self.server_path + "/serving"
            return

M
MRXLT 已提交
470 471
        if not os.path.exists(self.server_path):
            print('Frist time run, downloading PaddleServing components ...')
M
MRXLT 已提交
472

M
MRXLT 已提交
473 474 475 476
            r = os.system('wget ' + bin_url + ' --no-check-certificate')
            if r != 0:
                if os.path.exists(tar_name):
                    os.remove(tar_name)
M
MRXLT 已提交
477
                raise SystemExit(
T
TeslaZhao 已提交
478 479
                    'Download failed, please check your network or permission of {}.'
                    .format(self.module_path))
M
MRXLT 已提交
480 481 482 483 484 485
            else:
                try:
                    print('Decompressing files ..')
                    tar = tarfile.open(tar_name)
                    tar.extractall()
                    tar.close()
Z
zhangjun 已提交
486
                    open(download_flag, "a").close()
M
MRXLT 已提交
487
                except:
Z
zhangjun 已提交
488 489
                    if os.path.exists(self.server_path):
                        os.remove(self.server_path)
M
MRXLT 已提交
490
                    raise SystemExit(
T
TeslaZhao 已提交
491 492
                        'Decompressing failed, please check your permission of {} or disk space left.'
                        .format(self.module_path))
M
MRXLT 已提交
493 494
                finally:
                    os.remove(tar_name)
M
MRXLT 已提交
495
        #release lock
B
barrierye 已提交
496
        version_file.close()
M
MRXLT 已提交
497 498
        os.chdir(self.cur_path)
        self.bin_path = self.server_path + "/serving"
Z
update  
zhangjun 已提交
499

M
MRXLT 已提交
500 501 502 503
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
W
wangjiawei04 已提交
504
                       use_encryption_model=False,
M
MRXLT 已提交
505
                       cube_conf=None):
506
        self.device = device
M
MRXLT 已提交
507 508
        if workdir == None:
            workdir = "./tmp"
Z
zhangjun 已提交
509
            os.system("mkdir -p {}".format(workdir))
M
MRXLT 已提交
510
        else:
Z
zhangjun 已提交
511
            os.system("mkdir -p {}".format(workdir))
H
HexToString 已提交
512
        for subdir in self.subdirectory:
513
            os.system("mkdir -p {}/{}".format(workdir, subdir))
H
HexToString 已提交
514 515
            os.system("touch {}/{}/fluid_time_file".format(workdir, subdir))

M
MRXLT 已提交
516
        if not self.port_is_available(port):
G
gongweibao 已提交
517
            raise SystemExit("Port {} is already used".format(port))
M
MRXLT 已提交
518

G
guru4elephant 已提交
519
        self.set_port(port)
M
MRXLT 已提交
520
        self._prepare_resource(workdir, cube_conf)
H
HexToString 已提交
521 522
        self._prepare_engine(self.model_config_paths, device,
                             use_encryption_model)
M
MRXLT 已提交
523 524 525 526 527
        self._prepare_infer_service(port)
        self.workdir = workdir

        infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
        self._write_pb_str(infer_service_fn, self.infer_service_conf)
H
HexToString 已提交
528 529

        workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
M
MRXLT 已提交
530
        self._write_pb_str(workflow_fn, self.workflow_conf)
H
HexToString 已提交
531 532

        resource_fn = "{}/{}".format(workdir, self.resource_fn)
M
MRXLT 已提交
533
        self._write_pb_str(resource_fn, self.resource_conf)
H
HexToString 已提交
534

Z
zhangjun 已提交
535
        for idx, single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
H
HexToString 已提交
536 537
            model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn)
            self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx])
M
MRXLT 已提交
538

M
MRXLT 已提交
539
    def port_is_available(self, port):
M
MRXLT 已提交
540 541
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
542
            result = sock.connect_ex(('0.0.0.0', port))
M
MRXLT 已提交
543 544 545 546 547
        if result != 0:
            return True
        else:
            return False

M
MRXLT 已提交
548 549 550
    def run_server(self):
        # just run server with system command
        # currently we do not load cube
M
MRXLT 已提交
551
        self.check_local_bin()
M
MRXLT 已提交
552 553
        if not self.use_local_bin:
            self.download_bin()
B
fix bug  
barrierye 已提交
554 555 556
            # wait for other process to download server bin
            while not os.path.exists(self.server_path):
                time.sleep(1)
M
MRXLT 已提交
557 558
        else:
            print("Use local bin : {}".format(self.bin_path))
Z
zhangjun 已提交
559
        #self.check_cuda()
H
HexToString 已提交
560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591
        command = "{} " \
                    "-enable_model_toolkit " \
                    "-inferservice_path {} " \
                    "-inferservice_file {} " \
                    "-max_concurrency {} " \
                    "-num_threads {} " \
                    "-port {} " \
                    "-precision {} " \
                    "-use_calib {} " \
                    "-reload_interval_s {} " \
                    "-resource_path {} " \
                    "-resource_file {} " \
                    "-workflow_path {} " \
                    "-workflow_file {} " \
                    "-bthread_concurrency {} " \
                    "-max_body_size {} ".format(
                        self.bin_path,
                        self.workdir,
                        self.infer_service_fn,
                        self.max_concurrency,
                        self.num_threads,
                        self.port,
                        self.precision,
                        self.use_calib,
                        self.reload_interval_s,
                        self.workdir,
                        self.resource_fn,
                        self.workdir,
                        self.workflow_fn,
                        self.num_threads,
                        self.max_body_size)

M
MRXLT 已提交
592 593
        print("Going to Run Comand")
        print(command)
594

M
MRXLT 已提交
595
        os.system(command)
B
barrierye 已提交
596

Z
zhangjun 已提交
597

B
barrierye 已提交
598
class MultiLangServer(object):
B
barrierye 已提交
599
    def __init__(self):
B
barrierye 已提交
600
        self.bserver_ = Server()
B
barrierye 已提交
601 602 603
        self.worker_num_ = 4
        self.body_size_ = 64 * 1024 * 1024
        self.concurrency_ = 100000
604
        self.is_multi_model_ = False  # for model ensemble, which is not useful right now.
B
barrierye 已提交
605

B
barrierye 已提交
606
    def set_max_concurrency(self, concurrency):
B
barrierye 已提交
607
        self.concurrency_ = concurrency
B
barrierye 已提交
608 609
        self.bserver_.set_max_concurrency(concurrency)

610 611 612
    def set_device(self, device="cpu"):
        self.device = device

B
barrierye 已提交
613
    def set_num_threads(self, threads):
B
barrierye 已提交
614
        self.worker_num_ = threads
B
barrierye 已提交
615 616 617 618
        self.bserver_.set_num_threads(threads)

    def set_max_body_size(self, body_size):
        self.bserver_.set_max_body_size(body_size)
B
barrierye 已提交
619 620 621 622 623 624
        if body_size >= self.body_size_:
            self.body_size_ = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )
B
barrierye 已提交
625

626 627 628
    def use_encryption_model(self, flag=False):
        self.encryption_model = flag

B
barrierye 已提交
629 630 631
    def set_port(self, port):
        self.gport_ = port

632 633 634 635 636 637
    def set_precision(self, precision="fp32"):
        self.precision = precision

    def set_use_calib(self, use_calib=False):
        self.use_calib = use_calib

B
barrierye 已提交
638 639
    def set_reload_interval(self, interval):
        self.bserver_.set_reload_interval(interval)
B
barrierye 已提交
640 641 642 643

    def set_op_sequence(self, op_seq):
        self.bserver_.set_op_sequence(op_seq)

B
barrierye 已提交
644 645
    def set_op_graph(self, op_graph):
        self.bserver_.set_op_graph(op_graph)
Z
update  
zhangjun 已提交
646

Z
zhangjun 已提交
647 648
    def use_mkl(self, flag):
        self.bserver_.use_mkl(flag)
Z
update  
zhangjun 已提交
649

B
barrierye 已提交
650 651 652 653 654
    def set_memory_optimize(self, flag=False):
        self.bserver_.set_memory_optimize(flag)

    def set_ir_optimize(self, flag=False):
        self.bserver_.set_ir_optimize(flag)
B
barrierye 已提交
655

656
    def set_gpuid(self, gpuid):
B
barrierye 已提交
657 658
        self.bserver_.set_gpuid(gpuid)

659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
    def set_op_num(self, op_num):
        self.bserver_.set_op_num(op_num)

    def set_op_max_batch(self, op_max_batch):
        self.bserver_.set_op_max_batch(op_max_batch)

    def set_trt(self):
        self.bserver_.set_trt()

    def set_gpu_multi_stream(self):
        self.bserver_.set_gpu_multi_stream()

    def set_lite(self):
        self.bserver_.set_lite()

    def set_xpu(self):
        self.bserver_.set_xpu()

Z
zhangjun 已提交
677 678 679
    def load_model_config(self,
                          server_config_dir_paths,
                          client_config_path=None):
H
HexToString 已提交
680 681 682 683 684 685
        if isinstance(server_config_dir_paths, str):
            server_config_dir_paths = [server_config_dir_paths]
        elif isinstance(server_config_dir_paths, list):
            pass
        else:
            raise Exception("The type of model_config_paths must be str or list"
Z
zhangjun 已提交
686
                            ", not {}.".format(type(server_config_dir_paths)))
H
HexToString 已提交
687 688 689 690 691

        for single_model_config in server_config_dir_paths:
            if os.path.isdir(single_model_config):
                pass
            elif os.path.isfile(single_model_config):
Z
zhangjun 已提交
692 693
                raise ValueError(
                    "The input of --model should be a dir not file.")
H
HexToString 已提交
694 695

        self.bserver_.load_model_config(server_config_dir_paths)
B
barrierye 已提交
696
        if client_config_path is None:
H
HexToString 已提交
697 698
            #now dict is not useful.
            if isinstance(server_config_dir_paths, dict):
B
barrierye 已提交
699
                self.is_multi_model_ = True
H
HexToString 已提交
700
                client_config_path = []
Z
zhangjun 已提交
701 702 703
                for server_config_path_items in list(
                        server_config_dir_paths.items()):
                    client_config_path.append(server_config_path_items[1])
H
HexToString 已提交
704 705 706
            elif isinstance(server_config_dir_paths, list):
                self.is_multi_model_ = False
                client_config_path = server_config_dir_paths
B
barrierye 已提交
707
            else:
Z
zhangjun 已提交
708 709 710 711
                raise Exception(
                    "The type of model_config_paths must be str or list or "
                    "dict({op: model_path}), not {}.".format(
                        type(server_config_dir_paths)))
H
HexToString 已提交
712 713 714 715
        if isinstance(client_config_path, str):
            client_config_path = [client_config_path]
        elif isinstance(client_config_path, list):
            pass
Z
zhangjun 已提交
716 717 718 719 720
        else:  # dict is not support right now.
            raise Exception(
                "The type of client_config_path must be str or list or "
                "dict({op: model_path}), not {}.".format(
                    type(client_config_path)))
H
HexToString 已提交
721
        if len(client_config_path) != len(server_config_dir_paths):
Z
zhangjun 已提交
722 723 724
            raise Warning(
                "The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
                .format(len(client_config_path), len(server_config_dir_paths)))
H
HexToString 已提交
725
        self.bclient_config_path_list = client_config_path
B
barrierye 已提交
726

M
MRXLT 已提交
727 728 729 730
    def prepare_server(self,
                       workdir=None,
                       port=9292,
                       device="cpu",
H
HexToString 已提交
731
                       use_encryption_model=False,
M
MRXLT 已提交
732
                       cube_conf=None):
733
        self.device = device
B
barrierye 已提交
734
        if not self._port_is_available(port):
B
bjjwwang 已提交
735
            raise SystemExit("Port {} is already used".format(port))
B
barrierye 已提交
736 737 738 739 740 741 742 743
        default_port = 12000
        self.port_list_ = []
        for i in range(1000):
            if default_port + i != port and self._port_is_available(default_port
                                                                    + i):
                self.port_list_.append(default_port + i)
                break
        self.bserver_.prepare_server(
M
MRXLT 已提交
744 745 746
            workdir=workdir,
            port=self.port_list_[0],
            device=device,
H
HexToString 已提交
747
            use_encryption_model=use_encryption_model,
M
MRXLT 已提交
748
            cube_conf=cube_conf)
B
barrierye 已提交
749
        self.set_port(port)
B
barrierye 已提交
750 751 752 753 754 755 756 757 758 759 760 761 762 763

    def _launch_brpc_service(self, bserver):
        bserver.run_server()

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def run_server(self):
        p_bserver = Process(
            target=self._launch_brpc_service, args=(self.bserver_, ))
        p_bserver.start()
B
barrierye 已提交
764 765
        options = [('grpc.max_send_message_length', self.body_size_),
                   ('grpc.max_receive_message_length', self.body_size_)]
B
barrierye 已提交
766
        server = grpc.server(
B
barrierye 已提交
767 768 769
            futures.ThreadPoolExecutor(max_workers=self.worker_num_),
            options=options,
            maximum_concurrent_rpcs=self.concurrency_)
B
barrierye 已提交
770
        multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
B
barrierye 已提交
771
            MultiLangServerServiceServicer(
H
HexToString 已提交
772
                self.bclient_config_path_list, self.is_multi_model_,
B
barrierye 已提交
773
                ["0.0.0.0:{}".format(self.port_list_[0])]), server)
B
barrierye 已提交
774 775 776
        server.add_insecure_port('[::]:{}'.format(self.gport_))
        server.start()
        p_bserver.join()
Z
update  
zhangjun 已提交
777
        server.wait_for_termination()