__init__.py 29.8 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16
import os
D
dongdaxiang 已提交
17
import time
18
import sys
M
MRXLT 已提交
19 20 21 22 23 24
import requests
import json
import base64
import numpy as np
import paddle_serving_client
import google.protobuf.text_format
G
guru4elephant 已提交
25

B
barrierye 已提交
26
import grpc
M
MRXLT 已提交
27 28
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
B
barrierye 已提交
29
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
30 31
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
32
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
33

M
MRXLT 已提交
34 35 36 37 38
int64_type = 0
float32_type = 1
int32_type = 2
int_type = set([int64_type, int32_type])
float_type = set([float32_type])
G
guru4elephant 已提交
39

M
MRXLT 已提交
40

W
WangXi 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
69 70 71
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
72 73 74
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
75 76
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
77

78 79 80 81
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
82

M
MRXLT 已提交
83 84 85 86
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
87 88 89 90 91
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
92 93
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
94

95 96 97 98 99 100
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
101 102 103 104

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
105
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
106 107 108 109 110
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
111

G
guru4elephant 已提交
112 113 114 115 116 117 118 119
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
120
        return self.sdk_desc
G
guru4elephant 已提交
121

G
guru4elephant 已提交
122 123 124 125 126 127

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
128
        self.feed_shapes_ = {}
G
guru4elephant 已提交
129
        self.feed_types_ = {}
G
guru4elephant 已提交
130
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
131
        self.pid = os.getpid()
B
barrierye 已提交
132
        self.predictor_sdk_ = None
G
guru4elephant 已提交
133 134
        self.producers = []
        self.consumer = None
W
WangXi 已提交
135
        self.profile_ = _Profiler()
M
MRXLT 已提交
136 137
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
138
        self.rpc_timeout_ms = 20000
139 140
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
M
MRXLT 已提交
141

G
guru4elephant 已提交
142
    def load_client_config(self, path):
M
MRXLT 已提交
143
        from .serving_client import PredictorClient
144 145 146 147 148
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
149 150 151 152
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
153 154
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
155 156
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
157
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
158 159
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
160 161
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
162
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
163 164
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
165
        self.lod_tensor_set = set()
M
MRXLT 已提交
166
        self.feed_tensor_len = {}
M
MRXLT 已提交
167
        self.key = None
168

169 170 171
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
172
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
173

M
MRXLT 已提交
174 175
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
176 177 178 179 180
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
181 182 183
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
184 185
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
186 187
        return

188
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
189 190
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
191 192 193
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
194 195 196 197 198 199
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

M
MRXLT 已提交
200 201 202 203
    def use_key(self, key_filename):
        with open(key_filename, "r") as f:
            self.key = f.read()

M
MRXLT 已提交
204
    def get_serving_port(self, endpoints):
M
MRXLT 已提交
205 206 207 208
        if self.key is not None:
            req = json.dumps({"key": base64.b64encode(self.key)})
        else:
            req = json.dumps({})
M
MRXLT 已提交
209 210 211 212 213 214 215 216 217 218 219 220
        r = requests.post("http://" + endpoints[0], req)
        result = r.json()
        print(result)
        if "endpoint_list" not in result:
            raise ValueError("server not ready")
        else:
            endpoints = [
                endpoints[0].split(":")[0] + ":" +
                str(result["endpoint_list"][0])
            ]
            return endpoints

M
MRXLT 已提交
221
    def connect(self, endpoints=None, encryption=False):
G
guru4elephant 已提交
222 223 224
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
225 226
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
227
                raise ValueError(
B
barrierye 已提交
228 229 230
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
M
MRXLT 已提交
231 232
            if encryption:
                endpoints = self.get_serving_port(endpoints)
B
barrierye 已提交
233
            if self.predictor_sdk_ is None:
234
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
235
                                 100)
B
barrierye 已提交
236 237
            else:
                print(
238
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
239
                    format(endpoints))
M
MRXLT 已提交
240
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
241 242
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
243 244 245 246 247 248 249

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
250 251 252
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
253 254
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
255
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
256 257 258
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
259 260 261
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
262

W
wangjiawei04 已提交
263 264 265 266 267 268
    def predict(self,
                feed=None,
                fetch=None,
                batch=False,
                need_variant_tag=False,
                log_id=0):
W
WangXi 已提交
269 270
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
271 272 273
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

274 275 276 277 278 279
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
280
            raise ValueError("Fetch only accepts string and list of string")
281 282 283 284 285 286 287

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
288
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
289

M
MRXLT 已提交
290 291 292 293
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
294
        int_shape = []
W
wangjiawei04 已提交
295 296
        int_lod_slot_batch = []
        float_lod_slot_batch = []
D
dongdaxiang 已提交
297
        float_shape = []
W
wangjiawei04 已提交
298

M
MRXLT 已提交
299
        fetch_names = []
M
MRXLT 已提交
300
        counter = 0
M
MRXLT 已提交
301
        batch_size = len(feed_batch)
302 303 304 305 306 307 308

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
309
                "Fetch names should not be empty or out of saved fetch list.")
310 311
            return {}

G
guru4elephant 已提交
312
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
313 314
            int_slot = []
            float_slot = []
W
wangjiawei04 已提交
315 316
            int_lod_slot = []
            float_lod_slot = []
317
            for key in feed_i:
W
wangjiawei04 已提交
318
                if ".lod" not in key and key not in self.feed_names_:
M
MRXLT 已提交
319
                    raise ValueError("Wrong feed name: {}.".format(key))
W
wangjiawei04 已提交
320 321
                if ".lod" in key:
                    continue
M
MRXLT 已提交
322 323
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
324
                if self.feed_types_[key] in int_type:
G
guru4elephant 已提交
325
                    if i == 0:
M
MRXLT 已提交
326
                        int_feed_names.append(key)
W
wangjiawei04 已提交
327 328 329
                        shape_lst = []
                        if batch == False:
                            feed_i[key] = feed_i[key][np.newaxis, :]
D
dongdaxiang 已提交
330
                        if isinstance(feed_i[key], np.ndarray):
W
wangjiawei04 已提交
331 332
                            shape_lst.extend(list(feed_i[key].shape))
                            int_shape.append(shape_lst)
D
dongdaxiang 已提交
333 334
                        else:
                            int_shape.append(self.feed_shapes_[key])
W
wangjiawei04 已提交
335 336 337 338 339 340
                        if "{}.lod".format(key) in feed_i:
                            int_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
                        else:
                            int_lod_slot_batch.append([])

D
dongdaxiang 已提交
341
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
342
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
343
                        self.has_numpy_input = True
D
dongdaxiang 已提交
344 345
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
346
                        self.all_numpy_input = False
W
wangjiawei04 已提交
347

M
MRXLT 已提交
348
                elif self.feed_types_[key] in float_type:
G
guru4elephant 已提交
349
                    if i == 0:
M
MRXLT 已提交
350
                        float_feed_names.append(key)
W
wangjiawei04 已提交
351 352 353
                        shape_lst = []
                        if batch == False:
                            feed_i[key] = feed_i[key][np.newaxis, :]
D
dongdaxiang 已提交
354
                        if isinstance(feed_i[key], np.ndarray):
W
wangjiawei04 已提交
355 356
                            shape_lst.extend(list(feed_i[key].shape))
                            float_shape.append(shape_lst)
D
dongdaxiang 已提交
357 358
                        else:
                            float_shape.append(self.feed_shapes_[key])
W
wangjiawei04 已提交
359 360 361 362 363 364
                        if "{}.lod".format(key) in feed_i:
                            float_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
                        else:
                            float_lod_slot_batch.append([])

D
dongdaxiang 已提交
365
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
366
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
367
                        self.has_numpy_input = True
D
dongdaxiang 已提交
368 369
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
370
                        self.all_numpy_input = False
M
MRXLT 已提交
371 372
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)
W
wangjiawei04 已提交
373 374
            int_lod_slot_batch.append(int_lod_slot)
            float_lod_slot_batch.append(float_lod_slot)
M
MRXLT 已提交
375

W
WangXi 已提交
376 377 378
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

379
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
380
        if self.all_numpy_input:
M
MRXLT 已提交
381
            res = self.client_handle_.numpy_predict(
W
wangjiawei04 已提交
382 383 384 385
                float_slot_batch, float_feed_names, float_shape,
                float_lod_slot_batch, int_slot_batch, int_feed_names, int_shape,
                int_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
                log_id)
M
MRXLT 已提交
386
        elif self.has_numpy_input == False:
W
wangjiawei04 已提交
387 388
            raise ValueError(
                "Please make sure all of your inputs are numpy array")
M
MRXLT 已提交
389
        else:
M
MRXLT 已提交
390
            raise ValueError(
M
MRXLT 已提交
391 392
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
393

W
WangXi 已提交
394 395 396
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

397 398 399
        if res == -1:
            return None

B
barrierye 已提交
400
        multi_result_map = []
401
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
402
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
403
            result_map = {}
B
barrierye 已提交
404
            # result map needs to be a numpy array
B
barrierye 已提交
405
            for i, name in enumerate(fetch_names):
M
MRXLT 已提交
406
                if self.fetch_names_to_type_[name] == int64_type:
B
barrierye 已提交
407
                    # result_map[name] will be py::array(numpy array)
408 409 410
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barriery 已提交
411 412 413 414 415
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
B
barrierye 已提交
416 417
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
418 419 420
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
M
MRXLT 已提交
421
                elif self.fetch_names_to_type_[name] == float32_type:
422 423
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
B
barriery 已提交
424 425 426 427 428
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
429
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
430 431
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
432 433 434
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
M
MRXLT 已提交
435 436 437 438
                elif self.fetch_names_to_type_[name] == int32_type:
                    # result_map[name] will be py::array(numpy array)
                    result_map[name] = result_batch_handle.get_int32_by_name(
                        mi, name)
B
barriery 已提交
439 440 441 442 443
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
M
MRXLT 已提交
444 445 446
                    shape = result_batch_handle.get_shape(mi, name)
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
447 448 449
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
B
barrierye 已提交
450
            multi_result_map.append(result_map)
B
barrierye 已提交
451 452
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
453 454
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
455
        else:
B
barrierye 已提交
456 457 458 459 460 461
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
462 463 464
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
465
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
466
        return ret if not need_variant_tag else [
467
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
468
        ]
B
barrierye 已提交
469

470 471
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
472
        self.client_handle_ = None
B
barrierye 已提交
473 474


475
class MultiLangClient(object):
B
barrierye 已提交
476 477
    def __init__(self):
        self.channel_ = None
478
        self.stub_ = None
B
barrierye 已提交
479
        self.rpc_timeout_s_ = 2
B
barrierye 已提交
480
        self.profile_ = _Profiler()
B
barrierye 已提交
481

B
barrierye 已提交
482 483
    def add_variant(self, tag, cluster, variant_weight):
        # TODO
B
barrierye 已提交
484
        raise Exception("cannot support ABtest yet")
B
barrierye 已提交
485

B
barrierye 已提交
486
    def set_rpc_timeout_ms(self, rpc_timeout):
487 488 489 490 491
        if self.stub_ is None:
            raise Exception("set timeout must be set after connect.")
        if not isinstance(rpc_timeout, int):
            # for bclient
            raise ValueError("rpc_timeout must be int type.")
B
barrierye 已提交
492
        self.rpc_timeout_s_ = rpc_timeout / 1000.0
493 494 495 496
        timeout_req = multi_lang_general_model_service_pb2.SetTimeoutRequest()
        timeout_req.timeout_ms = rpc_timeout
        resp = self.stub_.SetTimeout(timeout_req)
        return resp.err_code == 0
B
barrierye 已提交
497 498

    def connect(self, endpoints):
W
WangXi 已提交
499 500 501
        # https://github.com/tensorflow/serving/issues/1382
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
502
                   ('grpc.lb_policy_name', 'round_robin')]
B
barrierye 已提交
503
        # TODO: weight round robin
504
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
B
barrierye 已提交
505
        self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
506
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
507
            self.channel_)
508 509 510 511 512 513
        # get client model config
        get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
        )
        resp = self.stub_.GetClientConfig(get_client_config_req)
        model_config_str = resp.client_config_str
        self._parse_model_config(model_config_str)
B
barrierye 已提交
514

B
barrierye 已提交
515 516 517 518 519 520 521 522
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

523
    def _parse_model_config(self, model_config_str):
B
barrierye 已提交
524
        model_conf = m_config.GeneralModelConfig()
525 526
        model_conf = google.protobuf.text_format.Merge(model_config_str,
                                                       model_conf)
B
barrierye 已提交
527 528
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
529
        self.feed_shapes_ = {}
B
barrierye 已提交
530
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
531 532
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
533 534 535
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
536
            if var.is_lod_tensor:
B
barrierye 已提交
537
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
538 539 540 541
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
542
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
543 544 545
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
546

B
barriery 已提交
547
    def _pack_inference_request(self, feed, fetch, is_python, log_id):
548
        req = multi_lang_general_model_service_pb2.InferenceRequest()
B
barrierye 已提交
549
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
550
        req.is_python = is_python
B
barriery 已提交
551
        req.log_id = log_id
B
barrierye 已提交
552 553 554 555 556 557 558
        feed_batch = None
        if isinstance(feed, dict):
            feed_batch = [feed]
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise Exception("{} not support".format(type(feed)))
W
WangXi 已提交
559
        req.feed_var_names.extend(feed_batch[0].keys())
B
barrierye 已提交
560
        init_feed_names = False
B
barrierye 已提交
561
        for feed_data in feed_batch:
562
            inst = multi_lang_general_model_service_pb2.FeedInst()
B
barrierye 已提交
563
            for name in req.feed_var_names:
564
                tensor = multi_lang_general_model_service_pb2.Tensor()
B
barrierye 已提交
565 566
                var = feed_data[name]
                v_type = self.feed_types_[name]
B
barrierye 已提交
567 568 569 570 571 572 573
                if is_python:
                    data = None
                    if isinstance(var, list):
                        if v_type == 0:  # int64
                            data = np.array(var, dtype="int64")
                        elif v_type == 1:  # float32
                            data = np.array(var, dtype="float32")
B
barrierye 已提交
574 575
                        elif v_type == 2:  # int32
                            data = np.array(var, dtype="int32")
B
barrierye 已提交
576
                        else:
B
barrierye 已提交
577 578
                            raise Exception("error tensor value type.")
                    elif isinstance(var, np.ndarray):
B
barrierye 已提交
579
                        data = var
B
barrierye 已提交
580 581 582 583 584 585 586 587 588
                        if v_type == 0:
                            if data.dtype != 'int64':
                                data = data.astype("int64")
                        elif v_type == 1:
                            if data.dtype != 'float32':
                                data = data.astype("float32")
                        elif v_type == 2:
                            if data.dtype != 'int32':
                                data = data.astype("int32")
B
barrierye 已提交
589 590 591 592
                        else:
                            raise Exception("error tensor value type.")
                    else:
                        raise Exception("var must be list or ndarray.")
B
barrierye 已提交
593
                    tensor.data = data.tobytes()
B
barrierye 已提交
594
                else:
B
barrierye 已提交
595 596 597 598 599 600 601 602
                    if isinstance(var, np.ndarray):
                        if v_type == 0:  # int64
                            tensor.int64_data.extend(
                                var.reshape(-1).astype("int64").tolist())
                        elif v_type == 1:
                            tensor.float_data.extend(
                                var.reshape(-1).astype('float32').tolist())
                        elif v_type == 2:
603
                            tensor.int_data.extend(
B
barrierye 已提交
604
                                var.reshape(-1).astype('int32').tolist())
B
barrierye 已提交
605
                        else:
B
barrierye 已提交
606 607 608
                            raise Exception("error tensor value type.")
                    elif isinstance(var, list):
                        if v_type == 0:
B
barrierye 已提交
609
                            tensor.int64_data.extend(self._flatten_list(var))
B
barrierye 已提交
610
                        elif v_type == 1:
B
barrierye 已提交
611
                            tensor.float_data.extend(self._flatten_list(var))
B
barrierye 已提交
612
                        elif v_type == 2:
613
                            tensor.int_data.extend(self._flatten_list(var))
B
barrierye 已提交
614 615
                        else:
                            raise Exception("error tensor value type.")
B
barrierye 已提交
616
                    else:
B
barrierye 已提交
617
                        raise Exception("var must be list or ndarray.")
B
barrierye 已提交
618
                if isinstance(var, np.ndarray):
B
barrierye 已提交
619
                    tensor.shape.extend(list(var.shape))
B
barrierye 已提交
620
                else:
B
barrierye 已提交
621 622 623
                    tensor.shape.extend(self.feed_shapes_[name])
                inst.tensor_array.append(tensor)
            req.insts.append(inst)
B
barrierye 已提交
624
        return req
B
barrierye 已提交
625

626 627 628
    def _unpack_inference_response(self, resp, fetch, is_python,
                                   need_variant_tag):
        if resp.err_code != 0:
B
fix bug  
barrierye 已提交
629
            return None
B
barrierye 已提交
630
        tag = resp.tag
B
barrierye 已提交
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
        multi_result_map = {}
        for model_result in resp.outputs:
            inst = model_result.insts[0]
            result_map = {}
            for i, name in enumerate(fetch):
                var = inst.tensor_array[i]
                v_type = self.fetch_types_[name]
                if is_python:
                    if v_type == 0:  # int64
                        result_map[name] = np.frombuffer(
                            var.data, dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.frombuffer(
                            var.data, dtype="float32")
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
647
                else:
B
barrierye 已提交
648 649 650 651 652 653 654 655 656 657 658 659 660 661
                    if v_type == 0:  # int64
                        result_map[name] = np.array(
                            list(var.int64_data), dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.array(
                            list(var.float_data), dtype="float32")
                    else:
                        raise Exception("error type.")
                result_map[name].shape = list(var.shape)
                if name in self.lod_tensor_set_:
                    result_map["{}.lod".format(name)] = np.array(list(var.lod))
            multi_result_map[model_result.engine_name] = result_map
        ret = None
        if len(resp.outputs) == 1:
B
barrierye 已提交
662
            ret = list(multi_result_map.values())[0]
B
barrierye 已提交
663 664
        else:
            ret = multi_result_map
B
barrierye 已提交
665

666
        ret["serving_status_code"] = 0
B
barrierye 已提交
667
        return ret if not need_variant_tag else [ret, tag]
668

B
barrierye 已提交
669
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
670
        def unpack_resp(resp):
671 672
            return self._unpack_inference_response(resp, fetch, is_python,
                                                   need_variant_tag)
B
barrierye 已提交
673

674 675
        return unpack_resp

W
WangXi 已提交
676 677 678
    def get_feed_names(self):
        return self.feed_names_

B
barrierye 已提交
679 680 681 682 683
    def predict(self,
                feed,
                fetch,
                need_variant_tag=False,
                asyn=False,
B
barriery 已提交
684 685
                is_python=True,
                log_id=0):
686
        if not asyn:
B
barrierye 已提交
687
            try:
B
barrierye 已提交
688 689
                self.profile_.record('py_prepro_0')
                req = self._pack_inference_request(
B
barriery 已提交
690
                    feed, fetch, is_python=is_python, log_id=log_id)
B
barrierye 已提交
691 692 693
                self.profile_.record('py_prepro_1')

                self.profile_.record('py_client_infer_0')
B
barrierye 已提交
694
                resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
B
barrierye 已提交
695 696 697 698
                self.profile_.record('py_client_infer_1')

                self.profile_.record('py_postpro_0')
                ret = self._unpack_inference_response(
B
barrierye 已提交
699 700 701 702
                    resp,
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag)
B
barrierye 已提交
703 704 705
                self.profile_.record('py_postpro_1')
                self.profile_.print_profile()
                return ret
B
barrierye 已提交
706
            except grpc.RpcError as e:
707
                return {"serving_status_code": e.code()}
708
        else:
B
barriery 已提交
709 710
            req = self._pack_inference_request(
                feed, fetch, is_python=is_python, log_id=log_id)
711 712
            call_future = self.stub_.Inference.future(
                req, timeout=self.rpc_timeout_s_)
713
            return MultiLangPredictFuture(
B
barrierye 已提交
714 715 716 717 718
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
719 720 721 722 723 724 725 726


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
B
barrierye 已提交
727 728 729
        try:
            resp = self.call_future_.result()
        except grpc.RpcError as e:
730
            return {"serving_status_code": e.code()}
731
        return self.callback_func_(resp)
W
WangXi 已提交
732 733 734 735 736 737 738

    def add_done_callback(self, fn):
        def __fn__(call_future):
            assert call_future == self.call_future_
            fn(self)

        self.call_future_.add_done_callback(__fn__)