__init__.py 23.4 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
G
guru4elephant 已提交
24

B
barrierye 已提交
25
import grpc
B
barrierye 已提交
26
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
27 28
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
29
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
30

M
MRXLT 已提交
31 32 33 34 35
int64_type = 0
float32_type = 1
int32_type = 2
int_type = set([int64_type, int32_type])
float_type = set([float32_type])
G
guru4elephant 已提交
36

M
MRXLT 已提交
37

W
WangXi 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
66 67 68
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
69 70 71
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
72 73
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
74

75 76 77 78
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
79

M
MRXLT 已提交
80 81 82 83
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
84 85 86 87 88
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
89 90
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
91

92 93 94 95 96 97
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
98 99 100 101

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
102
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
103 104 105 106 107
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
108

G
guru4elephant 已提交
109 110 111 112 113 114 115 116
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
117
        return self.sdk_desc
G
guru4elephant 已提交
118

G
guru4elephant 已提交
119 120 121 122 123 124

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
125
        self.feed_shapes_ = {}
G
guru4elephant 已提交
126
        self.feed_types_ = {}
G
guru4elephant 已提交
127
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
128
        self.pid = os.getpid()
B
barrierye 已提交
129
        self.predictor_sdk_ = None
G
guru4elephant 已提交
130 131
        self.producers = []
        self.consumer = None
W
WangXi 已提交
132
        self.profile_ = _Profiler()
M
MRXLT 已提交
133 134
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
135
        self.rpc_timeout_ms = 20000
136 137
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
M
MRXLT 已提交
138

G
guru4elephant 已提交
139
    def load_client_config(self, path):
M
MRXLT 已提交
140
        from .serving_client import PredictorClient
141 142 143 144 145
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
146 147 148 149
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
150 151
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
152 153
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
154
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
155 156
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
157 158
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
159
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
160 161
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
162
        self.lod_tensor_set = set()
M
MRXLT 已提交
163
        self.feed_tensor_len = {}
164

165 166 167
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
168
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
169

M
MRXLT 已提交
170 171
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
172 173 174 175 176
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
177 178 179
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
180 181
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
182 183
        return

184
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
185 186
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
187 188 189
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
190 191 192 193 194 195
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

B
barrierye 已提交
196
    def connect(self, endpoints=None):
G
guru4elephant 已提交
197 198 199
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
200 201
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
202
                raise ValueError(
B
barrierye 已提交
203 204 205 206
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
207
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
208
                                 100)
B
barrierye 已提交
209 210
            else:
                print(
211
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
212
                    format(endpoints))
M
MRXLT 已提交
213
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
214 215
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
216 217 218 219 220 221 222

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
223 224 225
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
226 227
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
228
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
229 230 231
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
232 233 234
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
235

236
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
237 238
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
239 240 241
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

242 243 244 245 246 247
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
248
            raise ValueError("Fetch only accepts string and list of string")
249 250 251 252 253 254 255

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
256
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
257

M
MRXLT 已提交
258 259 260 261
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
262 263
        int_shape = []
        float_shape = []
M
MRXLT 已提交
264
        fetch_names = []
M
MRXLT 已提交
265
        counter = 0
M
MRXLT 已提交
266
        batch_size = len(feed_batch)
267 268 269 270 271 272 273

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
274
                "Fetch names should not be empty or out of saved fetch list.")
275 276
            return {}

G
guru4elephant 已提交
277
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
278 279
            int_slot = []
            float_slot = []
280
            for key in feed_i:
M
MRXLT 已提交
281
                if key not in self.feed_names_:
M
MRXLT 已提交
282
                    raise ValueError("Wrong feed name: {}.".format(key))
M
MRXLT 已提交
283 284
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
285
                if self.feed_types_[key] in int_type:
G
guru4elephant 已提交
286
                    if i == 0:
M
MRXLT 已提交
287
                        int_feed_names.append(key)
D
dongdaxiang 已提交
288
                        if isinstance(feed_i[key], np.ndarray):
289
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
290 291
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
292
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
293
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
294
                        self.has_numpy_input = True
D
dongdaxiang 已提交
295 296
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
297
                        self.all_numpy_input = False
M
MRXLT 已提交
298
                elif self.feed_types_[key] in float_type:
G
guru4elephant 已提交
299
                    if i == 0:
M
MRXLT 已提交
300
                        float_feed_names.append(key)
D
dongdaxiang 已提交
301
                        if isinstance(feed_i[key], np.ndarray):
302
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
303 304
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
305
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
306
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
307
                        self.has_numpy_input = True
D
dongdaxiang 已提交
308 309
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
310
                        self.all_numpy_input = False
M
MRXLT 已提交
311 312 313
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
314 315 316
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

317
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
318
        if self.all_numpy_input:
M
MRXLT 已提交
319 320
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
321 322
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
323
        elif self.has_numpy_input == False:
M
MRXLT 已提交
324 325
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
326 327
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
328
        else:
M
MRXLT 已提交
329
            raise ValueError(
M
MRXLT 已提交
330 331
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
332

W
WangXi 已提交
333 334 335
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

336 337 338
        if res == -1:
            return None

B
barrierye 已提交
339
        multi_result_map = []
340
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
341
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
342
            result_map = {}
B
barrierye 已提交
343
            # result map needs to be a numpy array
B
barrierye 已提交
344
            for i, name in enumerate(fetch_names):
M
MRXLT 已提交
345
                if self.fetch_names_to_type_[name] == int64_type:
B
barrierye 已提交
346
                    # result_map[name] will be py::array(numpy array)
347 348 349
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
350 351
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
352 353
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
M
MRXLT 已提交
354
                elif self.fetch_names_to_type_[name] == float32_type:
355 356 357
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
358 359
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
360 361
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
M
MRXLT 已提交
362 363 364 365 366 367 368 369 370 371

                elif self.fetch_names_to_type_[name] == int32_type:
                    # result_map[name] will be py::array(numpy array)
                    result_map[name] = result_batch_handle.get_int32_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
372
            multi_result_map.append(result_map)
B
barrierye 已提交
373 374
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
375 376
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
377
        else:
B
barrierye 已提交
378 379 380 381 382 383
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
384 385 386
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
387
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
388
        return ret if not need_variant_tag else [
389
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
390
        ]
B
barrierye 已提交
391

392 393
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
394
        self.client_handle_ = None
B
barrierye 已提交
395 396


397
class MultiLangClient(object):
B
barrierye 已提交
398 399 400 401
    def __init__(self):
        self.channel_ = None

    def load_client_config(self, path):
B
barrierye 已提交
402 403 404
        if not isinstance(path, str):
            raise Exception("GClient only supports multi-model temporarily")
        self._parse_model_config(path)
B
barrierye 已提交
405 406

    def connect(self, endpoint):
W
WangXi 已提交
407 408 409 410 411 412 413
        # https://github.com/tensorflow/serving/issues/1382
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.max_receive_message_length', 512 * 1024 * 1024)]

        self.channel_ = grpc.insecure_channel(
            endpoint[0], options=options)  #TODO
414
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
415 416
            self.channel_)

B
barrierye 已提交
417 418 419 420 421 422 423 424
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

B
barrierye 已提交
425 426 427 428 429 430 431
    def _parse_model_config(self, model_config_path):
        model_conf = m_config.GeneralModelConfig()
        f = open(model_config_path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
432
        self.feed_shapes_ = {}
B
barrierye 已提交
433
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
434 435
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
436 437 438
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
439
            if var.is_lod_tensor:
B
barrierye 已提交
440
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
441 442 443 444
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
445
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
446 447 448
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
449

B
barrierye 已提交
450
    def _pack_feed_data(self, feed, fetch, is_python):
451
        req = multi_lang_general_model_service_pb2.Request()
B
barrierye 已提交
452
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
453
        req.is_python = is_python
B
barrierye 已提交
454 455 456 457 458 459 460
        feed_batch = None
        if isinstance(feed, dict):
            feed_batch = [feed]
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise Exception("{} not support".format(type(feed)))
W
WangXi 已提交
461
        req.feed_var_names.extend(feed_batch[0].keys())
B
barrierye 已提交
462
        init_feed_names = False
B
barrierye 已提交
463
        for feed_data in feed_batch:
464
            inst = multi_lang_general_model_service_pb2.FeedInst()
B
barrierye 已提交
465
            for name in req.feed_var_names:
466
                tensor = multi_lang_general_model_service_pb2.Tensor()
B
barrierye 已提交
467 468
                var = feed_data[name]
                v_type = self.feed_types_[name]
B
barrierye 已提交
469 470 471 472 473 474 475
                if is_python:
                    data = None
                    if isinstance(var, list):
                        if v_type == 0:  # int64
                            data = np.array(var, dtype="int64")
                        elif v_type == 1:  # float32
                            data = np.array(var, dtype="float32")
M
MRXLT 已提交
476 477
                        elif v_type == 2:  #int32
                            data = np.array(var, dtype="int32")
B
barrierye 已提交
478 479
                        else:
                            raise Exception("error type.")
B
barrierye 已提交
480
                    else:
B
barrierye 已提交
481 482 483 484
                        data = var
                        if var.dtype == "float64":
                            data = data.astype("float32")
                    tensor.data = data.tobytes()
B
barrierye 已提交
485
                else:
B
barrierye 已提交
486 487 488 489 490 491 492 493 494 495
                    if v_type == 0:  # int64
                        if isinstance(var, np.ndarray):
                            tensor.int64_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.int64_data.extend(self._flatten_list(var))
                    elif v_type == 1:  # float32
                        if isinstance(var, np.ndarray):
                            tensor.float_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.float_data.extend(self._flatten_list(var))
M
MRXLT 已提交
496 497 498 499 500
                    elif v_type == 2:  #int32
                        if isinstance(car, np.array):
                            tensor.int_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.int_data.extend(self._flatten_list(var))
B
barrierye 已提交
501 502
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
503
                if isinstance(var, np.ndarray):
B
barrierye 已提交
504
                    tensor.shape.extend(list(var.shape))
B
barrierye 已提交
505
                else:
B
barrierye 已提交
506 507 508
                    tensor.shape.extend(self.feed_shapes_[name])
                inst.tensor_array.append(tensor)
            req.insts.append(inst)
B
barrierye 已提交
509
        return req
B
barrierye 已提交
510

B
barrierye 已提交
511
    def _unpack_resp(self, resp, fetch, is_python, need_variant_tag):
B
barrierye 已提交
512
        result_map = {}
B
barrierye 已提交
513 514 515 516 517
        inst = resp.outputs[0].insts[0]
        tag = resp.tag
        for i, name in enumerate(fetch):
            var = inst.tensor_array[i]
            v_type = self.fetch_types_[name]
B
barrierye 已提交
518 519 520 521 522 523 524
            if is_python:
                if v_type == 0:  # int64
                    result_map[name] = np.frombuffer(var.data, dtype="int64")
                elif v_type == 1:  # float32
                    result_map[name] = np.frombuffer(var.data, dtype="float32")
                else:
                    raise Exception("error type.")
B
barrierye 已提交
525
            else:
B
barrierye 已提交
526
                if v_type == 0:  # int64
527 528
                    result_map[name] = np.array(
                        list(var.int64_data), dtype="int64")
B
barrierye 已提交
529
                elif v_type == 1:  # float32
530 531
                    result_map[name] = np.array(
                        list(var.float_data), dtype="float32")
M
MRXLT 已提交
532 533 534
                elif v_type == 2:  # int32
                    result_map[name] = np.array(
                        list(var.int_data), dtype="int32")
B
barrierye 已提交
535 536
                else:
                    raise Exception("error type.")
B
barrierye 已提交
537
            result_map[name].shape = list(var.shape)
B
barrierye 已提交
538
            if name in self.lod_tensor_set_:
B
barrierye 已提交
539
                result_map["{}.lod".format(name)] = np.array(list(var.lod))
540 541
        return result_map if not need_variant_tag else [result_map, tag]

B
barrierye 已提交
542
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
543
        def unpack_resp(resp):
B
barrierye 已提交
544
            return self._unpack_resp(resp, fetch, is_python, need_variant_tag)
B
barrierye 已提交
545

546 547
        return unpack_resp

W
WangXi 已提交
548 549 550
    def get_feed_names(self):
        return self.feed_names_

B
barrierye 已提交
551 552 553 554 555 556 557
    def predict(self,
                feed,
                fetch,
                need_variant_tag=False,
                asyn=False,
                is_python=True):
        req = self._pack_feed_data(feed, fetch, is_python=is_python)
558 559
        if not asyn:
            resp = self.stub_.inference(req)
B
barrierye 已提交
560 561 562 563 564
            return self._unpack_resp(
                resp,
                fetch,
                is_python=is_python,
                need_variant_tag=need_variant_tag)
565 566 567
        else:
            call_future = self.stub_.inference.future(req)
            return MultiLangPredictFuture(
B
barrierye 已提交
568 569 570 571 572
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
573 574 575 576 577 578 579 580 581 582


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)
W
WangXi 已提交
583 584 585 586 587 588 589

    def add_done_callback(self, fn):
        def __fn__(call_future):
            assert call_future == self.call_future_
            fn(self)

        self.call_future_.add_done_callback(__fn__)