__init__.py 21.6 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
G
guru4elephant 已提交
24

B
barrierye 已提交
25
import grpc
B
barrierye 已提交
26 27
from .proto import multi_lang_general_model_service_pb2
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
28

G
guru4elephant 已提交
29 30 31
int_type = 0
float_type = 1

M
MRXLT 已提交
32

W
WangXi 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
61 62 63
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
64 65 66
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
67 68
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
69

70 71 72 73
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
74

M
MRXLT 已提交
75 76 77 78
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
79 80 81 82 83
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
84 85
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
86

87 88 89 90 91 92
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
93 94 95 96

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
97
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
98 99 100 101 102
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
103

G
guru4elephant 已提交
104 105 106 107 108 109 110 111
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
112
        return self.sdk_desc
G
guru4elephant 已提交
113

G
guru4elephant 已提交
114 115 116 117 118 119

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
120
        self.feed_shapes_ = {}
G
guru4elephant 已提交
121
        self.feed_types_ = {}
G
guru4elephant 已提交
122
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
123
        self.pid = os.getpid()
B
barrierye 已提交
124
        self.predictor_sdk_ = None
G
guru4elephant 已提交
125 126
        self.producers = []
        self.consumer = None
W
WangXi 已提交
127
        self.profile_ = _Profiler()
M
MRXLT 已提交
128 129
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
130
        self.rpc_timeout_ms = 20000
131 132
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
M
MRXLT 已提交
133

G
guru4elephant 已提交
134
    def load_client_config(self, path):
M
MRXLT 已提交
135
        from .serving_client import PredictorClient
136 137 138 139 140
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
141 142 143 144
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
145 146
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
147 148
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
149
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
150 151
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
152 153
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
154
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
155 156
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
157
        self.lod_tensor_set = set()
M
MRXLT 已提交
158
        self.feed_tensor_len = {}
159

160 161 162
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
163
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
164

M
MRXLT 已提交
165 166
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
167 168 169 170 171
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
172 173 174
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
175 176
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
177 178
        return

179
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
180 181
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
182 183 184
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
185 186 187 188 189 190
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

B
barrierye 已提交
191
    def connect(self, endpoints=None):
G
guru4elephant 已提交
192 193 194
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
195 196
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
197
                raise ValueError(
B
barrierye 已提交
198 199 200 201
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
202
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
203
                                 100)
B
barrierye 已提交
204 205
            else:
                print(
206
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
207
                    format(endpoints))
M
MRXLT 已提交
208
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
209 210
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
211 212 213 214 215 216 217

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
218 219 220
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
221 222
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
223
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
224 225 226
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
227 228 229
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
230

231
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
232 233
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
234 235 236
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

237 238 239 240 241 242
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
243
            raise ValueError("Fetch only accepts string and list of string")
244 245 246 247 248 249 250

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
251
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
252

M
MRXLT 已提交
253 254 255 256
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
257 258
        int_shape = []
        float_shape = []
M
MRXLT 已提交
259
        fetch_names = []
M
MRXLT 已提交
260
        counter = 0
M
MRXLT 已提交
261
        batch_size = len(feed_batch)
262 263 264 265 266 267 268

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
269
                "Fetch names should not be empty or out of saved fetch list.")
270 271
            return {}

G
guru4elephant 已提交
272
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
273 274
            int_slot = []
            float_slot = []
275
            for key in feed_i:
M
MRXLT 已提交
276
                if key not in self.feed_names_:
M
MRXLT 已提交
277
                    raise ValueError("Wrong feed name: {}.".format(key))
M
MRXLT 已提交
278 279
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
280
                if self.feed_types_[key] == int_type:
G
guru4elephant 已提交
281
                    if i == 0:
M
MRXLT 已提交
282
                        int_feed_names.append(key)
D
dongdaxiang 已提交
283
                        if isinstance(feed_i[key], np.ndarray):
284
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
285 286
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
287
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
288
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
289
                        self.has_numpy_input = True
D
dongdaxiang 已提交
290 291
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
292
                        self.all_numpy_input = False
M
MRXLT 已提交
293
                elif self.feed_types_[key] == float_type:
G
guru4elephant 已提交
294
                    if i == 0:
M
MRXLT 已提交
295
                        float_feed_names.append(key)
D
dongdaxiang 已提交
296
                        if isinstance(feed_i[key], np.ndarray):
297
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
298 299
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
300
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
301
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
302
                        self.has_numpy_input = True
D
dongdaxiang 已提交
303 304
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
305
                        self.all_numpy_input = False
M
MRXLT 已提交
306 307 308
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
309 310 311
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

312
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
313
        if self.all_numpy_input:
M
MRXLT 已提交
314 315
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
316 317
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
318
        elif self.has_numpy_input == False:
M
MRXLT 已提交
319 320
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
321 322
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
323
        else:
M
MRXLT 已提交
324
            raise ValueError(
M
MRXLT 已提交
325 326
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
327

W
WangXi 已提交
328 329 330
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

331 332 333
        if res == -1:
            return None

B
barrierye 已提交
334
        multi_result_map = []
335
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
336
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
337
            result_map = {}
B
barrierye 已提交
338
            # result map needs to be a numpy array
B
barrierye 已提交
339 340
            for i, name in enumerate(fetch_names):
                if self.fetch_names_to_type_[name] == int_type:
B
barrierye 已提交
341
                    # result_map[name] will be py::array(numpy array)
342 343 344
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
345 346
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
347 348
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
349
                elif self.fetch_names_to_type_[name] == float_type:
350 351 352
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
353 354
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
355 356
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
357
            multi_result_map.append(result_map)
B
barrierye 已提交
358 359
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
360 361
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
362
        else:
B
barrierye 已提交
363 364 365 366 367 368
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
369 370 371
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
372
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
373
        return ret if not need_variant_tag else [
374
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
375
        ]
B
barrierye 已提交
376

377 378
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
379
        self.client_handle_ = None
B
barrierye 已提交
380 381


382
class MultiLangClient(object):
B
barrierye 已提交
383 384 385 386
    def __init__(self):
        self.channel_ = None

    def load_client_config(self, path):
B
barrierye 已提交
387 388 389
        if not isinstance(path, str):
            raise Exception("GClient only supports multi-model temporarily")
        self._parse_model_config(path)
B
barrierye 已提交
390 391

    def connect(self, endpoint):
B
barrierye 已提交
392
        self.channel_ = grpc.insecure_channel(endpoint[0])  #TODO
393
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
394 395
            self.channel_)

B
barrierye 已提交
396 397 398 399 400 401 402 403
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

B
barrierye 已提交
404 405 406 407 408 409 410
    def _parse_model_config(self, model_config_path):
        model_conf = m_config.GeneralModelConfig()
        f = open(model_config_path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
411
        self.feed_shapes_ = {}
B
barrierye 已提交
412
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
413 414
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
415 416 417
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
418
            if var.is_lod_tensor:
B
barrierye 已提交
419
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
420 421 422 423
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
424
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
425 426 427
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
428

B
barrierye 已提交
429
    def _pack_feed_data(self, feed, fetch, is_python):
430
        req = multi_lang_general_model_service_pb2.Request()
B
barrierye 已提交
431
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
432
        req.feed_var_names.extend(feed.keys())
B
barrierye 已提交
433
        req.is_python = is_python
B
barrierye 已提交
434 435 436 437 438 439 440
        feed_batch = None
        if isinstance(feed, dict):
            feed_batch = [feed]
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise Exception("{} not support".format(type(feed)))
B
barrierye 已提交
441
        init_feed_names = False
B
barrierye 已提交
442
        for feed_data in feed_batch:
443
            inst = multi_lang_general_model_service_pb2.FeedInst()
B
barrierye 已提交
444
            for name in req.feed_var_names:
445
                tensor = multi_lang_general_model_service_pb2.Tensor()
B
barrierye 已提交
446 447
                var = feed_data[name]
                v_type = self.feed_types_[name]
B
barrierye 已提交
448 449 450 451 452 453 454 455 456
                if is_python:
                    data = None
                    if isinstance(var, list):
                        if v_type == 0:  # int64
                            data = np.array(var, dtype="int64")
                        elif v_type == 1:  # float32
                            data = np.array(var, dtype="float32")
                        else:
                            raise Exception("error type.")
B
barrierye 已提交
457
                    else:
B
barrierye 已提交
458 459 460 461
                        data = var
                        if var.dtype == "float64":
                            data = data.astype("float32")
                    tensor.data = data.tobytes()
B
barrierye 已提交
462
                else:
B
barrierye 已提交
463 464 465 466 467 468 469 470 471 472 473 474
                    if v_type == 0:  # int64
                        if isinstance(var, np.ndarray):
                            tensor.int64_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.int64_data.extend(self._flatten_list(var))
                    elif v_type == 1:  # float32
                        if isinstance(var, np.ndarray):
                            tensor.float_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.float_data.extend(self._flatten_list(var))
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
475
                if isinstance(var, np.ndarray):
B
barrierye 已提交
476
                    tensor.shape.extend(list(var.shape))
B
barrierye 已提交
477
                else:
B
barrierye 已提交
478 479 480
                    tensor.shape.extend(self.feed_shapes_[name])
                inst.tensor_array.append(tensor)
            req.insts.append(inst)
B
barrierye 已提交
481
        return req
B
barrierye 已提交
482

B
barrierye 已提交
483
    def _unpack_resp(self, resp, fetch, is_python, need_variant_tag):
B
barrierye 已提交
484
        result_map = {}
B
barrierye 已提交
485 486 487 488 489
        inst = resp.outputs[0].insts[0]
        tag = resp.tag
        for i, name in enumerate(fetch):
            var = inst.tensor_array[i]
            v_type = self.fetch_types_[name]
B
barrierye 已提交
490 491 492 493 494 495 496
            if is_python:
                if v_type == 0:  # int64
                    result_map[name] = np.frombuffer(var.data, dtype="int64")
                elif v_type == 1:  # float32
                    result_map[name] = np.frombuffer(var.data, dtype="float32")
                else:
                    raise Exception("error type.")
B
barrierye 已提交
497
            else:
B
barrierye 已提交
498
                if v_type == 0:  # int64
499 500
                    result_map[name] = np.array(
                        list(var.int64_data), dtype="int64")
B
barrierye 已提交
501
                elif v_type == 1:  # float32
502 503
                    result_map[name] = np.array(
                        list(var.float_data), dtype="float32")
B
barrierye 已提交
504 505
                else:
                    raise Exception("error type.")
B
barrierye 已提交
506
            result_map[name].shape = list(var.shape)
B
barrierye 已提交
507
            if name in self.lod_tensor_set_:
B
barrierye 已提交
508
                result_map["{}.lod".format(name)] = np.array(list(var.lod))
509 510
        return result_map if not need_variant_tag else [result_map, tag]

B
barrierye 已提交
511
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
512
        def unpack_resp(resp):
B
barrierye 已提交
513
            return self._unpack_resp(resp, fetch, is_python, need_variant_tag)
B
barrierye 已提交
514

515 516
        return unpack_resp

B
barrierye 已提交
517 518 519 520 521 522 523
    def predict(self,
                feed,
                fetch,
                need_variant_tag=False,
                asyn=False,
                is_python=True):
        req = self._pack_feed_data(feed, fetch, is_python=is_python)
524 525
        if not asyn:
            resp = self.stub_.inference(req)
B
barrierye 已提交
526 527 528 529 530
            return self._unpack_resp(
                resp,
                fetch,
                is_python=is_python,
                need_variant_tag=need_variant_tag)
531 532 533
        else:
            call_future = self.stub_.inference.future(req)
            return MultiLangPredictFuture(
B
barrierye 已提交
534 535 536 537 538
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
539 540 541 542 543 544 545 546 547 548


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)