__init__.py 23.5 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
B
barrierye 已提交
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

G
guru4elephant 已提交
16
import os
17 18 19
from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
M
MRXLT 已提交
20
import tarfile
M
MRXLT 已提交
21
import socket
M
MRXLT 已提交
22
import paddle_serving_server as paddle_serving_server
23
from .version import serving_server_version
M
MRXLT 已提交
24
from contextlib import closing
B
barrierye 已提交
25
import collections
M
MRXLT 已提交
26
import fcntl
M
MRXLT 已提交
27

B
barrierye 已提交
28
import numpy as np
B
barrierye 已提交
29
import grpc
B
barrierye 已提交
30 31
import gserver_general_model_service_pb2
import gserver_general_model_service_pb2_grpc
B
barrierye 已提交
32 33
from multiprocessing import Pool, Process
from concurrent import futures
B
barrierye 已提交
34
import itertools
B
barrierye 已提交
35

G
guru4elephant 已提交
36 37 38

class OpMaker(object):
    def __init__(self):
39
        self.op_dict = {
M
MRXLT 已提交
40 41 42 43 44 45
            "general_infer": "GeneralInferOp",
            "general_reader": "GeneralReaderOp",
            "general_response": "GeneralResponseOp",
            "general_text_reader": "GeneralTextReaderOp",
            "general_text_response": "GeneralTextResponseOp",
            "general_single_kv": "GeneralSingleKVOp",
W
wangjiawei04 已提交
46
            "general_dist_kv_infer": "GeneralDistKVInferOp",
W
wangjiawei04 已提交
47
            "general_dist_kv_quant_infer": "GeneralDistKVQuantInferOp",
G
guru4elephant 已提交
48
            "general_copy": "GeneralCopyOp"
49
        }
B
barrierye 已提交
50
        self.node_name_suffix_ = collections.defaultdict(int)
G
guru4elephant 已提交
51

B
barrierye 已提交
52
    def create(self, node_type, engine_name=None, inputs=[], outputs=[]):
B
barrierye 已提交
53 54 55
        if node_type not in self.op_dict:
            raise Exception("Op type {} is not supported right now".format(
                node_type))
G
guru4elephant 已提交
56
        node = server_sdk.DAGNode()
B
barrierye 已提交
57 58 59 60 61 62 63 64
        # node.name will be used as the infer engine name
        if engine_name:
            node.name = engine_name
        else:
            node.name = '{}_{}'.format(node_type,
                                       self.node_name_suffix_[node_type])
            self.node_name_suffix_[node_type] += 1

B
barrierye 已提交
65
        node.type = self.op_dict[node_type]
B
barrierye 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78
        if inputs:
            for dep_node_str in inputs:
                dep_node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(dep_node_str, dep_node)
                dep = server_sdk.DAGNodeDependency()
                dep.name = dep_node.name
                dep.mode = "RO"
                node.dependencies.extend([dep])
        # Because the return value will be used as the key value of the
        # dict, and the proto object is variable which cannot be hashed,
        # so it is processed into a string. This has little effect on
        # overall efficiency.
        return google.protobuf.text_format.MessageToString(node)
G
guru4elephant 已提交
79

M
MRXLT 已提交
80

G
guru4elephant 已提交
81 82 83 84 85 86
class OpSeqMaker(object):
    def __init__(self):
        self.workflow = server_sdk.Workflow()
        self.workflow.name = "workflow1"
        self.workflow.workflow_type = "Sequence"

B
barrierye 已提交
87 88 89 90 91 92 93
    def add_op(self, node_str):
        node = server_sdk.DAGNode()
        google.protobuf.text_format.Parse(node_str, node)
        if len(node.dependencies) > 1:
            raise Exception(
                'Set more than one predecessor for op in OpSeqMaker is not allowed.'
            )
G
guru4elephant 已提交
94
        if len(self.workflow.nodes) >= 1:
B
barrierye 已提交
95
            if len(node.dependencies) == 0:
B
barrierye 已提交
96 97 98 99
                dep = server_sdk.DAGNodeDependency()
                dep.name = self.workflow.nodes[-1].name
                dep.mode = "RO"
                node.dependencies.extend([dep])
B
barrierye 已提交
100 101 102 103 104 105
            elif len(node.dependencies) == 1:
                if node.dependencies[0].name != self.workflow.nodes[-1].name:
                    raise Exception(
                        'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'.
                        format(node.dependencies[0].name, self.workflow.nodes[
                            -1].name))
G
guru4elephant 已提交
106 107 108 109 110 111 112
        self.workflow.nodes.extend([node])

    def get_op_sequence(self):
        workflow_conf = server_sdk.WorkflowConf()
        workflow_conf.workflows.extend([self.workflow])
        return workflow_conf

M
MRXLT 已提交
113

B
barrierye 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
class OpGraphMaker(object):
    def __init__(self):
        self.workflow = server_sdk.Workflow()
        self.workflow.name = "workflow1"
        # Currently, SDK only supports "Sequence"
        self.workflow.workflow_type = "Sequence"

    def add_op(self, node_str):
        node = server_sdk.DAGNode()
        google.protobuf.text_format.Parse(node_str, node)
        self.workflow.nodes.extend([node])

    def get_op_graph(self):
        workflow_conf = server_sdk.WorkflowConf()
        workflow_conf.workflows.extend([self.workflow])
        return workflow_conf


G
guru4elephant 已提交
132 133 134 135 136
class Server(object):
    def __init__(self):
        self.server_handle_ = None
        self.infer_service_conf = None
        self.model_toolkit_conf = None
G
guru4elephant 已提交
137
        self.resource_conf = None
138
        self.memory_optimization = False
M
MRXLT 已提交
139
        self.ir_optimization = False
140
        self.model_conf = None
G
guru4elephant 已提交
141 142 143 144
        self.workflow_fn = "workflow.prototxt"
        self.resource_fn = "resource.prototxt"
        self.infer_service_fn = "infer_service.prototxt"
        self.model_toolkit_fn = "model_toolkit.prototxt"
145
        self.general_model_config_fn = "general_model.prototxt"
W
wangjiawei04 已提交
146
        self.cube_config_fn = "cube.conf"
G
guru4elephant 已提交
147 148
        self.workdir = ""
        self.max_concurrency = 0
M
MRXLT 已提交
149
        self.num_threads = 4
G
guru4elephant 已提交
150 151
        self.port = 8080
        self.reload_interval_s = 10
M
MRXLT 已提交
152
        self.max_body_size = 64 * 1024 * 1024
M
MRXLT 已提交
153 154
        self.module_path = os.path.dirname(paddle_serving_server.__file__)
        self.cur_path = os.getcwd()
M
MRXLT 已提交
155
        self.use_local_bin = False
M
MRXLT 已提交
156
        self.mkl_flag = False
B
barrierye 已提交
157
        self.model_config_paths = None  # for multi-model in a workflow
G
guru4elephant 已提交
158 159 160 161 162 163 164

    def set_max_concurrency(self, concurrency):
        self.max_concurrency = concurrency

    def set_num_threads(self, threads):
        self.num_threads = threads

M
MRXLT 已提交
165 166 167 168 169 170 171 172
    def set_max_body_size(self, body_size):
        if body_size >= self.max_body_size:
            self.max_body_size = body_size
        else:
            print(
                "max_body_size is less than default value, will use default value in service."
            )

G
guru4elephant 已提交
173 174 175 176 177
    def set_port(self, port):
        self.port = port

    def set_reload_interval(self, interval):
        self.reload_interval_s = interval
G
guru4elephant 已提交
178 179 180 181

    def set_op_sequence(self, op_seq):
        self.workflow_conf = op_seq

B
barrierye 已提交
182 183 184
    def set_op_graph(self, op_graph):
        self.workflow_conf = op_graph

185 186 187
    def set_memory_optimize(self, flag=False):
        self.memory_optimization = flag

M
MRXLT 已提交
188 189 190
    def set_ir_optimize(self, flag=False):
        self.ir_optimization = flag

M
MRXLT 已提交
191 192 193 194
    def check_local_bin(self):
        if "SERVING_BIN" in os.environ:
            self.use_local_bin = True
            self.bin_path = os.environ["SERVING_BIN"]
M
MRXLT 已提交
195

B
barrierye 已提交
196
    def _prepare_engine(self, model_config_paths, device):
G
guru4elephant 已提交
197 198
        if self.model_toolkit_conf == None:
            self.model_toolkit_conf = server_sdk.ModelToolkitConf()
199

B
barrierye 已提交
200 201 202 203 204 205 206 207 208 209 210
        for engine_name, model_config_path in model_config_paths.items():
            engine = server_sdk.EngineDesc()
            engine.name = engine_name
            engine.reloadable_meta = model_config_path + "/fluid_time_file"
            os.system("touch {}".format(engine.reloadable_meta))
            engine.reloadable_type = "timestamp_ne"
            engine.runtime_thread_num = 0
            engine.batch_infer_size = 0
            engine.enable_batch_align = 0
            engine.model_data_path = model_config_path
            engine.enable_memory_optimization = self.memory_optimization
M
MRXLT 已提交
211
            engine.enable_ir_optimization = self.ir_optimization
B
barrierye 已提交
212 213 214 215 216 217 218 219 220
            engine.static_optimization = False
            engine.force_update_static_cache = False

            if device == "cpu":
                engine.type = "FLUID_CPU_ANALYSIS_DIR"
            elif device == "gpu":
                engine.type = "FLUID_GPU_ANALYSIS_DIR"

            self.model_toolkit_conf.engines.extend([engine])
G
guru4elephant 已提交
221 222 223 224 225 226 227 228 229 230 231 232

    def _prepare_infer_service(self, port):
        if self.infer_service_conf == None:
            self.infer_service_conf = server_sdk.InferServiceConf()
            self.infer_service_conf.port = port
            infer_service = server_sdk.InferService()
            infer_service.name = "GeneralModelService"
            infer_service.workflows.extend(["workflow1"])
            self.infer_service_conf.services.extend([infer_service])

    def _prepare_resource(self, workdir):
        if self.resource_conf == None:
M
MRXLT 已提交
233 234
            with open("{}/{}".format(workdir, self.general_model_config_fn),
                      "w") as fout:
235
                fout.write(str(self.model_conf))
G
guru4elephant 已提交
236
            self.resource_conf = server_sdk.ResourceConf()
W
wangjiawei04 已提交
237 238 239 240 241
            for workflow in self.workflow_conf.workflows:
                for node in workflow.nodes:
                    if "dist_kv" in node.name:
                        self.resource_conf.cube_config_path = workdir
                        self.resource_conf.cube_config_file = self.cube_config_fn
W
wangjiawei04 已提交
242 243
                        if "quant" in node.name:
                            self.resource_conf.cube_quant_bits = 8
G
guru4elephant 已提交
244
            self.resource_conf.model_toolkit_path = workdir
G
guru4elephant 已提交
245
            self.resource_conf.model_toolkit_file = self.model_toolkit_fn
246 247
            self.resource_conf.general_model_path = workdir
            self.resource_conf.general_model_file = self.general_model_config_fn
G
guru4elephant 已提交
248 249 250 251 252

    def _write_pb_str(self, filepath, pb_obj):
        with open(filepath, "w") as fout:
            fout.write(str(pb_obj))

B
barrierye 已提交
253
    def load_model_config(self, model_config_paths):
B
barrierye 已提交
254 255 256
        # At present, Serving needs to configure the model path in
        # the resource.prototxt file to determine the input and output
        # format of the workflow. To ensure that the input and output
B
barrierye 已提交
257
        # of multiple models are the same.
B
barrierye 已提交
258 259
        workflow_oi_config_path = None
        if isinstance(model_config_paths, str):
B
barrierye 已提交
260
            # If there is only one model path, use the default infer_op.
M
MRXLT 已提交
261
            # Because there are several infer_op type, we need to find
B
barrierye 已提交
262 263 264
            # it from workflow_conf.
            default_engine_names = [
                'general_infer_0', 'general_dist_kv_infer_0',
B
barrierye 已提交
265
                'general_dist_kv_quant_infer_0'
B
barrierye 已提交
266 267 268 269 270 271 272 273 274 275 276 277
            ]
            engine_name = None
            for node in self.workflow_conf.workflows[0].nodes:
                if node.name in default_engine_names:
                    engine_name = node.name
                    break
            if engine_name is None:
                raise Exception(
                    "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
                )
            self.model_config_paths = {engine_name: model_config_paths}
            workflow_oi_config_path = self.model_config_paths[engine_name]
B
barrierye 已提交
278 279 280 281 282 283 284 285
        elif isinstance(model_config_paths, dict):
            self.model_config_paths = {}
            for node_str, path in model_config_paths.items():
                node = server_sdk.DAGNode()
                google.protobuf.text_format.Parse(node_str, node)
                self.model_config_paths[node.name] = path
            print("You have specified multiple model paths, please ensure "
                  "that the input and output of multiple models are the same.")
M
MRXLT 已提交
286 287
            workflow_oi_config_path = list(self.model_config_paths.items())[0][
                1]
B
barrierye 已提交
288 289 290 291 292
        else:
            raise Exception("The type of model_config_paths must be str or "
                            "dict({op: model_path}), not {}.".format(
                                type(model_config_paths)))

293
        self.model_conf = m_config.GeneralModelConfig()
B
barrierye 已提交
294 295 296
        f = open(
            "{}/serving_server_conf.prototxt".format(workflow_oi_config_path),
            'r')
297 298
        self.model_conf = google.protobuf.text_format.Merge(
            str(f.read()), self.model_conf)
G
guru4elephant 已提交
299
        # check config here
300
        # print config here
G
guru4elephant 已提交
301

M
MRXLT 已提交
302 303
    def use_mkl(self, flag):
        self.mkl_flag = flag
M
MRXLT 已提交
304

M
MRXLT 已提交
305 306
    def get_device_version(self):
        avx_flag = False
M
MRXLT 已提交
307
        mkl_flag = self.mkl_flag
M
MRXLT 已提交
308 309 310 311 312 313 314 315 316 317
        openblas_flag = False
        r = os.system("cat /proc/cpuinfo | grep avx > /dev/null 2>&1")
        if r == 0:
            avx_flag = True
        if avx_flag:
            if mkl_flag:
                device_version = "serving-cpu-avx-mkl-"
            else:
                device_version = "serving-cpu-avx-openblas-"
        else:
M
MRXLT 已提交
318 319 320 321
            if mkl_flag:
                print(
                    "Your CPU does not support AVX, server will running with noavx-openblas mode."
                )
M
MRXLT 已提交
322 323 324 325 326 327 328 329 330 331 332
            device_version = "serving-cpu-noavx-openblas-"
        return device_version

    def download_bin(self):
        os.chdir(self.module_path)
        need_download = False
        device_version = self.get_device_version()
        floder_name = device_version + serving_server_version
        tar_name = floder_name + ".tar.gz"
        bin_url = "https://paddle-serving.bj.bcebos.com/bin/" + tar_name
        self.server_path = os.path.join(self.module_path, floder_name)
333

M
MRXLT 已提交
334 335 336 337
        #acquire lock
        version_file = open("{}/version.py".format(self.module_path), "r")
        fcntl.flock(version_file, fcntl.LOCK_EX)

M
MRXLT 已提交
338 339 340 341 342 343
        if not os.path.exists(self.server_path):
            print('Frist time run, downloading PaddleServing components ...')
            r = os.system('wget ' + bin_url + ' --no-check-certificate')
            if r != 0:
                if os.path.exists(tar_name):
                    os.remove(tar_name)
M
MRXLT 已提交
344 345 346
                raise SystemExit(
                    'Download failed, please check your network or permission of {}.'.
                    format(self.module_path))
M
MRXLT 已提交
347 348 349 350 351 352 353 354 355
            else:
                try:
                    print('Decompressing files ..')
                    tar = tarfile.open(tar_name)
                    tar.extractall()
                    tar.close()
                except:
                    if os.path.exists(exe_path):
                        os.remove(exe_path)
M
MRXLT 已提交
356 357 358
                    raise SystemExit(
                        'Decompressing failed, please check your permission of {} or disk space left.'.
                        foemat(self.module_path))
M
MRXLT 已提交
359 360
                finally:
                    os.remove(tar_name)
M
MRXLT 已提交
361 362
        #release lock
        version_file.close()
M
MRXLT 已提交
363 364 365
        os.chdir(self.cur_path)
        self.bin_path = self.server_path + "/serving"

G
guru4elephant 已提交
366
    def prepare_server(self, workdir=None, port=9292, device="cpu"):
G
guru4elephant 已提交
367 368 369
        if workdir == None:
            workdir = "./tmp"
            os.system("mkdir {}".format(workdir))
G
guru4elephant 已提交
370 371
        else:
            os.system("mkdir {}".format(workdir))
G
guru4elephant 已提交
372
        os.system("touch {}/fluid_time_file".format(workdir))
G
guru4elephant 已提交
373

M
MRXLT 已提交
374
        if not self.port_is_available(port):
M
MRXLT 已提交
375
            raise SystemExit("Prot {} is already used".format(port))
G
guru4elephant 已提交
376
        self._prepare_resource(workdir)
B
barrierye 已提交
377
        self._prepare_engine(self.model_config_paths, device)
G
guru4elephant 已提交
378
        self._prepare_infer_service(port)
M
MRXLT 已提交
379
        self.port = port
G
guru4elephant 已提交
380 381
        self.workdir = workdir

G
guru4elephant 已提交
382 383 384 385
        infer_service_fn = "{}/{}".format(workdir, self.infer_service_fn)
        workflow_fn = "{}/{}".format(workdir, self.workflow_fn)
        resource_fn = "{}/{}".format(workdir, self.resource_fn)
        model_toolkit_fn = "{}/{}".format(workdir, self.model_toolkit_fn)
G
guru4elephant 已提交
386 387 388 389 390 391

        self._write_pb_str(infer_service_fn, self.infer_service_conf)
        self._write_pb_str(workflow_fn, self.workflow_conf)
        self._write_pb_str(resource_fn, self.resource_conf)
        self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf)

M
MRXLT 已提交
392
    def port_is_available(self, port):
M
MRXLT 已提交
393 394
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
395
            result = sock.connect_ex(('0.0.0.0', port))
M
MRXLT 已提交
396 397 398 399 400
        if result != 0:
            return True
        else:
            return False

G
guru4elephant 已提交
401 402
    def run_server(self):
        # just run server with system command
G
guru4elephant 已提交
403
        # currently we do not load cube
M
MRXLT 已提交
404
        self.check_local_bin()
M
MRXLT 已提交
405 406
        if not self.use_local_bin:
            self.download_bin()
G
guru4elephant 已提交
407
        else:
M
MRXLT 已提交
408
            print("Use local bin : {}".format(self.bin_path))
M
MRXLT 已提交
409 410
        command = "{} " \
                  "-enable_model_toolkit " \
G
guru4elephant 已提交
411 412 413 414 415 416 417 418 419
                  "-inferservice_path {} " \
                  "-inferservice_file {} " \
                  "-max_concurrency {} " \
                  "-num_threads {} " \
                  "-port {} " \
                  "-reload_interval_s {} " \
                  "-resource_path {} " \
                  "-resource_file {} " \
                  "-workflow_path {} " \
M
MRXLT 已提交
420
                  "-workflow_file {} " \
M
MRXLT 已提交
421 422
                  "-bthread_concurrency {} " \
                  "-max_body_size {} ".format(
M
MRXLT 已提交
423
                      self.bin_path,
G
guru4elephant 已提交
424 425 426 427 428 429 430 431
                      self.workdir,
                      self.infer_service_fn,
                      self.max_concurrency,
                      self.num_threads,
                      self.port,
                      self.reload_interval_s,
                      self.workdir,
                      self.resource_fn,
M
MRXLT 已提交
432 433
                      self.workdir,
                      self.workflow_fn,
M
MRXLT 已提交
434 435
                      self.num_threads,
                      self.max_body_size)
W
wangjiawei04 已提交
436
        print("Going to Run Command")
G
guru4elephant 已提交
437
        print(command)
G
guru4elephant 已提交
438
        os.system(command)
B
barrierye 已提交
439 440 441 442 443 444 445 446


class GServerService(
        gserver_general_model_service_pb2_grpc.GServerGeneralModelService):
    def __init__(self, model_config_path, endpoints):
        from paddle_serving_client import Client
        self._parse_model_config(model_config_path)
        self.bclient_ = Client()
B
barrierye 已提交
447 448
        self.bclient_.load_client_config(
            "{}/serving_server_conf.prototxt".format(model_config_path))
B
barrierye 已提交
449 450 451 452 453 454 455 456 457 458
        self.bclient_.connect(endpoints)

    def _parse_model_config(self, model_config_path):
        model_conf = m_config.GeneralModelConfig()
        f = open("{}/serving_server_conf.prototxt".format(model_config_path),
                 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
459
        self.feed_shapes_ = {}
B
barrierye 已提交
460
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
461 462 463
        self.fetch_types_ = {}
        self.type_map_ = {0: "int64", 1: "float32"}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
464 465 466
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
467 468
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
469
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
470 471 472
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
473

B
barrierye 已提交
474 475 476 477 478 479 480 481
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

B
barrierye 已提交
482
    def _unpack_request(self, request):
B
barrierye 已提交
483
        feed_names = list(request.feed_var_names)
B
barrierye 已提交
484 485
        fetch_names = list(request.fetch_var_names)
        feed_batch = []
B
barrierye 已提交
486
        for feed_inst in request.insts:
B
barrierye 已提交
487
            feed_dict = {}
B
barrierye 已提交
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
            for idx, name in enumerate(feed_names):
                v_type = self.feed_types_[name]
                data = None
                if v_type == 0:  # int64
                    data = np.array(
                        list(feed_inst.tensor_array[idx].int64_data),
                        dtype="int64")
                elif v_type == 1:  # float32
                    data = np.array(
                        list(feed_inst.tensor_array[idx].float_data),
                        dtype="float")
                else:
                    raise Exception("error type.")
                shape = list(feed_inst.tensor_array[idx].shape)
                data.shape = shape
                feed_dict[name] = data
B
barrierye 已提交
504
            feed_batch.append(feed_dict)
B
barrierye 已提交
505 506
        return feed_batch, fetch_names

B
barrierye 已提交
507
    def _pack_resp_package(self, result, fetch_names, tag):
B
barrierye 已提交
508
        resp = gserver_general_model_service_pb2.Response()
B
barrierye 已提交
509 510 511 512 513 514 515 516 517 518 519 520 521
        # Only one model is supported temporarily
        model_output = gserver_general_model_service_pb2.ModelOutput()
        inst = gserver_general_model_service_pb2.FetchInst()
        for idx, name in enumerate(fetch_names):
            # model_output.fetch_var_names.append(name)
            tensor = gserver_general_model_service_pb2.Tensor()
            v_type = self.fetch_types_[name]
            if v_type == 0:  # int64
                tensor.int64_data.extend(
                    self._flatten_list(result[name].tolist()))
            elif v_type == 1:  # float32
                tensor.float_data.extend(
                    self._flatten_list(result[name].tolist()))
B
barrierye 已提交
522
            else:
B
barrierye 已提交
523 524 525 526 527 528 529 530
                raise Exception("error type.")
            tensor.shape.extend(list(result[name].shape))
            if name in self.lod_tensor_set_:
                tensor.lod.extend(result["{}.lod".format(name)].tolist())
            inst.tensor_array.append(tensor)
        model_output.insts.append(inst)
        resp.outputs.append(model_output)
        resp.tag = tag
B
barrierye 已提交
531 532 533 534
        return resp

    def inference(self, request, context):
        feed_dict, fetch_names = self._unpack_request(request)
B
barrierye 已提交
535 536 537
        data, tag = self.bclient_.predict(
            feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
        return self._pack_resp_package(data, fetch_names, tag)
B
barrierye 已提交
538 539 540 541


class GServer(object):
    def __init__(self, worker_num=2):
B
barrierye 已提交
542
        self.bserver_ = Server()
B
barrierye 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
        self.worker_num_ = worker_num

    def set_op_sequence(self, op_seq):
        self.bserver_.set_op_sequence(op_seq)

    def load_model_config(self, model_config_path):
        if not isinstance(model_config_path, str):
            raise Exception("GServer only supports multi-model temporarily")
        self.bserver_.load_model_config(model_config_path)
        self.model_config_path_ = model_config_path

    def prepare_server(self, workdir=None, port=9292, device="cpu"):
        default_port = 12000
        self.port_list_ = []
        for i in range(1000):
B
barrierye 已提交
558 559
            if default_port + i != port and self._port_is_available(default_port
                                                                    + i):
B
barrierye 已提交
560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576
                self.port_list_.append(default_port + i)
                break
        self.bserver_.prepare_server(
            workdir=workdir, port=self.port_list_[0], device=device)
        self.gport_ = port

    def _launch_brpc_service(self, bserver):
        bserver.run_server()

    def _port_is_available(self, port):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
            sock.settimeout(2)
            result = sock.connect_ex(('0.0.0.0', port))
        return result != 0

    def run_server(self):
        p_bserver = Process(
B
barrierye 已提交
577
            target=self._launch_brpc_service, args=(self.bserver_, ))
B
barrierye 已提交
578 579 580
        p_bserver.start()
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=self.worker_num_))
B
barrierye 已提交
581
        gserver_general_model_service_pb2_grpc.add_GServerGeneralModelServiceServicer_to_server(
B
barrierye 已提交
582
            GServerService(self.model_config_path_,
B
barrierye 已提交
583
                           ["0.0.0.0:{}".format(self.port_list_[0])]), server)
B
barrierye 已提交
584 585
        server.add_insecure_port('[::]:{}'.format(self.gport_))
        server.start()
B
barrierye 已提交
586 587
        p_bserver.join()
        server.wait_for_termination()