__init__.py 25.0 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
G
guru4elephant 已提交
24

B
barrierye 已提交
25
import grpc
B
barrierye 已提交
26
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
27 28
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
29
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
30

G
guru4elephant 已提交
31 32 33
int_type = 0
float_type = 1

M
MRXLT 已提交
34

W
WangXi 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
63 64 65
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
66 67 68
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
69 70
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
71

72 73 74 75
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
76

M
MRXLT 已提交
77 78 79 80
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
81 82 83 84 85
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
86 87
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
88

89 90 91 92 93 94
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
95 96 97 98

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
99
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
100 101 102 103 104
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
105

G
guru4elephant 已提交
106 107 108 109 110 111 112 113
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
114
        return self.sdk_desc
G
guru4elephant 已提交
115

G
guru4elephant 已提交
116 117 118 119 120 121

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
122
        self.feed_shapes_ = {}
G
guru4elephant 已提交
123
        self.feed_types_ = {}
G
guru4elephant 已提交
124
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
125
        self.pid = os.getpid()
B
barrierye 已提交
126
        self.predictor_sdk_ = None
G
guru4elephant 已提交
127 128
        self.producers = []
        self.consumer = None
W
WangXi 已提交
129
        self.profile_ = _Profiler()
M
MRXLT 已提交
130 131
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
132
        self.rpc_timeout_ms = 20000
133 134
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
M
MRXLT 已提交
135

G
guru4elephant 已提交
136
    def load_client_config(self, path):
M
MRXLT 已提交
137
        from .serving_client import PredictorClient
138 139 140 141 142
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
143 144 145 146
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
147 148
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
149 150
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
151
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
152 153
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
154 155
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
156
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
157 158
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
159
        self.lod_tensor_set = set()
M
MRXLT 已提交
160
        self.feed_tensor_len = {}
161

162 163 164
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
165
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
166

M
MRXLT 已提交
167 168
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
169 170 171 172 173
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
174 175 176
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
177 178
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
179 180
        return

181
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
182 183
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
184 185 186
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
187 188 189 190 191 192
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

B
barrierye 已提交
193
    def connect(self, endpoints=None):
G
guru4elephant 已提交
194 195 196
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
197 198
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
199
                raise ValueError(
B
barrierye 已提交
200 201 202 203
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
204
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
205
                                 100)
B
barrierye 已提交
206 207
            else:
                print(
208
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
209
                    format(endpoints))
M
MRXLT 已提交
210
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
211 212
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
213 214 215 216 217 218 219

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
220 221 222
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
223 224
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
225
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
226 227 228
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
229 230 231
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
232

233
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
234 235
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
236 237 238
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

239 240 241 242 243 244
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
245
            raise ValueError("Fetch only accepts string and list of string")
246 247 248 249 250 251 252

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
253
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
254

M
MRXLT 已提交
255 256 257 258
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
259 260
        int_shape = []
        float_shape = []
M
MRXLT 已提交
261
        fetch_names = []
M
MRXLT 已提交
262
        counter = 0
M
MRXLT 已提交
263
        batch_size = len(feed_batch)
264 265 266 267 268 269 270

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
271
                "Fetch names should not be empty or out of saved fetch list.")
272 273
            return {}

G
guru4elephant 已提交
274
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
275 276
            int_slot = []
            float_slot = []
277
            for key in feed_i:
M
MRXLT 已提交
278
                if key not in self.feed_names_:
M
MRXLT 已提交
279
                    raise ValueError("Wrong feed name: {}.".format(key))
M
MRXLT 已提交
280 281
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
282
                if self.feed_types_[key] == int_type:
G
guru4elephant 已提交
283
                    if i == 0:
M
MRXLT 已提交
284
                        int_feed_names.append(key)
D
dongdaxiang 已提交
285
                        if isinstance(feed_i[key], np.ndarray):
286
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
287 288
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
289
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
290
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
291
                        self.has_numpy_input = True
D
dongdaxiang 已提交
292 293
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
294
                        self.all_numpy_input = False
M
MRXLT 已提交
295
                elif self.feed_types_[key] == float_type:
G
guru4elephant 已提交
296
                    if i == 0:
M
MRXLT 已提交
297
                        float_feed_names.append(key)
D
dongdaxiang 已提交
298
                        if isinstance(feed_i[key], np.ndarray):
299
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
300 301
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
302
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
303
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
304
                        self.has_numpy_input = True
D
dongdaxiang 已提交
305 306
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
307
                        self.all_numpy_input = False
M
MRXLT 已提交
308 309 310
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
311 312 313
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

314
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
315
        if self.all_numpy_input:
M
MRXLT 已提交
316 317
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
318 319
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
320
        elif self.has_numpy_input == False:
M
MRXLT 已提交
321 322
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
323 324
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
325
        else:
M
MRXLT 已提交
326
            raise ValueError(
M
MRXLT 已提交
327 328
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
329

W
WangXi 已提交
330 331 332
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

333 334 335
        if res == -1:
            return None

B
barrierye 已提交
336
        multi_result_map = []
337
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
338
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
339
            result_map = {}
B
barrierye 已提交
340
            # result map needs to be a numpy array
B
barrierye 已提交
341 342
            for i, name in enumerate(fetch_names):
                if self.fetch_names_to_type_[name] == int_type:
B
barrierye 已提交
343
                    # result_map[name] will be py::array(numpy array)
344 345 346
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
347 348
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
349 350
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
351
                elif self.fetch_names_to_type_[name] == float_type:
352 353 354
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
355 356
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
357 358
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
359
            multi_result_map.append(result_map)
B
barrierye 已提交
360 361
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
362 363
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
364
        else:
B
barrierye 已提交
365 366 367 368 369 370
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
371 372 373
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
374
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
375
        return ret if not need_variant_tag else [
376
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
377
        ]
B
barrierye 已提交
378

379 380
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
381
        self.client_handle_ = None
B
barrierye 已提交
382 383


384
class MultiLangClient(object):
B
barrierye 已提交
385 386
    def __init__(self):
        self.channel_ = None
387
        self.stub_ = None
B
barrierye 已提交
388
        self.rpc_timeout_s_ = 2
B
barrierye 已提交
389

B
barrierye 已提交
390 391
    def add_variant(self, tag, cluster, variant_weight):
        # TODO
B
barrierye 已提交
392
        raise Exception("cannot support ABtest yet")
B
barrierye 已提交
393 394

    def set_rpc_timeout_ms(self, rpc_timeout):
395 396 397 398 399
        if self.stub_ is None:
            raise Exception("set timeout must be set after connect.")
        if not isinstance(rpc_timeout, int):
            # for bclient
            raise ValueError("rpc_timeout must be int type.")
B
barrierye 已提交
400
        self.rpc_timeout_s_ = rpc_timeout / 1000.0
401 402 403 404
        timeout_req = multi_lang_general_model_service_pb2.SetTimeoutRequest()
        timeout_req.timeout_ms = rpc_timeout
        resp = self.stub_.SetTimeout(timeout_req)
        return resp.err_code == 0
B
barrierye 已提交
405 406

    def connect(self, endpoints):
W
WangXi 已提交
407 408
        # https://github.com/tensorflow/serving/issues/1382
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
409 410
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.lb_policy_name', 'round_robin')]
B
barrierye 已提交
411
        # TODO: weight round robin
412
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
B
barrierye 已提交
413
        self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
414
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
415
            self.channel_)
416 417 418 419 420 421
        # get client model config
        get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
        )
        resp = self.stub_.GetClientConfig(get_client_config_req)
        model_config_str = resp.client_config_str
        self._parse_model_config(model_config_str)
B
barrierye 已提交
422

B
barrierye 已提交
423 424 425 426 427 428 429 430
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

431
    def _parse_model_config(self, model_config_str):
B
barrierye 已提交
432
        model_conf = m_config.GeneralModelConfig()
433 434
        model_conf = google.protobuf.text_format.Merge(model_config_str,
                                                       model_conf)
B
barrierye 已提交
435 436
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
437
        self.feed_shapes_ = {}
B
barrierye 已提交
438
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
439 440
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
441 442 443
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
444
            if var.is_lod_tensor:
B
barrierye 已提交
445
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
446 447 448 449
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
450
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
451 452 453
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
454

455 456
    def _pack_inference_request(self, feed, fetch, is_python):
        req = multi_lang_general_model_service_pb2.InferenceRequest()
B
barrierye 已提交
457
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
458
        req.is_python = is_python
B
barrierye 已提交
459 460 461 462 463 464 465
        feed_batch = None
        if isinstance(feed, dict):
            feed_batch = [feed]
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise Exception("{} not support".format(type(feed)))
W
WangXi 已提交
466
        req.feed_var_names.extend(feed_batch[0].keys())
B
barrierye 已提交
467
        init_feed_names = False
B
barrierye 已提交
468
        for feed_data in feed_batch:
469
            inst = multi_lang_general_model_service_pb2.FeedInst()
B
barrierye 已提交
470
            for name in req.feed_var_names:
471
                tensor = multi_lang_general_model_service_pb2.Tensor()
B
barrierye 已提交
472 473
                var = feed_data[name]
                v_type = self.feed_types_[name]
B
barrierye 已提交
474 475 476 477 478 479 480
                if is_python:
                    data = None
                    if isinstance(var, list):
                        if v_type == 0:  # int64
                            data = np.array(var, dtype="int64")
                        elif v_type == 1:  # float32
                            data = np.array(var, dtype="float32")
B
barrierye 已提交
481 482
                        elif v_type == 2:  # int32
                            data = np.array(var, dtype="int32")
B
barrierye 已提交
483
                        else:
B
barrierye 已提交
484 485
                            raise Exception("error tensor value type.")
                    elif isinstance(var, np.ndarray):
B
barrierye 已提交
486
                        data = var
B
barrierye 已提交
487 488 489
                        if v_type == 0 and data.dtype != 'int64':
                            data = data.astype("int64")
                        elif v_type == 1 and data.dtype != 'float32':
B
barrierye 已提交
490
                            data = data.astype("float32")
B
barrierye 已提交
491 492 493 494 495 496
                        elif v_type == 2 and data.dtype != 'int32':
                            data = data.astype("int32")
                        else:
                            raise Exception("error tensor value type.")
                    else:
                        raise Exception("var must be list or ndarray.")
B
barrierye 已提交
497
                    tensor.data = data.tobytes()
B
barrierye 已提交
498
                else:
B
barrierye 已提交
499 500 501 502 503 504 505 506 507 508
                    if isinstance(var, np.ndarray):
                        if v_type == 0:  # int64
                            tensor.int64_data.extend(
                                var.reshape(-1).astype("int64").tolist())
                        elif v_type == 1:
                            tensor.float_data.extend(
                                var.reshape(-1).astype('float32').tolist())
                        elif v_type == 2:
                            tensor.int32_data.extend(
                                var.reshape(-1).astype('int32').tolist())
B
barrierye 已提交
509
                        else:
B
barrierye 已提交
510 511 512
                            raise Exception("error tensor value type.")
                    elif isinstance(var, list):
                        if v_type == 0:
B
barrierye 已提交
513
                            tensor.int64_data.extend(self._flatten_list(var))
B
barrierye 已提交
514
                        elif v_type == 1:
B
barrierye 已提交
515
                            tensor.float_data.extend(self._flatten_list(var))
B
barrierye 已提交
516 517 518 519
                        elif v_type == 2:
                            tensor.int32_data.extend(self._flatten_list(var))
                        else:
                            raise Exception("error tensor value type.")
B
barrierye 已提交
520
                    else:
B
barrierye 已提交
521
                        raise Exception("var must be list or ndarray.")
B
barrierye 已提交
522
                if isinstance(var, np.ndarray):
B
barrierye 已提交
523
                    tensor.shape.extend(list(var.shape))
B
barrierye 已提交
524
                else:
B
barrierye 已提交
525 526 527
                    tensor.shape.extend(self.feed_shapes_[name])
                inst.tensor_array.append(tensor)
            req.insts.append(inst)
B
barrierye 已提交
528
        return req
B
barrierye 已提交
529

530 531 532
    def _unpack_inference_response(self, resp, fetch, is_python,
                                   need_variant_tag):
        if resp.err_code != 0:
B
fix bug  
barrierye 已提交
533 534
            return None
        tag = resp.tag
B
barrierye 已提交
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
        multi_result_map = {}
        for model_result in resp.outputs:
            inst = model_result.insts[0]
            result_map = {}
            for i, name in enumerate(fetch):
                var = inst.tensor_array[i]
                v_type = self.fetch_types_[name]
                if is_python:
                    if v_type == 0:  # int64
                        result_map[name] = np.frombuffer(
                            var.data, dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.frombuffer(
                            var.data, dtype="float32")
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
551
                else:
B
barrierye 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
                    if v_type == 0:  # int64
                        result_map[name] = np.array(
                            list(var.int64_data), dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.array(
                            list(var.float_data), dtype="float32")
                    else:
                        raise Exception("error type.")
                result_map[name].shape = list(var.shape)
                if name in self.lod_tensor_set_:
                    result_map["{}.lod".format(name)] = np.array(list(var.lod))
            multi_result_map[model_result.engine_name] = result_map
        ret = None
        if len(resp.outputs) == 1:
            ret = multi_result_map.values()[0]
        else:
            ret = multi_result_map
        return ret if not need_variant_tag else [ret, tag]
570

B
barrierye 已提交
571
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
572
        def unpack_resp(resp):
573 574
            return self._unpack_inference_response(resp, fetch, is_python,
                                                   need_variant_tag)
B
barrierye 已提交
575

576 577
        return unpack_resp

W
WangXi 已提交
578 579 580
    def get_feed_names(self):
        return self.feed_names_

B
barrierye 已提交
581 582 583 584 585
    def predict(self,
                feed,
                fetch,
                need_variant_tag=False,
                asyn=False,
586 587
                is_python=True):
        req = self._pack_inference_request(feed, fetch, is_python=is_python)
588
        if not asyn:
589 590
            resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
            return self._unpack_inference_response(
B
barrierye 已提交
591 592 593 594
                resp,
                fetch,
                is_python=is_python,
                need_variant_tag=need_variant_tag)
595
        else:
596 597
            call_future = self.stub_.Inference.future(
                req, timeout=self.rpc_timeout_s_)
598
            return MultiLangPredictFuture(
B
barrierye 已提交
599 600 601 602 603
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
604 605 606 607 608 609 610 611 612 613


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)
W
WangXi 已提交
614 615 616 617 618 619 620

    def add_done_callback(self, fn):
        def __fn__(call_future):
            assert call_future == self.call_future_
            fn(self)

        self.call_future_.add_done_callback(__fn__)