__init__.py 21.4 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
24
from .serving_client import PredictorRes
G
guru4elephant 已提交
25

B
barrierye 已提交
26
import grpc
B
barrierye 已提交
27 28
from .proto import multi_lang_general_model_service_pb2
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
29

G
guru4elephant 已提交
30 31 32
int_type = 0
float_type = 1

M
MRXLT 已提交
33

W
WangXi 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
62 63 64
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
65 66 67
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
68 69
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
70

71 72 73 74
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
75

M
MRXLT 已提交
76 77 78 79
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
80 81 82 83 84
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
85 86
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
87

88 89 90 91 92 93
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
94 95 96 97

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
98
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
99 100 101 102 103
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
104

G
guru4elephant 已提交
105 106 107 108 109 110 111 112
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
113
        return self.sdk_desc
G
guru4elephant 已提交
114

G
guru4elephant 已提交
115 116 117 118 119 120

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
121
        self.feed_shapes_ = {}
G
guru4elephant 已提交
122
        self.feed_types_ = {}
G
guru4elephant 已提交
123
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
124
        self.pid = os.getpid()
B
barrierye 已提交
125
        self.predictor_sdk_ = None
G
guru4elephant 已提交
126 127
        self.producers = []
        self.consumer = None
W
WangXi 已提交
128
        self.profile_ = _Profiler()
M
MRXLT 已提交
129 130
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
131
        self.rpc_timeout_ms = 20000
M
MRXLT 已提交
132

G
guru4elephant 已提交
133
    def load_client_config(self, path):
M
MRXLT 已提交
134
        from .serving_client import PredictorClient
135 136 137 138 139
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
140 141 142 143
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
144 145
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
146 147
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
148
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
149 150
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
151 152
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
153
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
154 155
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
156
        self.lod_tensor_set = set()
M
MRXLT 已提交
157
        self.feed_tensor_len = {}
158

159 160 161
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
162
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
163

M
MRXLT 已提交
164 165
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
166 167 168 169 170
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
171 172 173
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
174 175
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
176 177
        return

178
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
179 180
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
181 182 183
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
184 185 186 187 188 189
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

B
barrierye 已提交
190
    def connect(self, endpoints=None):
G
guru4elephant 已提交
191 192 193
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
194 195
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
196
                raise ValueError(
B
barrierye 已提交
197 198 199 200
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
201
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
202
                                 100)
B
barrierye 已提交
203 204
            else:
                print(
205
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
206
                    format(endpoints))
M
MRXLT 已提交
207
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
208 209
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
210 211 212 213 214 215 216

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
217 218 219
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
220 221
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
222
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
223 224 225
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
226 227 228
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
229

230
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
231 232
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
233 234 235
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

236 237 238 239 240 241
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
242
            raise ValueError("Fetch only accepts string and list of string")
243 244 245 246 247 248 249

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
250
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
251

M
MRXLT 已提交
252 253 254 255
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
256 257
        int_shape = []
        float_shape = []
M
MRXLT 已提交
258
        fetch_names = []
M
MRXLT 已提交
259
        counter = 0
M
MRXLT 已提交
260
        batch_size = len(feed_batch)
261 262 263 264 265 266 267

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
268
                "Fetch names should not be empty or out of saved fetch list.")
269 270
            return {}

G
guru4elephant 已提交
271
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
272 273
            int_slot = []
            float_slot = []
274
            for key in feed_i:
M
MRXLT 已提交
275
                if key not in self.feed_names_:
M
MRXLT 已提交
276
                    raise ValueError("Wrong feed name: {}.".format(key))
M
MRXLT 已提交
277 278
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
279
                if self.feed_types_[key] == int_type:
G
guru4elephant 已提交
280
                    if i == 0:
M
MRXLT 已提交
281
                        int_feed_names.append(key)
D
dongdaxiang 已提交
282
                        if isinstance(feed_i[key], np.ndarray):
283
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
284 285
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
286
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
287
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
288
                        self.has_numpy_input = True
D
dongdaxiang 已提交
289 290
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
291
                        self.all_numpy_input = False
M
MRXLT 已提交
292
                elif self.feed_types_[key] == float_type:
G
guru4elephant 已提交
293
                    if i == 0:
M
MRXLT 已提交
294
                        float_feed_names.append(key)
D
dongdaxiang 已提交
295
                        if isinstance(feed_i[key], np.ndarray):
296
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
297 298
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
299
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
300
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
301
                        self.has_numpy_input = True
D
dongdaxiang 已提交
302 303
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
304
                        self.all_numpy_input = False
M
MRXLT 已提交
305 306 307
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
308 309 310
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

311
        result_batch_handle = PredictorRes()
M
MRXLT 已提交
312
        if self.all_numpy_input:
M
MRXLT 已提交
313 314
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
315 316
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
317
        elif self.has_numpy_input == False:
M
MRXLT 已提交
318 319
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
320 321
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
322
        else:
M
MRXLT 已提交
323
            raise ValueError(
M
MRXLT 已提交
324 325
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
326

W
WangXi 已提交
327 328 329
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

330 331 332
        if res == -1:
            return None

B
barrierye 已提交
333
        multi_result_map = []
334
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
335
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
336
            result_map = {}
B
barrierye 已提交
337
            # result map needs to be a numpy array
B
barrierye 已提交
338 339
            for i, name in enumerate(fetch_names):
                if self.fetch_names_to_type_[name] == int_type:
B
barrierye 已提交
340
                    # result_map[name] will be py::array(numpy array)
341 342 343
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
344 345
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
346 347
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
348
                elif self.fetch_names_to_type_[name] == float_type:
349 350 351
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
352 353
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
354 355
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
356
            multi_result_map.append(result_map)
B
barrierye 已提交
357 358
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
359 360
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
361
        else:
B
barrierye 已提交
362 363 364 365 366 367
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
368 369 370
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
371
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
372
        return ret if not need_variant_tag else [
373
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
374
        ]
B
barrierye 已提交
375

376 377
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
378
        self.client_handle_ = None
B
barrierye 已提交
379 380


381
class MultiLangClient(object):
B
barrierye 已提交
382 383 384 385
    def __init__(self):
        self.channel_ = None

    def load_client_config(self, path):
B
barrierye 已提交
386 387 388
        if not isinstance(path, str):
            raise Exception("GClient only supports multi-model temporarily")
        self._parse_model_config(path)
B
barrierye 已提交
389 390

    def connect(self, endpoint):
B
barrierye 已提交
391
        self.channel_ = grpc.insecure_channel(endpoint[0])  #TODO
392
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
393 394
            self.channel_)

B
barrierye 已提交
395 396 397 398 399 400 401 402
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

B
barrierye 已提交
403 404 405 406 407 408 409
    def _parse_model_config(self, model_config_path):
        model_conf = m_config.GeneralModelConfig()
        f = open(model_config_path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
410
        self.feed_shapes_ = {}
B
barrierye 已提交
411
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
412 413
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
414 415 416
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
417
            if var.is_lod_tensor:
B
barrierye 已提交
418
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
419 420 421 422
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
423
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
424 425 426
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
427

B
barrierye 已提交
428
    def _pack_feed_data(self, feed, fetch, is_python):
429
        req = multi_lang_general_model_service_pb2.Request()
B
barrierye 已提交
430
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
431
        req.feed_var_names.extend(feed.keys())
B
barrierye 已提交
432
        req.is_python = is_python
B
barrierye 已提交
433 434 435 436 437 438 439
        feed_batch = None
        if isinstance(feed, dict):
            feed_batch = [feed]
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise Exception("{} not support".format(type(feed)))
B
barrierye 已提交
440
        init_feed_names = False
B
barrierye 已提交
441
        for feed_data in feed_batch:
442
            inst = multi_lang_general_model_service_pb2.FeedInst()
B
barrierye 已提交
443
            for name in req.feed_var_names:
444
                tensor = multi_lang_general_model_service_pb2.Tensor()
B
barrierye 已提交
445 446
                var = feed_data[name]
                v_type = self.feed_types_[name]
B
barrierye 已提交
447 448 449 450 451 452 453 454 455
                if is_python:
                    data = None
                    if isinstance(var, list):
                        if v_type == 0:  # int64
                            data = np.array(var, dtype="int64")
                        elif v_type == 1:  # float32
                            data = np.array(var, dtype="float32")
                        else:
                            raise Exception("error type.")
B
barrierye 已提交
456
                    else:
B
barrierye 已提交
457 458 459 460
                        data = var
                        if var.dtype == "float64":
                            data = data.astype("float32")
                    tensor.data = data.tobytes()
B
barrierye 已提交
461
                else:
B
barrierye 已提交
462 463 464 465 466 467 468 469 470 471 472 473
                    if v_type == 0:  # int64
                        if isinstance(var, np.ndarray):
                            tensor.int64_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.int64_data.extend(self._flatten_list(var))
                    elif v_type == 1:  # float32
                        if isinstance(var, np.ndarray):
                            tensor.float_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.float_data.extend(self._flatten_list(var))
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
474
                if isinstance(var, np.ndarray):
B
barrierye 已提交
475
                    tensor.shape.extend(list(var.shape))
B
barrierye 已提交
476
                else:
B
barrierye 已提交
477 478 479
                    tensor.shape.extend(self.feed_shapes_[name])
                inst.tensor_array.append(tensor)
            req.insts.append(inst)
B
barrierye 已提交
480
        return req
B
barrierye 已提交
481

B
barrierye 已提交
482
    def _unpack_resp(self, resp, fetch, is_python, need_variant_tag):
B
barrierye 已提交
483
        result_map = {}
B
barrierye 已提交
484 485 486 487 488
        inst = resp.outputs[0].insts[0]
        tag = resp.tag
        for i, name in enumerate(fetch):
            var = inst.tensor_array[i]
            v_type = self.fetch_types_[name]
B
barrierye 已提交
489 490 491 492 493 494 495
            if is_python:
                if v_type == 0:  # int64
                    result_map[name] = np.frombuffer(var.data, dtype="int64")
                elif v_type == 1:  # float32
                    result_map[name] = np.frombuffer(var.data, dtype="float32")
                else:
                    raise Exception("error type.")
B
barrierye 已提交
496
            else:
B
barrierye 已提交
497 498 499 500 501 502
                if v_type == 0:  # int64
                    result_map[name] = np.array(list(var.int64_data))
                elif v_type == 1:  # float32
                    result_map[name] = np.array(list(var.float_data))
                else:
                    raise Exception("error type.")
B
barrierye 已提交
503
            result_map[name].shape = list(var.shape)
B
barrierye 已提交
504
            if name in self.lod_tensor_set_:
B
barrierye 已提交
505
                result_map["{}.lod".format(name)] = np.array(list(var.lod))
506 507
        return result_map if not need_variant_tag else [result_map, tag]

B
barrierye 已提交
508
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
509
        def unpack_resp(resp):
B
barrierye 已提交
510
            return self._unpack_resp(resp, fetch, is_python, need_variant_tag)
B
barrierye 已提交
511

512 513
        return unpack_resp

B
barrierye 已提交
514 515 516 517 518 519 520
    def predict(self,
                feed,
                fetch,
                need_variant_tag=False,
                asyn=False,
                is_python=True):
        req = self._pack_feed_data(feed, fetch, is_python=is_python)
521 522
        if not asyn:
            resp = self.stub_.inference(req)
B
barrierye 已提交
523 524 525 526 527
            return self._unpack_resp(
                resp,
                fetch,
                is_python=is_python,
                need_variant_tag=need_variant_tag)
528 529 530
        else:
            call_future = self.stub_.inference.future(req)
            return MultiLangPredictFuture(
B
barrierye 已提交
531 532 533 534 535
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
536 537 538 539 540 541 542 543 544 545


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)