You need to sign in or sign up before continuing.
__init__.py 28.8 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
G
guru4elephant 已提交
24

B
barrierye 已提交
25
import grpc
B
barrierye 已提交
26
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
27 28
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
29
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
30

M
MRXLT 已提交
31 32 33 34 35
int64_type = 0
float32_type = 1
int32_type = 2
int_type = set([int64_type, int32_type])
float_type = set([float32_type])
G
guru4elephant 已提交
36

M
MRXLT 已提交
37

W
WangXi 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
66 67 68
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
69 70 71
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
72 73
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
74

75 76 77 78
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
79

M
MRXLT 已提交
80 81 82 83
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
84 85 86 87 88
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
89 90
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
91

92 93 94 95 96 97
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
98 99 100 101

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
102
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
103 104 105 106 107
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
108

G
guru4elephant 已提交
109 110 111 112 113 114 115 116
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
117
        return self.sdk_desc
G
guru4elephant 已提交
118

G
guru4elephant 已提交
119 120 121 122 123 124

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
125
        self.feed_shapes_ = {}
G
guru4elephant 已提交
126
        self.feed_types_ = {}
G
guru4elephant 已提交
127
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
128
        self.pid = os.getpid()
B
barrierye 已提交
129
        self.predictor_sdk_ = None
G
guru4elephant 已提交
130 131
        self.producers = []
        self.consumer = None
W
WangXi 已提交
132
        self.profile_ = _Profiler()
M
MRXLT 已提交
133 134
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
135
        self.rpc_timeout_ms = 20000
136 137
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
M
MRXLT 已提交
138

G
guru4elephant 已提交
139
    def load_client_config(self, path):
M
MRXLT 已提交
140
        from .serving_client import PredictorClient
141 142 143 144 145
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
146 147 148 149
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
150 151
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
152 153
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
154
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
155 156
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
157 158
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
159
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
160 161
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
162
        self.lod_tensor_set = set()
M
MRXLT 已提交
163
        self.feed_tensor_len = {}
164

165 166 167
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
168
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
169

M
MRXLT 已提交
170 171
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
172 173 174 175 176
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
177 178 179
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
180 181
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
182 183
        return

184
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
185 186
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
187 188 189
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
190 191 192 193 194 195
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

B
barrierye 已提交
196
    def connect(self, endpoints=None):
G
guru4elephant 已提交
197 198 199
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
200 201
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
202
                raise ValueError(
B
barrierye 已提交
203 204 205 206
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
207
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
208
                                 100)
B
barrierye 已提交
209 210
            else:
                print(
211
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
212
                    format(endpoints))
M
MRXLT 已提交
213
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
214 215
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
216 217 218 219 220 221 222

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
223 224 225
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
226 227
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
228
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
229 230 231
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
232 233 234
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
235

W
wangjiawei04 已提交
236 237 238 239 240 241
    def predict(self,
                feed=None,
                fetch=None,
                batch=False,
                need_variant_tag=False,
                log_id=0):
W
WangXi 已提交
242 243
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
244 245 246
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

247 248 249 250 251 252
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
253
            raise ValueError("Fetch only accepts string and list of string")
254 255 256 257 258 259 260

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
261
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
262

M
MRXLT 已提交
263 264 265 266
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
267
        int_shape = []
W
wangjiawei04 已提交
268 269
        int_lod_slot_batch = []
        float_lod_slot_batch = []
D
dongdaxiang 已提交
270
        float_shape = []
W
wangjiawei04 已提交
271

M
MRXLT 已提交
272
        fetch_names = []
M
MRXLT 已提交
273
        counter = 0
M
MRXLT 已提交
274
        batch_size = len(feed_batch)
275 276 277 278 279 280 281

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
282
                "Fetch names should not be empty or out of saved fetch list.")
283 284
            return {}

G
guru4elephant 已提交
285
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
286 287
            int_slot = []
            float_slot = []
W
wangjiawei04 已提交
288 289
            int_lod_slot = []
            float_lod_slot = []
290
            for key in feed_i:
W
wangjiawei04 已提交
291
                if ".lod" not in key and key not in self.feed_names_:
M
MRXLT 已提交
292
                    raise ValueError("Wrong feed name: {}.".format(key))
W
wangjiawei04 已提交
293 294
                if ".lod" in key:
                    continue
M
MRXLT 已提交
295 296
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
297
                if self.feed_types_[key] in int_type:
G
guru4elephant 已提交
298
                    if i == 0:
M
MRXLT 已提交
299
                        int_feed_names.append(key)
W
wangjiawei04 已提交
300 301
                        shape_lst = []
                        if batch == False:
W
wangjiawei04 已提交
302
                            feed_i[key] = feed_i[key][np.newaxis, :]
D
dongdaxiang 已提交
303
                        if isinstance(feed_i[key], np.ndarray):
W
wangjiawei04 已提交
304 305
                            shape_lst.extend(list(feed_i[key].shape))
                            int_shape.append(shape_lst)
D
dongdaxiang 已提交
306 307
                        else:
                            int_shape.append(self.feed_shapes_[key])
W
wangjiawei04 已提交
308
                        if "{}.lod".format(key) in feed_i:
W
wangjiawei04 已提交
309 310
                            int_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
W
wangjiawei04 已提交
311
                        else:
W
wangjiawei04 已提交
312
                            int_lod_slot_batch.append([])
W
wangjiawei04 已提交
313

D
dongdaxiang 已提交
314
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
315
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
316
                        self.has_numpy_input = True
D
dongdaxiang 已提交
317 318
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
319
                        self.all_numpy_input = False
W
wangjiawei04 已提交
320

M
MRXLT 已提交
321
                elif self.feed_types_[key] in float_type:
G
guru4elephant 已提交
322
                    if i == 0:
M
MRXLT 已提交
323
                        float_feed_names.append(key)
W
wangjiawei04 已提交
324 325
                        shape_lst = []
                        if batch == False:
W
wangjiawei04 已提交
326
                            feed_i[key] = feed_i[key][np.newaxis, :]
D
dongdaxiang 已提交
327
                        if isinstance(feed_i[key], np.ndarray):
W
wangjiawei04 已提交
328 329
                            shape_lst.extend(list(feed_i[key].shape))
                            float_shape.append(shape_lst)
D
dongdaxiang 已提交
330 331
                        else:
                            float_shape.append(self.feed_shapes_[key])
W
wangjiawei04 已提交
332
                        if "{}.lod".format(key) in feed_i:
W
wangjiawei04 已提交
333 334
                            float_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
W
wangjiawei04 已提交
335
                        else:
W
wangjiawei04 已提交
336
                            float_lod_slot_batch.append([])
W
wangjiawei04 已提交
337

D
dongdaxiang 已提交
338
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
339
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
340
                        self.has_numpy_input = True
D
dongdaxiang 已提交
341 342
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
343
                        self.all_numpy_input = False
M
MRXLT 已提交
344 345
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)
W
wangjiawei04 已提交
346 347
            int_lod_slot_batch.append(int_lod_slot)
            float_lod_slot_batch.append(float_lod_slot)
M
MRXLT 已提交
348

W
WangXi 已提交
349 350 351
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

352
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
353
        if self.all_numpy_input:
M
MRXLT 已提交
354
            res = self.client_handle_.numpy_predict(
W
wangjiawei04 已提交
355 356 357 358
                float_slot_batch, float_feed_names, float_shape,
                float_lod_slot_batch, int_slot_batch, int_feed_names, int_shape,
                int_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
                log_id)
M
MRXLT 已提交
359
        elif self.has_numpy_input == False:
W
wangjiawei04 已提交
360 361
            raise ValueError(
                "Please make sure all of your inputs are numpy array")
M
MRXLT 已提交
362
        else:
M
MRXLT 已提交
363
            raise ValueError(
M
MRXLT 已提交
364 365
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
366

W
WangXi 已提交
367 368 369
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

370 371 372
        if res == -1:
            return None

B
barrierye 已提交
373
        multi_result_map = []
374
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
375
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
376
            result_map = {}
B
barrierye 已提交
377
            # result map needs to be a numpy array
B
barrierye 已提交
378
            for i, name in enumerate(fetch_names):
M
MRXLT 已提交
379
                if self.fetch_names_to_type_[name] == int64_type:
B
barrierye 已提交
380
                    # result_map[name] will be py::array(numpy array)
381 382 383
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barriery 已提交
384 385 386 387 388
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
B
barrierye 已提交
389 390
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
391 392
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
M
MRXLT 已提交
393
                elif self.fetch_names_to_type_[name] == float32_type:
394 395
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
B
barriery 已提交
396 397 398 399 400
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
401
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
402 403
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
404 405
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
M
MRXLT 已提交
406 407 408 409 410

                elif self.fetch_names_to_type_[name] == int32_type:
                    # result_map[name] will be py::array(numpy array)
                    result_map[name] = result_batch_handle.get_int32_by_name(
                        mi, name)
B
barriery 已提交
411 412 413 414 415
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
M
MRXLT 已提交
416 417 418 419 420
                    shape = result_batch_handle.get_shape(mi, name)
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
421
            multi_result_map.append(result_map)
B
barrierye 已提交
422 423
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
424 425
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
426
        else:
B
barrierye 已提交
427 428 429 430 431 432
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
433 434 435
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
436
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
437
        return ret if not need_variant_tag else [
438
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
439
        ]
B
barrierye 已提交
440

441 442
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
443
        self.client_handle_ = None
B
barrierye 已提交
444 445


446
class MultiLangClient(object):
B
barrierye 已提交
447 448
    def __init__(self):
        self.channel_ = None
449
        self.stub_ = None
B
barrierye 已提交
450
        self.rpc_timeout_s_ = 2
B
barrierye 已提交
451
        self.profile_ = _Profiler()
B
barrierye 已提交
452

B
barrierye 已提交
453 454
    def add_variant(self, tag, cluster, variant_weight):
        # TODO
B
barrierye 已提交
455
        raise Exception("cannot support ABtest yet")
B
barrierye 已提交
456 457

    def set_rpc_timeout_ms(self, rpc_timeout):
458 459 460 461 462
        if self.stub_ is None:
            raise Exception("set timeout must be set after connect.")
        if not isinstance(rpc_timeout, int):
            # for bclient
            raise ValueError("rpc_timeout must be int type.")
B
barrierye 已提交
463
        self.rpc_timeout_s_ = rpc_timeout / 1000.0
464 465 466 467
        timeout_req = multi_lang_general_model_service_pb2.SetTimeoutRequest()
        timeout_req.timeout_ms = rpc_timeout
        resp = self.stub_.SetTimeout(timeout_req)
        return resp.err_code == 0
B
barrierye 已提交
468 469

    def connect(self, endpoints):
W
WangXi 已提交
470 471
        # https://github.com/tensorflow/serving/issues/1382
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
472 473
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.lb_policy_name', 'round_robin')]
B
barrierye 已提交
474
        # TODO: weight round robin
475
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
B
barrierye 已提交
476
        self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
477
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
478
            self.channel_)
479 480 481 482 483 484
        # get client model config
        get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
        )
        resp = self.stub_.GetClientConfig(get_client_config_req)
        model_config_str = resp.client_config_str
        self._parse_model_config(model_config_str)
B
barrierye 已提交
485

B
barrierye 已提交
486 487 488 489 490 491 492 493
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

494
    def _parse_model_config(self, model_config_str):
B
barrierye 已提交
495
        model_conf = m_config.GeneralModelConfig()
496 497
        model_conf = google.protobuf.text_format.Merge(model_config_str,
                                                       model_conf)
B
barrierye 已提交
498 499
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
500
        self.feed_shapes_ = {}
B
barrierye 已提交
501
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
502 503
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
504 505 506
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
507
            if var.is_lod_tensor:
B
barrierye 已提交
508
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
509 510 511 512
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
513
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
514 515 516
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
517

B
barriery 已提交
518
    def _pack_inference_request(self, feed, fetch, is_python, log_id):
519
        req = multi_lang_general_model_service_pb2.InferenceRequest()
B
barrierye 已提交
520
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
521
        req.is_python = is_python
B
barriery 已提交
522
        req.log_id = log_id
B
barrierye 已提交
523 524 525 526 527 528 529
        feed_batch = None
        if isinstance(feed, dict):
            feed_batch = [feed]
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise Exception("{} not support".format(type(feed)))
W
WangXi 已提交
530
        req.feed_var_names.extend(feed_batch[0].keys())
B
barrierye 已提交
531
        init_feed_names = False
B
barrierye 已提交
532
        for feed_data in feed_batch:
533
            inst = multi_lang_general_model_service_pb2.FeedInst()
B
barrierye 已提交
534
            for name in req.feed_var_names:
535
                tensor = multi_lang_general_model_service_pb2.Tensor()
B
barrierye 已提交
536 537
                var = feed_data[name]
                v_type = self.feed_types_[name]
B
barrierye 已提交
538 539 540 541 542 543 544
                if is_python:
                    data = None
                    if isinstance(var, list):
                        if v_type == 0:  # int64
                            data = np.array(var, dtype="int64")
                        elif v_type == 1:  # float32
                            data = np.array(var, dtype="float32")
B
barrierye 已提交
545 546
                        elif v_type == 2:  # int32
                            data = np.array(var, dtype="int32")
B
barrierye 已提交
547
                        else:
B
barrierye 已提交
548 549
                            raise Exception("error tensor value type.")
                    elif isinstance(var, np.ndarray):
B
barrierye 已提交
550
                        data = var
B
barrierye 已提交
551 552 553 554 555 556 557 558 559
                        if v_type == 0:
                            if data.dtype != 'int64':
                                data = data.astype("int64")
                        elif v_type == 1:
                            if data.dtype != 'float32':
                                data = data.astype("float32")
                        elif v_type == 2:
                            if data.dtype != 'int32':
                                data = data.astype("int32")
B
barrierye 已提交
560 561 562 563
                        else:
                            raise Exception("error tensor value type.")
                    else:
                        raise Exception("var must be list or ndarray.")
B
barrierye 已提交
564
                    tensor.data = data.tobytes()
B
barrierye 已提交
565
                else:
B
barrierye 已提交
566 567 568 569 570 571 572 573
                    if isinstance(var, np.ndarray):
                        if v_type == 0:  # int64
                            tensor.int64_data.extend(
                                var.reshape(-1).astype("int64").tolist())
                        elif v_type == 1:
                            tensor.float_data.extend(
                                var.reshape(-1).astype('float32').tolist())
                        elif v_type == 2:
574
                            tensor.int_data.extend(
B
barrierye 已提交
575
                                var.reshape(-1).astype('int32').tolist())
B
barrierye 已提交
576
                        else:
B
barrierye 已提交
577 578 579
                            raise Exception("error tensor value type.")
                    elif isinstance(var, list):
                        if v_type == 0:
B
barrierye 已提交
580
                            tensor.int64_data.extend(self._flatten_list(var))
B
barrierye 已提交
581
                        elif v_type == 1:
B
barrierye 已提交
582
                            tensor.float_data.extend(self._flatten_list(var))
B
barrierye 已提交
583
                        elif v_type == 2:
584
                            tensor.int_data.extend(self._flatten_list(var))
B
barrierye 已提交
585 586
                        else:
                            raise Exception("error tensor value type.")
B
barrierye 已提交
587
                    else:
B
barrierye 已提交
588
                        raise Exception("var must be list or ndarray.")
B
barrierye 已提交
589
                if isinstance(var, np.ndarray):
B
barrierye 已提交
590
                    tensor.shape.extend(list(var.shape))
B
barrierye 已提交
591
                else:
B
barrierye 已提交
592 593 594
                    tensor.shape.extend(self.feed_shapes_[name])
                inst.tensor_array.append(tensor)
            req.insts.append(inst)
B
barrierye 已提交
595
        return req
B
barrierye 已提交
596

597 598 599
    def _unpack_inference_response(self, resp, fetch, is_python,
                                   need_variant_tag):
        if resp.err_code != 0:
B
fix bug  
barrierye 已提交
600 601
            return None
        tag = resp.tag
B
barrierye 已提交
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
        multi_result_map = {}
        for model_result in resp.outputs:
            inst = model_result.insts[0]
            result_map = {}
            for i, name in enumerate(fetch):
                var = inst.tensor_array[i]
                v_type = self.fetch_types_[name]
                if is_python:
                    if v_type == 0:  # int64
                        result_map[name] = np.frombuffer(
                            var.data, dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.frombuffer(
                            var.data, dtype="float32")
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
618
                else:
B
barrierye 已提交
619 620 621 622 623 624 625 626 627 628 629 630 631 632
                    if v_type == 0:  # int64
                        result_map[name] = np.array(
                            list(var.int64_data), dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.array(
                            list(var.float_data), dtype="float32")
                    else:
                        raise Exception("error type.")
                result_map[name].shape = list(var.shape)
                if name in self.lod_tensor_set_:
                    result_map["{}.lod".format(name)] = np.array(list(var.lod))
            multi_result_map[model_result.engine_name] = result_map
        ret = None
        if len(resp.outputs) == 1:
B
barrierye 已提交
633
            ret = list(multi_result_map.values())[0]
B
barrierye 已提交
634 635
        else:
            ret = multi_result_map
B
barrierye 已提交
636

637
        ret["serving_status_code"] = 0
B
barrierye 已提交
638
        return ret if not need_variant_tag else [ret, tag]
639

B
barrierye 已提交
640
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
641
        def unpack_resp(resp):
642 643
            return self._unpack_inference_response(resp, fetch, is_python,
                                                   need_variant_tag)
B
barrierye 已提交
644

645 646
        return unpack_resp

W
WangXi 已提交
647 648 649
    def get_feed_names(self):
        return self.feed_names_

B
barrierye 已提交
650 651 652 653 654
    def predict(self,
                feed,
                fetch,
                need_variant_tag=False,
                asyn=False,
B
barriery 已提交
655 656
                is_python=True,
                log_id=0):
657
        if not asyn:
B
barrierye 已提交
658
            try:
B
barrierye 已提交
659 660
                self.profile_.record('py_prepro_0')
                req = self._pack_inference_request(
B
barriery 已提交
661
                    feed, fetch, is_python=is_python, log_id=log_id)
B
barrierye 已提交
662 663 664
                self.profile_.record('py_prepro_1')

                self.profile_.record('py_client_infer_0')
B
barrierye 已提交
665
                resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
B
barrierye 已提交
666 667 668 669
                self.profile_.record('py_client_infer_1')

                self.profile_.record('py_postpro_0')
                ret = self._unpack_inference_response(
B
barrierye 已提交
670 671 672 673
                    resp,
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag)
B
barrierye 已提交
674 675 676
                self.profile_.record('py_postpro_1')
                self.profile_.print_profile()
                return ret
B
barrierye 已提交
677
            except grpc.RpcError as e:
678
                return {"serving_status_code": e.code()}
679
        else:
B
barriery 已提交
680 681
            req = self._pack_inference_request(
                feed, fetch, is_python=is_python, log_id=log_id)
682 683
            call_future = self.stub_.Inference.future(
                req, timeout=self.rpc_timeout_s_)
684
            return MultiLangPredictFuture(
B
barrierye 已提交
685 686 687 688 689
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
690 691 692 693 694 695 696 697


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
B
barrierye 已提交
698 699 700
        try:
            resp = self.call_future_.result()
        except grpc.RpcError as e:
701
            return {"serving_status_code": e.code()}
702
        return self.callback_func_(resp)
W
WangXi 已提交
703 704 705 706 707 708 709

    def add_done_callback(self, fn):
        def __fn__(call_future):
            assert call_future == self.call_future_
            fn(self)

        self.call_future_.add_done_callback(__fn__)