__init__.py 29.9 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

16
import paddle_serving_client
M
MRXLT 已提交
17
import os
18 19 20 21
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import numpy as np
H
HexToString 已提交
22 23 24
import requests
import json
import base64
D
dongdaxiang 已提交
25
import time
26
import sys
G
guru4elephant 已提交
27

B
barrierye 已提交
28
import grpc
B
barrierye 已提交
29
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
30 31
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
32
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
33

M
MRXLT 已提交
34 35 36
int64_type = 0
float32_type = 1
int32_type = 2
H
HexToString 已提交
37
bytes_type = 3
M
MRXLT 已提交
38 39
int_type = set([int64_type, int32_type])
float_type = set([float32_type])
H
HexToString 已提交
40
string_type= set([bytes_type])
G
guru4elephant 已提交
41

M
MRXLT 已提交
42

W
WangXi 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
71 72 73
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
74 75 76
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
77 78
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
79

80 81 82 83
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
84

M
MRXLT 已提交
85 86 87 88
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
89 90 91 92 93
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
94 95
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
96

97 98 99 100 101 102
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
103 104 105 106

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
107
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
108 109 110 111 112
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
113

G
guru4elephant 已提交
114 115 116 117 118 119 120 121
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
122
        return self.sdk_desc
G
guru4elephant 已提交
123

G
guru4elephant 已提交
124 125 126 127 128 129

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
130
        self.feed_shapes_ = {}
G
guru4elephant 已提交
131
        self.feed_types_ = {}
G
guru4elephant 已提交
132
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
133
        self.pid = os.getpid()
B
barrierye 已提交
134
        self.predictor_sdk_ = None
G
guru4elephant 已提交
135 136
        self.producers = []
        self.consumer = None
W
WangXi 已提交
137
        self.profile_ = _Profiler()
M
MRXLT 已提交
138 139
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
140
        self.rpc_timeout_ms = 20000
141 142
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
M
MRXLT 已提交
143

G
guru4elephant 已提交
144
    def load_client_config(self, path):
H
HexToString 已提交
145 146 147 148 149
        if isinstance(path, str):
            path_list = [path]
        elif isinstance(path, list):
            path_list = path

M
MRXLT 已提交
150
        from .serving_client import PredictorClient
151
        model_conf = m_config.GeneralModelConfig()
H
HexToString 已提交
152
        f = open(path_list[0], 'r')
153 154 155
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
156 157 158 159
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
160
        self.client_handle_ = PredictorClient()
H
HexToString 已提交
161
        self.client_handle_.init(path_list)
M
bug fix  
MRXLT 已提交
162 163
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
164
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
165 166
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
H
HexToString 已提交
167
        
168
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
G
guru4elephant 已提交
169
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
170
        self.lod_tensor_set = set()
M
MRXLT 已提交
171
        self.feed_tensor_len = {}
H
HexToString 已提交
172
        self.key = None
173 174 175
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
176
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
177

M
MRXLT 已提交
178 179
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
180 181 182 183 184
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
H
HexToString 已提交
185 186 187 188 189 190 191 192 193
        
        if len(path_list) > 1:
            model_conf = m_config.GeneralModelConfig()
            f = open(path_list[-1], 'r')
            model_conf = google.protobuf.text_format.Merge(
                str(f.read()), model_conf)
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
G
guru4elephant 已提交
194 195 196
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
197 198
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
199 200
        return

H
HexToString 已提交
201

202
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
203 204
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
205 206 207
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
208 209 210 211 212 213
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

H
HexToString 已提交
214
    def use_key(self, key_filename):
215
        with open(key_filename, "rb") as f:
H
HexToString 已提交
216 217 218 219
            self.key = f.read()

    def get_serving_port(self, endpoints):
        if self.key is not None:
220
            req = json.dumps({"key": base64.b64encode(self.key).decode()})
H
HexToString 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
        else:
            req = json.dumps({})
        r = requests.post("http://" + endpoints[0], req)
        result = r.json()
        print(result)
        if "endpoint_list" not in result:
            raise ValueError("server not ready")
        else:
            endpoints = [
                endpoints[0].split(":")[0] + ":" +
                str(result["endpoint_list"][0])
            ]
            return endpoints

    def connect(self, endpoints=None, encryption=False):
G
guru4elephant 已提交
236 237 238
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
239 240
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
241
                raise ValueError(
B
barrierye 已提交
242 243 244
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
W
wangjiawei04 已提交
245
            if encryption:
H
HexToString 已提交
246
                endpoints = self.get_serving_port(endpoints)
B
barrierye 已提交
247
            if self.predictor_sdk_ is None:
248
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
249
                                 100)
B
barrierye 已提交
250 251
            else:
                print(
252
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
253
                    format(endpoints))
M
MRXLT 已提交
254
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
255 256
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
257 258 259 260 261 262 263

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
264 265 266
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
267 268
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
269
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
270 271 272
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
273 274 275
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
276

W
wangjiawei04 已提交
277 278 279 280 281 282
    def predict(self,
                feed=None,
                fetch=None,
                batch=False,
                need_variant_tag=False,
                log_id=0):
W
WangXi 已提交
283 284
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
285 286 287
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

288 289 290 291 292 293
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
294
            raise ValueError("Fetch only accepts string and list of string")
295 296 297 298 299 300 301

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
302
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
303

M
MRXLT 已提交
304 305
        int_slot_batch = []
        int_feed_names = []
D
dongdaxiang 已提交
306
        int_shape = []
W
wangjiawei04 已提交
307
        int_lod_slot_batch = []
H
HexToString 已提交
308 309
        float_slot_batch = []
        float_feed_names = []
W
wangjiawei04 已提交
310
        float_lod_slot_batch = []
D
dongdaxiang 已提交
311
        float_shape = []
H
HexToString 已提交
312 313 314 315
        string_slot_batch = []
        string_feed_names = []
        string_lod_slot_batch = []
        string_shape = []
W
wangjiawei04 已提交
316

M
MRXLT 已提交
317
        fetch_names = []
M
MRXLT 已提交
318
        counter = 0
M
MRXLT 已提交
319
        batch_size = len(feed_batch)
320 321 322 323 324 325 326

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
327
                "Fetch names should not be empty or out of saved fetch list.")
328 329
            return {}

G
guru4elephant 已提交
330
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
331
            int_slot = []
W
wangjiawei04 已提交
332
            int_lod_slot = []
H
HexToString 已提交
333
            float_slot = []
W
wangjiawei04 已提交
334
            float_lod_slot = []
H
HexToString 已提交
335 336
            string_slot = []
            string_lod_slot = []
337
            for key in feed_i:
W
wangjiawei04 已提交
338
                if ".lod" not in key and key not in self.feed_names_:
M
MRXLT 已提交
339
                    raise ValueError("Wrong feed name: {}.".format(key))
W
wangjiawei04 已提交
340 341
                if ".lod" in key:
                    continue
M
MRXLT 已提交
342 343
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
344
                if self.feed_types_[key] in int_type:
G
guru4elephant 已提交
345
                    if i == 0:
M
MRXLT 已提交
346
                        int_feed_names.append(key)
W
wangjiawei04 已提交
347 348 349
                        shape_lst = []
                        if batch == False:
                            feed_i[key] = feed_i[key][np.newaxis, :]
D
dongdaxiang 已提交
350
                        if isinstance(feed_i[key], np.ndarray):
W
wangjiawei04 已提交
351 352
                            shape_lst.extend(list(feed_i[key].shape))
                            int_shape.append(shape_lst)
D
dongdaxiang 已提交
353 354
                        else:
                            int_shape.append(self.feed_shapes_[key])
W
wangjiawei04 已提交
355 356 357 358 359 360
                        if "{}.lod".format(key) in feed_i:
                            int_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
                        else:
                            int_lod_slot_batch.append([])

D
dongdaxiang 已提交
361
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
362
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
363
                        self.has_numpy_input = True
D
dongdaxiang 已提交
364 365
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
366
                        self.all_numpy_input = False
W
wangjiawei04 已提交
367

M
MRXLT 已提交
368
                elif self.feed_types_[key] in float_type:
G
guru4elephant 已提交
369
                    if i == 0:
M
MRXLT 已提交
370
                        float_feed_names.append(key)
W
wangjiawei04 已提交
371 372 373
                        shape_lst = []
                        if batch == False:
                            feed_i[key] = feed_i[key][np.newaxis, :]
D
dongdaxiang 已提交
374
                        if isinstance(feed_i[key], np.ndarray):
W
wangjiawei04 已提交
375 376
                            shape_lst.extend(list(feed_i[key].shape))
                            float_shape.append(shape_lst)
D
dongdaxiang 已提交
377 378
                        else:
                            float_shape.append(self.feed_shapes_[key])
W
wangjiawei04 已提交
379 380 381 382 383 384
                        if "{}.lod".format(key) in feed_i:
                            float_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
                        else:
                            float_lod_slot_batch.append([])

D
dongdaxiang 已提交
385
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
386
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
387
                        self.has_numpy_input = True
D
dongdaxiang 已提交
388 389
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
390
                        self.all_numpy_input = False
H
HexToString 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403
                #if input is string, feed is not numpy.
                elif self.feed_types_[key] in string_type:
                    if i == 0:
                        string_feed_names.append(key)
                        string_shape.append(self.feed_shapes_[key])
                        if "{}.lod".format(key) in feed_i:
                            string_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
                        else:
                            string_lod_slot_batch.append([])
                    string_slot.append(feed_i[key])
                    self.has_numpy_input = True

M
MRXLT 已提交
404
            int_slot_batch.append(int_slot)
W
wangjiawei04 已提交
405
            int_lod_slot_batch.append(int_lod_slot)
H
HexToString 已提交
406
            float_slot_batch.append(float_slot)
W
wangjiawei04 已提交
407
            float_lod_slot_batch.append(float_lod_slot)
H
HexToString 已提交
408 409
            string_slot_batch.append(string_slot)
            string_lod_slot_batch.append(string_lod_slot)
M
MRXLT 已提交
410

W
WangXi 已提交
411 412 413
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

414
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
415
        if self.all_numpy_input:
M
MRXLT 已提交
416
            res = self.client_handle_.numpy_predict(
W
wangjiawei04 已提交
417 418
                float_slot_batch, float_feed_names, float_shape,
                float_lod_slot_batch, int_slot_batch, int_feed_names, int_shape,
H
HexToString 已提交
419 420
                int_lod_slot_batch, string_slot_batch, string_feed_names, string_shape,
                string_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
W
wangjiawei04 已提交
421
                log_id)
M
MRXLT 已提交
422
        elif self.has_numpy_input == False:
W
wangjiawei04 已提交
423 424
            raise ValueError(
                "Please make sure all of your inputs are numpy array")
M
MRXLT 已提交
425
        else:
M
MRXLT 已提交
426
            raise ValueError(
M
MRXLT 已提交
427 428
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
429

W
WangXi 已提交
430 431 432
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

433 434 435
        if res == -1:
            return None

B
barrierye 已提交
436
        multi_result_map = []
437
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
438
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
439
            result_map = {}
B
barrierye 已提交
440
            # result map needs to be a numpy array
B
barrierye 已提交
441
            for i, name in enumerate(fetch_names):
M
MRXLT 已提交
442
                if self.fetch_names_to_type_[name] == int64_type:
B
barrierye 已提交
443
                    # result_map[name] will be py::array(numpy array)
444 445 446
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barriery 已提交
447 448 449 450 451
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
B
barrierye 已提交
452 453
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
454 455 456
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
M
MRXLT 已提交
457
                elif self.fetch_names_to_type_[name] == float32_type:
458 459
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
B
barriery 已提交
460 461 462 463 464
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
465
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
466 467
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
468 469 470
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
M
MRXLT 已提交
471 472 473 474
                elif self.fetch_names_to_type_[name] == int32_type:
                    # result_map[name] will be py::array(numpy array)
                    result_map[name] = result_batch_handle.get_int32_by_name(
                        mi, name)
B
barriery 已提交
475 476 477 478 479
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
M
MRXLT 已提交
480 481 482
                    shape = result_batch_handle.get_shape(mi, name)
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
483 484 485
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
B
barrierye 已提交
486
            multi_result_map.append(result_map)
B
barrierye 已提交
487 488
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
489 490
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
491
        else:
B
barrierye 已提交
492 493 494 495 496 497
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
498 499 500
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
501
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
502
        return ret if not need_variant_tag else [
503
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
504
        ]
B
barrierye 已提交
505

506 507
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
508
        self.client_handle_ = None
B
barrierye 已提交
509 510


511
class MultiLangClient(object):
B
barrierye 已提交
512 513
    def __init__(self):
        self.channel_ = None
514
        self.stub_ = None
B
barrierye 已提交
515
        self.rpc_timeout_s_ = 2
B
barrierye 已提交
516
        self.profile_ = _Profiler()
B
barrierye 已提交
517

B
barrierye 已提交
518 519
    def add_variant(self, tag, cluster, variant_weight):
        # TODO
B
barrierye 已提交
520
        raise Exception("cannot support ABtest yet")
B
barrierye 已提交
521

B
barrierye 已提交
522
    def set_rpc_timeout_ms(self, rpc_timeout):
523 524 525 526 527
        if self.stub_ is None:
            raise Exception("set timeout must be set after connect.")
        if not isinstance(rpc_timeout, int):
            # for bclient
            raise ValueError("rpc_timeout must be int type.")
B
barrierye 已提交
528
        self.rpc_timeout_s_ = rpc_timeout / 1000.0
529 530 531 532
        timeout_req = multi_lang_general_model_service_pb2.SetTimeoutRequest()
        timeout_req.timeout_ms = rpc_timeout
        resp = self.stub_.SetTimeout(timeout_req)
        return resp.err_code == 0
B
barrierye 已提交
533 534

    def connect(self, endpoints):
W
WangXi 已提交
535 536 537
        # https://github.com/tensorflow/serving/issues/1382
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
538
                   ('grpc.lb_policy_name', 'round_robin')]
B
barrierye 已提交
539
        # TODO: weight round robin
540
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
B
barrierye 已提交
541
        self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
542
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
543
            self.channel_)
544 545 546 547 548 549
        # get client model config
        get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
        )
        resp = self.stub_.GetClientConfig(get_client_config_req)
        model_config_str = resp.client_config_str
        self._parse_model_config(model_config_str)
B
barrierye 已提交
550

B
barrierye 已提交
551 552 553 554 555 556 557 558
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

559
    def _parse_model_config(self, model_config_str):
B
barrierye 已提交
560
        model_conf = m_config.GeneralModelConfig()
561 562
        model_conf = google.protobuf.text_format.Merge(model_config_str,
                                                       model_conf)
B
barrierye 已提交
563 564
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
565
        self.feed_shapes_ = {}
B
barrierye 已提交
566
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
B
barrierye 已提交
567 568
        self.fetch_types_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
569 570 571
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
572
            if var.is_lod_tensor:
B
barrierye 已提交
573
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
574 575 576 577
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
B
barrierye 已提交
578
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
579 580 581
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
582

B
barriery 已提交
583
    def _pack_inference_request(self, feed, fetch, is_python, log_id):
584
        req = multi_lang_general_model_service_pb2.InferenceRequest()
B
barrierye 已提交
585
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
586
        req.is_python = is_python
B
barriery 已提交
587
        req.log_id = log_id
W
wangjiawei04 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
        feed_var_names = []
        for key in feed.keys():
            if '.lod' not in key:
                feed_var_names.append(key)
        req.feed_var_names.extend(feed_var_names)
        inst = multi_lang_general_model_service_pb2.FeedInst()
        for name in req.feed_var_names:
            tensor = multi_lang_general_model_service_pb2.Tensor()
            var = feed[name]
            v_type = self.feed_types_[name]
            if is_python:
                data = None
                if isinstance(var, list):
                    if v_type == 0:  # int64
                        data = np.array(var, dtype="int64")
                    elif v_type == 1:  # float32
                        data = np.array(var, dtype="float32")
                    elif v_type == 2:  # int32
                        data = np.array(var, dtype="int32")
B
barrierye 已提交
607
                    else:
W
wangjiawei04 已提交
608 609 610 611 612 613 614 615 616 617 618 619
                        raise Exception("error tensor value type.")
                elif isinstance(var, np.ndarray):
                    data = var
                    if v_type == 0:
                        if data.dtype != 'int64':
                            data = data.astype("int64")
                    elif v_type == 1:
                        if data.dtype != 'float32':
                            data = data.astype("float32")
                    elif v_type == 2:
                        if data.dtype != 'int32':
                            data = data.astype("int32")
B
barrierye 已提交
620
                    else:
W
wangjiawei04 已提交
621
                        raise Exception("error tensor value type.")
B
barrierye 已提交
622
                else:
W
wangjiawei04 已提交
623 624 625 626 627 628 629
                    raise Exception("var must be list or ndarray.")
                tensor.data = data.tobytes()
            tensor.shape.extend(list(var.shape))
            if "{}.lod".format(name) in feed.keys():
                tensor.lod.extend(feed["{}.lod".format(name)])
            inst.tensor_array.append(tensor)
        req.insts.append(inst)
B
barrierye 已提交
630
        return req
B
barrierye 已提交
631

632 633 634
    def _unpack_inference_response(self, resp, fetch, is_python,
                                   need_variant_tag):
        if resp.err_code != 0:
B
fix bug  
barrierye 已提交
635
            return None
B
barrierye 已提交
636
        tag = resp.tag
B
barrierye 已提交
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652
        multi_result_map = {}
        for model_result in resp.outputs:
            inst = model_result.insts[0]
            result_map = {}
            for i, name in enumerate(fetch):
                var = inst.tensor_array[i]
                v_type = self.fetch_types_[name]
                if is_python:
                    if v_type == 0:  # int64
                        result_map[name] = np.frombuffer(
                            var.data, dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.frombuffer(
                            var.data, dtype="float32")
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
653
                else:
B
barrierye 已提交
654 655 656 657 658 659 660 661 662 663 664 665 666 667
                    if v_type == 0:  # int64
                        result_map[name] = np.array(
                            list(var.int64_data), dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.array(
                            list(var.float_data), dtype="float32")
                    else:
                        raise Exception("error type.")
                result_map[name].shape = list(var.shape)
                if name in self.lod_tensor_set_:
                    result_map["{}.lod".format(name)] = np.array(list(var.lod))
            multi_result_map[model_result.engine_name] = result_map
        ret = None
        if len(resp.outputs) == 1:
B
barrierye 已提交
668
            ret = list(multi_result_map.values())[0]
B
barrierye 已提交
669 670
        else:
            ret = multi_result_map
B
barrierye 已提交
671

672
        ret["serving_status_code"] = 0
B
barrierye 已提交
673
        return ret if not need_variant_tag else [ret, tag]
674

B
barrierye 已提交
675
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
676
        def unpack_resp(resp):
677 678
            return self._unpack_inference_response(resp, fetch, is_python,
                                                   need_variant_tag)
B
barrierye 已提交
679

680 681
        return unpack_resp

W
WangXi 已提交
682 683 684
    def get_feed_names(self):
        return self.feed_names_

B
barrierye 已提交
685 686 687
    def predict(self,
                feed,
                fetch,
W
wangjiawei04 已提交
688
                batch=True,
B
barrierye 已提交
689 690
                need_variant_tag=False,
                asyn=False,
B
barriery 已提交
691 692
                is_python=True,
                log_id=0):
W
wangjiawei04 已提交
693 694 695 696 697 698
        if isinstance(feed, dict) is False:
            raise ValueError("Type Error. grpc feed must be dict.")
        if batch is False:
            for key in feed:
                if ".lod" not in key:
                    feed[key] = feed[key][np.newaxis, :]
699
        if not asyn:
B
barrierye 已提交
700
            try:
B
barrierye 已提交
701 702
                self.profile_.record('py_prepro_0')
                req = self._pack_inference_request(
B
barriery 已提交
703
                    feed, fetch, is_python=is_python, log_id=log_id)
B
barrierye 已提交
704 705 706
                self.profile_.record('py_prepro_1')

                self.profile_.record('py_client_infer_0')
B
barrierye 已提交
707
                resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
B
barrierye 已提交
708 709 710 711
                self.profile_.record('py_client_infer_1')

                self.profile_.record('py_postpro_0')
                ret = self._unpack_inference_response(
B
barrierye 已提交
712 713 714 715
                    resp,
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag)
B
barrierye 已提交
716 717 718
                self.profile_.record('py_postpro_1')
                self.profile_.print_profile()
                return ret
B
barrierye 已提交
719
            except grpc.RpcError as e:
720
                return {"serving_status_code": e.code()}
721
        else:
B
barriery 已提交
722 723
            req = self._pack_inference_request(
                feed, fetch, is_python=is_python, log_id=log_id)
724 725
            call_future = self.stub_.Inference.future(
                req, timeout=self.rpc_timeout_s_)
726
            return MultiLangPredictFuture(
B
barrierye 已提交
727 728 729 730 731
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
732 733 734 735 736 737 738 739


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
B
barrierye 已提交
740 741 742
        try:
            resp = self.call_future_.result()
        except grpc.RpcError as e:
743
            return {"serving_status_code": e.code()}
744
        return self.callback_func_(resp)
W
WangXi 已提交
745 746 747 748 749 750 751

    def add_done_callback(self, fn):
        def __fn__(call_future):
            assert call_future == self.call_future_
            fn(self)

        self.call_future_.add_done_callback(__fn__)