__init__.py 26.6 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
G
guru4elephant 已提交
24

B
barrierye 已提交
25
import grpc
B
barrierye 已提交
26
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
27 28
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
29
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
30

M
MRXLT 已提交
31 32 33 34 35
int64_type = 0
float32_type = 1
int32_type = 2
int_type = set([int64_type, int32_type])
float_type = set([float32_type])
G
guru4elephant 已提交
36

M
MRXLT 已提交
37

W
WangXi 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
66 67 68
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
69 70 71
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
72 73
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
74

75 76 77 78
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
79

M
MRXLT 已提交
80 81 82 83
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
84 85 86 87 88
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
89 90
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
91

92 93 94 95 96 97
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
98 99 100 101

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
102
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
103 104 105 106 107
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
108

G
guru4elephant 已提交
109 110 111 112 113 114 115 116
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
117
        return self.sdk_desc
G
guru4elephant 已提交
118

G
guru4elephant 已提交
119 120 121 122 123 124

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
125
        self.feed_shapes_ = {}
G
guru4elephant 已提交
126
        self.feed_types_ = {}
G
guru4elephant 已提交
127
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
128
        self.pid = os.getpid()
B
barrierye 已提交
129
        self.predictor_sdk_ = None
G
guru4elephant 已提交
130 131
        self.producers = []
        self.consumer = None
W
WangXi 已提交
132
        self.profile_ = _Profiler()
M
MRXLT 已提交
133 134
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
135
        self.rpc_timeout_ms = 20000
136 137
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
B
barrierye 已提交
138
        self.write_profile_into_fetch_map_ = False  # only for grpc impl
M
MRXLT 已提交
139

G
guru4elephant 已提交
140
    def load_client_config(self, path):
M
MRXLT 已提交
141
        from .serving_client import PredictorClient
142 143 144 145 146
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
147 148 149 150
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
151 152
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
153 154
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
155
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
156 157
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
158 159
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
160
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
161 162
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
163
        self.lod_tensor_set = set()
M
MRXLT 已提交
164
        self.feed_tensor_len = {}
165

166 167 168
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
169
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
170

M
MRXLT 已提交
171 172
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
173 174 175 176 177
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
178 179 180
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
181 182
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
183 184
        return

185
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
186 187
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
188 189 190
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
191 192 193 194 195 196
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

B
barrierye 已提交
197
    def connect(self, endpoints=None):
G
guru4elephant 已提交
198 199 200
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
201 202
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
203
                raise ValueError(
B
barrierye 已提交
204 205 206 207
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
208
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
209
                                 100)
B
barrierye 已提交
210 211
            else:
                print(
212
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
213
                    format(endpoints))
M
MRXLT 已提交
214
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
215 216
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
217 218 219 220 221 222 223

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
224 225 226
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
227 228
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
229
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
230 231 232
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
233 234 235
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
236

237
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
238 239
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
240 241 242
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

243 244 245 246 247 248
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
249
            raise ValueError("Fetch only accepts string and list of string")
250 251 252 253 254 255 256

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
257
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
258

M
MRXLT 已提交
259 260 261 262
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
263 264
        int_shape = []
        float_shape = []
M
MRXLT 已提交
265
        fetch_names = []
M
MRXLT 已提交
266
        counter = 0
M
MRXLT 已提交
267
        batch_size = len(feed_batch)
268 269 270 271 272 273 274

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
275
                "Fetch names should not be empty or out of saved fetch list.")
276 277
            return {}

G
guru4elephant 已提交
278
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
279 280
            int_slot = []
            float_slot = []
281
            for key in feed_i:
M
MRXLT 已提交
282
                if key not in self.feed_names_:
M
MRXLT 已提交
283
                    raise ValueError("Wrong feed name: {}.".format(key))
M
MRXLT 已提交
284 285
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
286
                if self.feed_types_[key] in int_type:
G
guru4elephant 已提交
287
                    if i == 0:
M
MRXLT 已提交
288
                        int_feed_names.append(key)
D
dongdaxiang 已提交
289
                        if isinstance(feed_i[key], np.ndarray):
290
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
291 292
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
293
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
294
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
295
                        self.has_numpy_input = True
D
dongdaxiang 已提交
296 297
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
298
                        self.all_numpy_input = False
M
MRXLT 已提交
299
                elif self.feed_types_[key] in float_type:
G
guru4elephant 已提交
300
                    if i == 0:
M
MRXLT 已提交
301
                        float_feed_names.append(key)
D
dongdaxiang 已提交
302
                        if isinstance(feed_i[key], np.ndarray):
303
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
304 305
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
306
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
307
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
308
                        self.has_numpy_input = True
D
dongdaxiang 已提交
309 310
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
311
                        self.all_numpy_input = False
M
MRXLT 已提交
312 313 314
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
315 316 317
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

318
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
319
        if self.all_numpy_input:
M
MRXLT 已提交
320 321
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
322 323
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
324
        elif self.has_numpy_input == False:
M
MRXLT 已提交
325 326
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
327 328
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
329
        else:
M
MRXLT 已提交
330
            raise ValueError(
M
MRXLT 已提交
331 332
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
333

W
WangXi 已提交
334 335 336
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

337 338 339
        if res == -1:
            return None

B
barrierye 已提交
340
        multi_result_map = []
341
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
342
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
343
            result_map = {}
B
barrierye 已提交
344
            # result map needs to be a numpy array
B
barrierye 已提交
345
            for i, name in enumerate(fetch_names):
M
MRXLT 已提交
346
                if self.fetch_names_to_type_[name] == int64_type:
B
barrierye 已提交
347
                    # result_map[name] will be py::array(numpy array)
348 349 350
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
351 352
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
353 354
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
M
MRXLT 已提交
355
                elif self.fetch_names_to_type_[name] == float32_type:
356 357 358
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
359 360
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
361 362
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
M
MRXLT 已提交
363 364 365 366 367 368 369 370 371 372

                elif self.fetch_names_to_type_[name] == int32_type:
                    # result_map[name] will be py::array(numpy array)
                    result_map[name] = result_batch_handle.get_int32_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
373
            multi_result_map.append(result_map)
B
barrierye 已提交
374 375
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
376 377
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
378
        else:
B
barrierye 已提交
379 380 381 382 383 384
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
385 386 387
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
388
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
389
        return ret if not need_variant_tag else [
390
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
391
        ]
B
barrierye 已提交
392

393 394
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
395
        self.client_handle_ = None
B
barrierye 已提交
396 397


398
class MultiLangClient(object):
B
barrierye 已提交
399 400
    def __init__(self):
        self.channel_ = None
401
        self.stub_ = None
B
barrierye 已提交
402
        self.rpc_timeout_s_ = 2
B
barrierye 已提交
403
        self.profile_ = _Profiler()
B
barrierye 已提交
404

B
barrierye 已提交
405 406
    def add_variant(self, tag, cluster, variant_weight):
        # TODO
B
barrierye 已提交
407
        raise Exception("cannot support ABtest yet")
B
barrierye 已提交
408 409

    def set_rpc_timeout_ms(self, rpc_timeout):
410 411 412 413 414
        if self.stub_ is None:
            raise Exception("set timeout must be set after connect.")
        if not isinstance(rpc_timeout, int):
            # for bclient
            raise ValueError("rpc_timeout must be int type.")
B
barrierye 已提交
415
        self.rpc_timeout_s_ = rpc_timeout / 1000.0
416 417 418 419
        timeout_req = multi_lang_general_model_service_pb2.SetTimeoutRequest()
        timeout_req.timeout_ms = rpc_timeout
        resp = self.stub_.SetTimeout(timeout_req)
        return resp.err_code == 0
B
barrierye 已提交
420 421

    def connect(self, endpoints):
W
WangXi 已提交
422 423
        # https://github.com/tensorflow/serving/issues/1382
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
424 425
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.lb_policy_name', 'round_robin')]
B
barrierye 已提交
426
        # TODO: weight round robin
427
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
B
barrierye 已提交
428
        self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
429
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
430
            self.channel_)
431 432 433 434 435 436
        # get client model config
        get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
        )
        resp = self.stub_.GetClientConfig(get_client_config_req)
        model_config_str = resp.client_config_str
        self._parse_model_config(model_config_str)
B
barrierye 已提交
437

B
barrierye 已提交
438 439 440 441 442 443 444 445
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

446
    def _parse_model_config(self, model_config_str):
B
barrierye 已提交
447
        model_conf = m_config.GeneralModelConfig()
448 449
        model_conf = google.protobuf.text_format.Merge(model_config_str,
                                                       model_conf)
B
barrierye 已提交
450 451
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
452
        self.feed_shapes_ = {}
B
barrierye 已提交
453
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
454 455
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
456 457 458
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
459
            if var.is_lod_tensor:
B
barrierye 已提交
460
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
461 462 463 464
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
465
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
466 467 468
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
469

470 471
    def _pack_inference_request(self, feed, fetch, is_python):
        req = multi_lang_general_model_service_pb2.InferenceRequest()
B
barrierye 已提交
472
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
473
        req.is_python = is_python
B
barrierye 已提交
474 475 476 477 478 479 480
        feed_batch = None
        if isinstance(feed, dict):
            feed_batch = [feed]
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise Exception("{} not support".format(type(feed)))
W
WangXi 已提交
481
        req.feed_var_names.extend(feed_batch[0].keys())
B
barrierye 已提交
482
        init_feed_names = False
B
barrierye 已提交
483
        for feed_data in feed_batch:
484
            inst = multi_lang_general_model_service_pb2.FeedInst()
B
barrierye 已提交
485
            for name in req.feed_var_names:
486
                tensor = multi_lang_general_model_service_pb2.Tensor()
B
barrierye 已提交
487 488
                var = feed_data[name]
                v_type = self.feed_types_[name]
B
barrierye 已提交
489 490 491 492 493 494 495
                if is_python:
                    data = None
                    if isinstance(var, list):
                        if v_type == 0:  # int64
                            data = np.array(var, dtype="int64")
                        elif v_type == 1:  # float32
                            data = np.array(var, dtype="float32")
B
barrierye 已提交
496 497
                        elif v_type == 2:  # int32
                            data = np.array(var, dtype="int32")
B
barrierye 已提交
498
                        else:
B
barrierye 已提交
499 500
                            raise Exception("error tensor value type.")
                    elif isinstance(var, np.ndarray):
B
barrierye 已提交
501
                        data = var
B
barrierye 已提交
502 503 504 505 506 507 508 509 510
                        if v_type == 0:
                            if data.dtype != 'int64':
                                data = data.astype("int64")
                        elif v_type == 1:
                            if data.dtype != 'float32':
                                data = data.astype("float32")
                        elif v_type == 2:
                            if data.dtype != 'int32':
                                data = data.astype("int32")
B
barrierye 已提交
511 512 513 514
                        else:
                            raise Exception("error tensor value type.")
                    else:
                        raise Exception("var must be list or ndarray.")
B
barrierye 已提交
515
                    tensor.data = data.tobytes()
B
barrierye 已提交
516
                else:
B
barrierye 已提交
517 518 519 520 521 522 523 524 525 526
                    if isinstance(var, np.ndarray):
                        if v_type == 0:  # int64
                            tensor.int64_data.extend(
                                var.reshape(-1).astype("int64").tolist())
                        elif v_type == 1:
                            tensor.float_data.extend(
                                var.reshape(-1).astype('float32').tolist())
                        elif v_type == 2:
                            tensor.int32_data.extend(
                                var.reshape(-1).astype('int32').tolist())
B
barrierye 已提交
527
                        else:
B
barrierye 已提交
528 529 530
                            raise Exception("error tensor value type.")
                    elif isinstance(var, list):
                        if v_type == 0:
B
barrierye 已提交
531
                            tensor.int64_data.extend(self._flatten_list(var))
B
barrierye 已提交
532
                        elif v_type == 1:
B
barrierye 已提交
533
                            tensor.float_data.extend(self._flatten_list(var))
B
barrierye 已提交
534 535 536 537
                        elif v_type == 2:
                            tensor.int32_data.extend(self._flatten_list(var))
                        else:
                            raise Exception("error tensor value type.")
B
barrierye 已提交
538
                    else:
B
barrierye 已提交
539
                        raise Exception("var must be list or ndarray.")
B
barrierye 已提交
540
                if isinstance(var, np.ndarray):
B
barrierye 已提交
541
                    tensor.shape.extend(list(var.shape))
B
barrierye 已提交
542
                else:
B
barrierye 已提交
543 544 545
                    tensor.shape.extend(self.feed_shapes_[name])
                inst.tensor_array.append(tensor)
            req.insts.append(inst)
B
barrierye 已提交
546
        return req
B
barrierye 已提交
547

548 549 550
    def _unpack_inference_response(self, resp, fetch, is_python,
                                   need_variant_tag):
        if resp.err_code != 0:
B
fix bug  
barrierye 已提交
551 552
            return None
        tag = resp.tag
B
barrierye 已提交
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
        multi_result_map = {}
        for model_result in resp.outputs:
            inst = model_result.insts[0]
            result_map = {}
            for i, name in enumerate(fetch):
                var = inst.tensor_array[i]
                v_type = self.fetch_types_[name]
                if is_python:
                    if v_type == 0:  # int64
                        result_map[name] = np.frombuffer(
                            var.data, dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.frombuffer(
                            var.data, dtype="float32")
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
569
                else:
B
barrierye 已提交
570 571 572 573 574 575 576 577 578 579 580 581 582 583
                    if v_type == 0:  # int64
                        result_map[name] = np.array(
                            list(var.int64_data), dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.array(
                            list(var.float_data), dtype="float32")
                    else:
                        raise Exception("error type.")
                result_map[name].shape = list(var.shape)
                if name in self.lod_tensor_set_:
                    result_map["{}.lod".format(name)] = np.array(list(var.lod))
            multi_result_map[model_result.engine_name] = result_map
        ret = None
        if len(resp.outputs) == 1:
B
barrierye 已提交
584
            ret = list(multi_result_map.values())[0]
B
barrierye 已提交
585 586
        else:
            ret = multi_result_map
B
barrierye 已提交
587

588
        ret["serving_status_code"] = 0
B
barrierye 已提交
589
        return ret if not need_variant_tag else [ret, tag]
590

B
barrierye 已提交
591
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
592
        def unpack_resp(resp):
593 594
            return self._unpack_inference_response(resp, fetch, is_python,
                                                   need_variant_tag)
B
barrierye 已提交
595

596 597
        return unpack_resp

W
WangXi 已提交
598 599 600
    def get_feed_names(self):
        return self.feed_names_

B
barrierye 已提交
601 602 603 604 605
    def predict(self,
                feed,
                fetch,
                need_variant_tag=False,
                asyn=False,
606
                is_python=True):
607
        if not asyn:
B
barrierye 已提交
608
            try:
B
barrierye 已提交
609 610 611 612 613 614
                self.profile_.record('py_prepro_0')
                req = self._pack_inference_request(
                    feed, fetch, is_python=is_python)
                self.profile_.record('py_prepro_1')

                self.profile_.record('py_client_infer_0')
B
barrierye 已提交
615
                resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
B
barrierye 已提交
616 617 618 619
                self.profile_.record('py_client_infer_1')

                self.profile_.record('py_postpro_0')
                ret = self._unpack_inference_response(
B
barrierye 已提交
620 621 622 623
                    resp,
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag)
B
barrierye 已提交
624 625 626
                self.profile_.record('py_postpro_1')
                self.profile_.print_profile()
                return ret
B
barrierye 已提交
627
            except grpc.RpcError as e:
628
                return {"serving_status_code": e.code()}
629
        else:
B
barrierye 已提交
630
            req = self._pack_inference_request(feed, fetch, is_python=is_python)
631 632
            call_future = self.stub_.Inference.future(
                req, timeout=self.rpc_timeout_s_)
633
            return MultiLangPredictFuture(
B
barrierye 已提交
634 635 636 637 638
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
639 640 641 642 643 644 645 646


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
B
barrierye 已提交
647 648 649
        try:
            resp = self.call_future_.result()
        except grpc.RpcError as e:
650
            return {"serving_status_code": e.code()}
651
        return self.callback_func_(resp)
W
WangXi 已提交
652 653 654 655 656 657 658

    def add_done_callback(self, fn):
        def __fn__(call_future):
            assert call_future == self.call_future_
            fn(self)

        self.call_future_.add_done_callback(__fn__)