__init__.py 22.6 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
G
guru4elephant 已提交
24

B
barrierye 已提交
25
import grpc
B
barrierye 已提交
26
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
27 28
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
G
gongweibao 已提交
29 30
from .proto import multi_lang_general_model_service_pb2 as pb2
from .proto import multi_lang_general_model_service_pb2_grpc as grpc_pb2
B
barrierye 已提交
31

G
guru4elephant 已提交
32 33 34
int_type = 0
float_type = 1

M
MRXLT 已提交
35

W
WangXi 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
64 65 66
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
67 68 69
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
70 71
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
72

73 74 75 76
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
77

M
MRXLT 已提交
78 79 80 81
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
82 83 84 85 86
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
87 88
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
89

90 91 92 93 94 95
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
96 97 98 99

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
100
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
101 102 103 104 105
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
106

G
guru4elephant 已提交
107 108 109 110 111 112 113 114
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
115
        return self.sdk_desc
G
guru4elephant 已提交
116

G
guru4elephant 已提交
117 118 119 120 121 122

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
123
        self.feed_shapes_ = {}
G
guru4elephant 已提交
124
        self.feed_types_ = {}
G
guru4elephant 已提交
125
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
126
        self.pid = os.getpid()
B
barrierye 已提交
127
        self.predictor_sdk_ = None
G
guru4elephant 已提交
128 129
        self.producers = []
        self.consumer = None
W
WangXi 已提交
130
        self.profile_ = _Profiler()
M
MRXLT 已提交
131 132
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
133
        self.rpc_timeout_ms = 20000
134 135
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
M
MRXLT 已提交
136

G
guru4elephant 已提交
137
    def load_client_config(self, path):
M
MRXLT 已提交
138
        from .serving_client import PredictorClient
139 140 141 142 143
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
144 145 146 147
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
148 149
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
150 151
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
152
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
153 154
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
155 156
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
157
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
158 159
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
160
        self.lod_tensor_set = set()
M
MRXLT 已提交
161
        self.feed_tensor_len = {}
162

163 164 165
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
166
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
167

M
MRXLT 已提交
168 169
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
170 171 172 173 174
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
175 176 177
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
178 179
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
180 181
        return

182
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
183 184
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
185 186 187
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
188 189 190 191 192 193
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

B
barrierye 已提交
194
    def connect(self, endpoints=None):
G
guru4elephant 已提交
195 196 197
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
198 199
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
200
                raise ValueError(
B
barrierye 已提交
201 202 203 204
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
205
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
206
                                 100)
B
barrierye 已提交
207 208
            else:
                print(
209
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
210
                    format(endpoints))
M
MRXLT 已提交
211
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
212 213
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
214 215 216 217 218 219 220

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
221 222 223
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
224 225
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
226
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
227 228 229
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
230 231 232
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
233

234
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
235 236
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
237 238 239
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

240 241 242 243 244 245
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
246
            raise ValueError("Fetch only accepts string and list of string")
247 248 249 250 251 252 253

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
254
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
255

M
MRXLT 已提交
256 257 258 259
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
260 261
        int_shape = []
        float_shape = []
M
MRXLT 已提交
262
        fetch_names = []
M
MRXLT 已提交
263
        counter = 0
M
MRXLT 已提交
264
        batch_size = len(feed_batch)
265 266 267 268 269 270 271

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
272
                "Fetch names should not be empty or out of saved fetch list.")
273 274
            return {}

G
guru4elephant 已提交
275
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
276 277
            int_slot = []
            float_slot = []
278
            for key in feed_i:
M
MRXLT 已提交
279
                if key not in self.feed_names_:
M
MRXLT 已提交
280
                    raise ValueError("Wrong feed name: {}.".format(key))
M
MRXLT 已提交
281 282
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
283
                if self.feed_types_[key] == int_type:
G
guru4elephant 已提交
284
                    if i == 0:
M
MRXLT 已提交
285
                        int_feed_names.append(key)
D
dongdaxiang 已提交
286
                        if isinstance(feed_i[key], np.ndarray):
287
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
288 289
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
290
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
291
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
292
                        self.has_numpy_input = True
D
dongdaxiang 已提交
293 294
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
295
                        self.all_numpy_input = False
M
MRXLT 已提交
296
                elif self.feed_types_[key] == float_type:
G
guru4elephant 已提交
297
                    if i == 0:
M
MRXLT 已提交
298
                        float_feed_names.append(key)
D
dongdaxiang 已提交
299
                        if isinstance(feed_i[key], np.ndarray):
300
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
301 302
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
303
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
304
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
305
                        self.has_numpy_input = True
D
dongdaxiang 已提交
306 307
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
308
                        self.all_numpy_input = False
M
MRXLT 已提交
309 310 311
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
312 313 314
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

315
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
316
        if self.all_numpy_input:
M
MRXLT 已提交
317 318
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
319 320
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
321
        elif self.has_numpy_input == False:
M
MRXLT 已提交
322 323
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
324 325
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
326
        else:
M
MRXLT 已提交
327
            raise ValueError(
M
MRXLT 已提交
328 329
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
330

W
WangXi 已提交
331 332 333
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

334 335 336
        if res == -1:
            return None

B
barrierye 已提交
337
        multi_result_map = []
338
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
339
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
340
            result_map = {}
B
barrierye 已提交
341
            # result map needs to be a numpy array
B
barrierye 已提交
342 343
            for i, name in enumerate(fetch_names):
                if self.fetch_names_to_type_[name] == int_type:
B
barrierye 已提交
344
                    # result_map[name] will be py::array(numpy array)
345 346 347
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
348 349
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
350 351
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
352
                elif self.fetch_names_to_type_[name] == float_type:
353 354 355
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
356 357
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
358 359
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
360
            multi_result_map.append(result_map)
B
barrierye 已提交
361 362
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
363 364
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
365
        else:
B
barrierye 已提交
366 367 368 369 370 371
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
372 373 374
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
375
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
376
        return ret if not need_variant_tag else [
377
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
378
        ]
B
barrierye 已提交
379

380 381
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
382
        self.client_handle_ = None
B
barrierye 已提交
383 384


385
class MultiLangClient(object):
B
barrierye 已提交
386 387
    def __init__(self):
        self.channel_ = None
G
gongweibao 已提交
388
        self._config = None
B
barrierye 已提交
389 390

    def load_client_config(self, path):
B
barrierye 已提交
391 392
        if not isinstance(path, str):
            raise Exception("GClient only supports multi-model temporarily")
G
gongweibao 已提交
393 394
        with open(path, 'r') as f:
            proto_txt = str(f.read())
G
gongweibao 已提交
395

G
gongweibao 已提交
396
        self._parse_model_config(proto_txt)
B
barrierye 已提交
397

G
gongweibao 已提交
398
    def _load_client_config(self):
G
gongweibao 已提交
399 400
        req = pb2.EmptyRequest()
        self._config = self.stub_.get_config(req)
G
gongweibao 已提交
401
        self._parse_model_config(self._config.proto_txt)
G
gongweibao 已提交
402 403

    def connect(self, endpoint, use_remote_config=True):
W
WangXi 已提交
404 405 406 407 408 409 410
        # https://github.com/tensorflow/serving/issues/1382
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.max_receive_message_length', 512 * 1024 * 1024)]

        self.channel_ = grpc.insecure_channel(
            endpoint[0], options=options)  #TODO
411
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
412 413
            self.channel_)

G
gongweibao 已提交
414
        if use_remote_config:
G
gongweibao 已提交
415
            self._load_client_config()
G
gongweibao 已提交
416

B
barrierye 已提交
417 418 419 420 421 422 423 424
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

G
gongweibao 已提交
425
    def _parse_model_config(self, proto_txt):
B
barrierye 已提交
426
        model_conf = m_config.GeneralModelConfig()
G
gongweibao 已提交
427
        model_conf = google.protobuf.text_format.Merge(proto_txt, model_conf)
B
barrierye 已提交
428 429
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
430
        self.feed_shapes_ = {}
B
barrierye 已提交
431
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
432 433
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
434 435 436
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
437
            if var.is_lod_tensor:
B
barrierye 已提交
438
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
439 440 441 442
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
443
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
444 445 446
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
447

B
barrierye 已提交
448
    def _pack_feed_data(self, feed, fetch, is_python):
449
        req = multi_lang_general_model_service_pb2.Request()
B
barrierye 已提交
450
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
451
        req.is_python = is_python
B
barrierye 已提交
452 453 454 455 456 457 458
        feed_batch = None
        if isinstance(feed, dict):
            feed_batch = [feed]
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise Exception("{} not support".format(type(feed)))
W
WangXi 已提交
459
        req.feed_var_names.extend(feed_batch[0].keys())
B
barrierye 已提交
460
        init_feed_names = False
B
barrierye 已提交
461
        for feed_data in feed_batch:
462
            inst = multi_lang_general_model_service_pb2.FeedInst()
B
barrierye 已提交
463
            for name in req.feed_var_names:
464
                tensor = multi_lang_general_model_service_pb2.Tensor()
B
barrierye 已提交
465 466
                var = feed_data[name]
                v_type = self.feed_types_[name]
B
barrierye 已提交
467 468 469 470 471 472 473 474 475
                if is_python:
                    data = None
                    if isinstance(var, list):
                        if v_type == 0:  # int64
                            data = np.array(var, dtype="int64")
                        elif v_type == 1:  # float32
                            data = np.array(var, dtype="float32")
                        else:
                            raise Exception("error type.")
B
barrierye 已提交
476
                    else:
B
barrierye 已提交
477 478 479 480
                        data = var
                        if var.dtype == "float64":
                            data = data.astype("float32")
                    tensor.data = data.tobytes()
B
barrierye 已提交
481
                else:
B
barrierye 已提交
482 483 484 485 486 487 488 489 490 491 492 493
                    if v_type == 0:  # int64
                        if isinstance(var, np.ndarray):
                            tensor.int64_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.int64_data.extend(self._flatten_list(var))
                    elif v_type == 1:  # float32
                        if isinstance(var, np.ndarray):
                            tensor.float_data.extend(var.reshape(-1).tolist())
                        else:
                            tensor.float_data.extend(self._flatten_list(var))
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
494
                if isinstance(var, np.ndarray):
B
barrierye 已提交
495
                    tensor.shape.extend(list(var.shape))
B
barrierye 已提交
496
                else:
B
barrierye 已提交
497 498 499
                    tensor.shape.extend(self.feed_shapes_[name])
                inst.tensor_array.append(tensor)
            req.insts.append(inst)
B
barrierye 已提交
500
        return req
B
barrierye 已提交
501

B
barrierye 已提交
502
    def _unpack_resp(self, resp, fetch, is_python, need_variant_tag):
B
barrierye 已提交
503
        result_map = {}
B
barrierye 已提交
504 505 506 507 508
        inst = resp.outputs[0].insts[0]
        tag = resp.tag
        for i, name in enumerate(fetch):
            var = inst.tensor_array[i]
            v_type = self.fetch_types_[name]
B
barrierye 已提交
509 510 511 512 513 514 515
            if is_python:
                if v_type == 0:  # int64
                    result_map[name] = np.frombuffer(var.data, dtype="int64")
                elif v_type == 1:  # float32
                    result_map[name] = np.frombuffer(var.data, dtype="float32")
                else:
                    raise Exception("error type.")
B
barrierye 已提交
516
            else:
B
barrierye 已提交
517
                if v_type == 0:  # int64
518 519
                    result_map[name] = np.array(
                        list(var.int64_data), dtype="int64")
B
barrierye 已提交
520
                elif v_type == 1:  # float32
521 522
                    result_map[name] = np.array(
                        list(var.float_data), dtype="float32")
B
barrierye 已提交
523 524
                else:
                    raise Exception("error type.")
B
barrierye 已提交
525
            result_map[name].shape = list(var.shape)
B
barrierye 已提交
526
            if name in self.lod_tensor_set_:
B
barrierye 已提交
527
                result_map["{}.lod".format(name)] = np.array(list(var.lod))
528 529
        return result_map if not need_variant_tag else [result_map, tag]

B
barrierye 已提交
530
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
531
        def unpack_resp(resp):
B
barrierye 已提交
532
            return self._unpack_resp(resp, fetch, is_python, need_variant_tag)
B
barrierye 已提交
533

534 535
        return unpack_resp

W
WangXi 已提交
536 537 538
    def get_feed_names(self):
        return self.feed_names_

B
barrierye 已提交
539 540 541 542 543 544 545
    def predict(self,
                feed,
                fetch,
                need_variant_tag=False,
                asyn=False,
                is_python=True):
        req = self._pack_feed_data(feed, fetch, is_python=is_python)
546 547
        if not asyn:
            resp = self.stub_.inference(req)
B
barrierye 已提交
548 549 550 551 552
            return self._unpack_resp(
                resp,
                fetch,
                is_python=is_python,
                need_variant_tag=need_variant_tag)
553 554 555
        else:
            call_future = self.stub_.inference.future(req)
            return MultiLangPredictFuture(
B
barrierye 已提交
556 557 558 559 560
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
561

G
gongweibao 已提交
562

563 564 565 566 567 568 569 570
class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)
W
WangXi 已提交
571 572 573 574 575 576 577

    def add_done_callback(self, fn):
        def __fn__(call_future):
            assert call_future == self.call_future_
            fn(self)

        self.call_future_.add_done_callback(__fn__)