__init__.py 31.1 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

16
import paddle_serving_client
M
MRXLT 已提交
17
import os
18 19 20 21
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import numpy as np
H
HexToString 已提交
22 23 24
import requests
import json
import base64
D
dongdaxiang 已提交
25
import time
26
import sys
G
guru4elephant 已提交
27

B
barrierye 已提交
28
import grpc
B
barrierye 已提交
29
from .proto import multi_lang_general_model_service_pb2
B
barrierye 已提交
30 31
sys.path.append(
    os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
B
barrierye 已提交
32
from .proto import multi_lang_general_model_service_pb2_grpc
B
barrierye 已提交
33

M
MRXLT 已提交
34 35 36
int64_type = 0
float32_type = 1
int32_type = 2
H
HexToString 已提交
37
bytes_type = 3
M
MRXLT 已提交
38 39
int_type = set([int64_type, int32_type])
float_type = set([float32_type])
H
HexToString 已提交
40
string_type= set([bytes_type])
G
guru4elephant 已提交
41

M
MRXLT 已提交
42

W
WangXi 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
71 72 73
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
74 75 76
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
77 78
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
79

80 81 82 83
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
84

M
MRXLT 已提交
85 86 87 88
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
89 90 91 92 93
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
94 95
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
96

97 98 99 100 101 102
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
103 104 105 106

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
107
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
108 109 110 111 112
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
113

G
guru4elephant 已提交
114 115 116 117 118 119 120 121
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
122
        return self.sdk_desc
G
guru4elephant 已提交
123

G
guru4elephant 已提交
124 125 126 127 128 129

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
130
        self.feed_shapes_ = {}
G
guru4elephant 已提交
131
        self.feed_types_ = {}
G
guru4elephant 已提交
132
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
133
        self.pid = os.getpid()
B
barrierye 已提交
134
        self.predictor_sdk_ = None
G
guru4elephant 已提交
135 136
        self.producers = []
        self.consumer = None
W
WangXi 已提交
137
        self.profile_ = _Profiler()
M
MRXLT 已提交
138 139
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
140
        self.rpc_timeout_ms = 20000
141 142
        from .serving_client import PredictorRes
        self.predictorres_constructor = PredictorRes
M
MRXLT 已提交
143

H
HexToString 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156
    def load_client_config(self, model_config_path_list):
        if isinstance(model_config_path_list, str):
            model_config_path_list = [model_config_path_list]
        elif isinstance(model_config_path_list, list):
            pass

        file_path_list = []
        for single_model_config in model_config_path_list:
            if os.path.isdir(single_model_config):
                file_path_list.append("{}/serving_server_conf.prototxt".format(
                    single_model_config))
            elif os.path.isfile(single_model_config):
                file_path_list.append(single_model_config)
H
HexToString 已提交
157

H
HexToString 已提交
158
        
M
MRXLT 已提交
159
        from .serving_client import PredictorClient
160
        model_conf = m_config.GeneralModelConfig()
H
HexToString 已提交
161
        f = open(file_path_list[0], 'r')
162 163 164
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
165 166 167 168
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
169
        self.client_handle_ = PredictorClient()
H
HexToString 已提交
170
        self.client_handle_.init(file_path_list)
M
bug fix  
MRXLT 已提交
171 172
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
173
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
174 175
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
H
HexToString 已提交
176
        
177
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
H
HexToString 已提交
178
        self.feed_names_to_idx_ = {}#this is not useful
M
MRXLT 已提交
179
        self.lod_tensor_set = set()
H
HexToString 已提交
180
        self.feed_tensor_len = {}#this is only used for shape check 
H
HexToString 已提交
181
        self.key = None
182 183 184
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
185
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
186

M
MRXLT 已提交
187 188
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
189 190 191 192 193
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
H
HexToString 已提交
194
        
H
HexToString 已提交
195
        if len(file_path_list) > 1:
H
HexToString 已提交
196
            model_conf = m_config.GeneralModelConfig()
H
HexToString 已提交
197
            f = open(file_path_list[-1], 'r')
H
HexToString 已提交
198 199 200 201 202
            model_conf = google.protobuf.text_format.Merge(
                str(f.read()), model_conf)
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
G
guru4elephant 已提交
203 204 205
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
206 207
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
208 209
        return

H
HexToString 已提交
210

211
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
212 213
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
214 215 216
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
217 218 219 220 221 222
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

H
HexToString 已提交
223
    def use_key(self, key_filename):
224
        with open(key_filename, "rb") as f:
H
HexToString 已提交
225 226 227 228
            self.key = f.read()

    def get_serving_port(self, endpoints):
        if self.key is not None:
229
            req = json.dumps({"key": base64.b64encode(self.key).decode()})
H
HexToString 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
        else:
            req = json.dumps({})
        r = requests.post("http://" + endpoints[0], req)
        result = r.json()
        print(result)
        if "endpoint_list" not in result:
            raise ValueError("server not ready")
        else:
            endpoints = [
                endpoints[0].split(":")[0] + ":" +
                str(result["endpoint_list"][0])
            ]
            return endpoints

    def connect(self, endpoints=None, encryption=False):
G
guru4elephant 已提交
245 246 247
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
248 249
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
250
                raise ValueError(
B
barrierye 已提交
251 252 253
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
W
wangjiawei04 已提交
254
            if encryption:
H
HexToString 已提交
255
                endpoints = self.get_serving_port(endpoints)
B
barrierye 已提交
256
            if self.predictor_sdk_ is None:
257
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
258
                                 100)
B
barrierye 已提交
259 260
            else:
                print(
261
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
262
                    format(endpoints))
M
MRXLT 已提交
263
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
264 265
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
266 267 268 269 270 271 272

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
273 274 275
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
276 277
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
278
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
279 280 281
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
282 283 284
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
285

W
wangjiawei04 已提交
286 287 288 289 290 291
    def predict(self,
                feed=None,
                fetch=None,
                batch=False,
                need_variant_tag=False,
                log_id=0):
W
WangXi 已提交
292 293
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
294 295 296
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

297 298 299 300 301 302
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
303
            raise ValueError("Fetch only accepts string and list of string")
304 305 306 307 308 309 310

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
311
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
312

M
MRXLT 已提交
313 314
        int_slot_batch = []
        int_feed_names = []
D
dongdaxiang 已提交
315
        int_shape = []
W
wangjiawei04 已提交
316
        int_lod_slot_batch = []
H
HexToString 已提交
317 318
        float_slot_batch = []
        float_feed_names = []
W
wangjiawei04 已提交
319
        float_lod_slot_batch = []
D
dongdaxiang 已提交
320
        float_shape = []
H
HexToString 已提交
321 322 323 324
        string_slot_batch = []
        string_feed_names = []
        string_lod_slot_batch = []
        string_shape = []
W
wangjiawei04 已提交
325

M
MRXLT 已提交
326
        fetch_names = []
M
MRXLT 已提交
327
        counter = 0
M
MRXLT 已提交
328
        batch_size = len(feed_batch)
329 330 331 332 333 334 335

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
336
                "Fetch names should not be empty or out of saved fetch list.")
337 338
            return {}

G
guru4elephant 已提交
339
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
340
            int_slot = []
W
wangjiawei04 已提交
341
            int_lod_slot = []
H
HexToString 已提交
342
            float_slot = []
W
wangjiawei04 已提交
343
            float_lod_slot = []
H
HexToString 已提交
344 345
            string_slot = []
            string_lod_slot = []
346
            for key in feed_i:
W
wangjiawei04 已提交
347
                if ".lod" not in key and key not in self.feed_names_:
M
MRXLT 已提交
348
                    raise ValueError("Wrong feed name: {}.".format(key))
W
wangjiawei04 已提交
349 350
                if ".lod" in key:
                    continue
M
MRXLT 已提交
351 352
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
353
                if self.feed_types_[key] in int_type:
G
guru4elephant 已提交
354
                    if i == 0:
M
MRXLT 已提交
355
                        int_feed_names.append(key)
W
wangjiawei04 已提交
356 357 358
                        shape_lst = []
                        if batch == False:
                            feed_i[key] = feed_i[key][np.newaxis, :]
D
dongdaxiang 已提交
359
                        if isinstance(feed_i[key], np.ndarray):
W
wangjiawei04 已提交
360 361
                            shape_lst.extend(list(feed_i[key].shape))
                            int_shape.append(shape_lst)
D
dongdaxiang 已提交
362 363
                        else:
                            int_shape.append(self.feed_shapes_[key])
W
wangjiawei04 已提交
364 365 366 367 368 369
                        if "{}.lod".format(key) in feed_i:
                            int_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
                        else:
                            int_lod_slot_batch.append([])

D
dongdaxiang 已提交
370
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
371
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
372
                        self.has_numpy_input = True
D
dongdaxiang 已提交
373 374
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
375
                        self.all_numpy_input = False
W
wangjiawei04 已提交
376

M
MRXLT 已提交
377
                elif self.feed_types_[key] in float_type:
G
guru4elephant 已提交
378
                    if i == 0:
M
MRXLT 已提交
379
                        float_feed_names.append(key)
W
wangjiawei04 已提交
380 381 382
                        shape_lst = []
                        if batch == False:
                            feed_i[key] = feed_i[key][np.newaxis, :]
D
dongdaxiang 已提交
383
                        if isinstance(feed_i[key], np.ndarray):
W
wangjiawei04 已提交
384 385
                            shape_lst.extend(list(feed_i[key].shape))
                            float_shape.append(shape_lst)
D
dongdaxiang 已提交
386 387
                        else:
                            float_shape.append(self.feed_shapes_[key])
W
wangjiawei04 已提交
388 389 390 391 392 393
                        if "{}.lod".format(key) in feed_i:
                            float_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
                        else:
                            float_lod_slot_batch.append([])

D
dongdaxiang 已提交
394
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
395
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
396
                        self.has_numpy_input = True
D
dongdaxiang 已提交
397 398
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
399
                        self.all_numpy_input = False
H
HexToString 已提交
400 401 402 403 404 405 406 407 408 409 410 411 412
                #if input is string, feed is not numpy.
                elif self.feed_types_[key] in string_type:
                    if i == 0:
                        string_feed_names.append(key)
                        string_shape.append(self.feed_shapes_[key])
                        if "{}.lod".format(key) in feed_i:
                            string_lod_slot_batch.append(feed_i["{}.lod".format(
                                key)])
                        else:
                            string_lod_slot_batch.append([])
                    string_slot.append(feed_i[key])
                    self.has_numpy_input = True

M
MRXLT 已提交
413
            int_slot_batch.append(int_slot)
W
wangjiawei04 已提交
414
            int_lod_slot_batch.append(int_lod_slot)
H
HexToString 已提交
415
            float_slot_batch.append(float_slot)
W
wangjiawei04 已提交
416
            float_lod_slot_batch.append(float_lod_slot)
H
HexToString 已提交
417 418
            string_slot_batch.append(string_slot)
            string_lod_slot_batch.append(string_lod_slot)
M
MRXLT 已提交
419

W
WangXi 已提交
420 421 422
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

423
        result_batch_handle = self.predictorres_constructor()
M
MRXLT 已提交
424
        if self.all_numpy_input:
M
MRXLT 已提交
425
            res = self.client_handle_.numpy_predict(
W
wangjiawei04 已提交
426 427
                float_slot_batch, float_feed_names, float_shape,
                float_lod_slot_batch, int_slot_batch, int_feed_names, int_shape,
H
HexToString 已提交
428 429
                int_lod_slot_batch, string_slot_batch, string_feed_names, string_shape,
                string_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
W
wangjiawei04 已提交
430
                log_id)
M
MRXLT 已提交
431
        elif self.has_numpy_input == False:
W
wangjiawei04 已提交
432 433
            raise ValueError(
                "Please make sure all of your inputs are numpy array")
M
MRXLT 已提交
434
        else:
M
MRXLT 已提交
435
            raise ValueError(
M
MRXLT 已提交
436 437
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
438

W
WangXi 已提交
439 440 441
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

442 443 444
        if res == -1:
            return None

B
barrierye 已提交
445
        multi_result_map = []
446
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
447
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
448
            result_map = {}
B
barrierye 已提交
449
            # result map needs to be a numpy array
B
barrierye 已提交
450
            for i, name in enumerate(fetch_names):
M
MRXLT 已提交
451
                if self.fetch_names_to_type_[name] == int64_type:
B
barrierye 已提交
452
                    # result_map[name] will be py::array(numpy array)
453 454 455
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barriery 已提交
456 457 458 459 460
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
B
barrierye 已提交
461 462
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
463 464 465
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
M
MRXLT 已提交
466
                elif self.fetch_names_to_type_[name] == float32_type:
467 468
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
B
barriery 已提交
469 470 471 472 473
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
474
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
475 476
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
477 478 479
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
M
MRXLT 已提交
480 481 482 483
                elif self.fetch_names_to_type_[name] == int32_type:
                    # result_map[name] will be py::array(numpy array)
                    result_map[name] = result_batch_handle.get_int32_by_name(
                        mi, name)
B
barriery 已提交
484 485 486 487 488
                    if result_map[name].size == 0:
                        raise ValueError(
                            "Failed to fetch, maybe the type of [{}]"
                            " is wrong, please check the model file".format(
                                name))
M
MRXLT 已提交
489 490 491
                    shape = result_batch_handle.get_shape(mi, name)
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
W
wangjiawei04 已提交
492 493 494
                        tmp_lod = result_batch_handle.get_lod(mi, name)
                        if np.size(tmp_lod) > 0:
                            result_map["{}.lod".format(name)] = tmp_lod
B
barrierye 已提交
495
            multi_result_map.append(result_map)
B
barrierye 已提交
496 497
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
498 499
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
500
        else:
B
barrierye 已提交
501 502 503 504 505 506
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
507 508 509
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
510
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
511
        return ret if not need_variant_tag else [
512
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
513
        ]
B
barrierye 已提交
514

515 516
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
517
        self.client_handle_ = None
B
barrierye 已提交
518 519


520
class MultiLangClient(object):
B
barrierye 已提交
521 522
    def __init__(self):
        self.channel_ = None
523
        self.stub_ = None
B
barrierye 已提交
524
        self.rpc_timeout_s_ = 2
B
barrierye 已提交
525
        self.profile_ = _Profiler()
B
barrierye 已提交
526

B
barrierye 已提交
527 528
    def add_variant(self, tag, cluster, variant_weight):
        # TODO
B
barrierye 已提交
529
        raise Exception("cannot support ABtest yet")
B
barrierye 已提交
530

B
barrierye 已提交
531
    def set_rpc_timeout_ms(self, rpc_timeout):
532 533 534 535 536
        if self.stub_ is None:
            raise Exception("set timeout must be set after connect.")
        if not isinstance(rpc_timeout, int):
            # for bclient
            raise ValueError("rpc_timeout must be int type.")
B
barrierye 已提交
537
        self.rpc_timeout_s_ = rpc_timeout / 1000.0
538 539 540 541
        timeout_req = multi_lang_general_model_service_pb2.SetTimeoutRequest()
        timeout_req.timeout_ms = rpc_timeout
        resp = self.stub_.SetTimeout(timeout_req)
        return resp.err_code == 0
B
barrierye 已提交
542 543

    def connect(self, endpoints):
W
WangXi 已提交
544 545 546
        # https://github.com/tensorflow/serving/issues/1382
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
547
                   ('grpc.lb_policy_name', 'round_robin')]
B
barrierye 已提交
548
        # TODO: weight round robin
549
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
B
barrierye 已提交
550
        self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
551
        self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
B
barrierye 已提交
552
            self.channel_)
553 554 555 556
        # get client model config
        get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
        )
        resp = self.stub_.GetClientConfig(get_client_config_req)
H
HexToString 已提交
557 558
        model_config_path_list = resp.client_config_str_list
        self._parse_model_config(model_config_path_list)
B
barrierye 已提交
559

B
barrierye 已提交
560 561 562 563 564 565 566 567
    def _flatten_list(self, nested_list):
        for item in nested_list:
            if isinstance(item, (list, tuple)):
                for sub_item in self._flatten_list(item):
                    yield sub_item
            else:
                yield item

H
HexToString 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580 581
    def _parse_model_config(self, model_config_path_list):
        if isinstance(model_config_path_list, str):
            model_config_path_list = [model_config_path_list]
        elif isinstance(model_config_path_list, list):
            pass

        file_path_list = []
        for single_model_config in model_config_path_list:
            if os.path.isdir(single_model_config):
                file_path_list.append("{}/serving_server_conf.prototxt".format(
                    single_model_config))
            elif os.path.isfile(single_model_config):
                file_path_list.append(single_model_config)
        
B
barrierye 已提交
582
        model_conf = m_config.GeneralModelConfig()
H
HexToString 已提交
583 584 585
        f = open(file_path_list[0], 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
B
barrierye 已提交
586 587
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.feed_types_ = {}
B
barrierye 已提交
588 589
        self.feed_shapes_ = {}
        self.lod_tensor_set_ = set()
B
barrierye 已提交
590 591 592
        for i, var in enumerate(model_conf.feed_var):
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape
B
barrierye 已提交
593
            if var.is_lod_tensor:
B
barrierye 已提交
594
                self.lod_tensor_set_.add(var.alias_name)
H
HexToString 已提交
595 596 597 598 599 600 601 602 603
        
        if len(file_path_list) > 1:
            model_conf = m_config.GeneralModelConfig()
            f = open(file_path_list[-1], 'r')
            model_conf = google.protobuf.text_format.Merge(
                str(f.read()), model_conf)

        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.fetch_types_ = {}
B
barrierye 已提交
604
        for i, var in enumerate(model_conf.fetch_var):
B
barrierye 已提交
605 606 607
            self.fetch_types_[var.alias_name] = var.fetch_type
            if var.is_lod_tensor:
                self.lod_tensor_set_.add(var.alias_name)
B
barrierye 已提交
608

B
barriery 已提交
609
    def _pack_inference_request(self, feed, fetch, is_python, log_id):
610
        req = multi_lang_general_model_service_pb2.InferenceRequest()
B
barrierye 已提交
611
        req.fetch_var_names.extend(fetch)
B
barrierye 已提交
612
        req.is_python = is_python
B
barriery 已提交
613
        req.log_id = log_id
W
wangjiawei04 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
        feed_var_names = []
        for key in feed.keys():
            if '.lod' not in key:
                feed_var_names.append(key)
        req.feed_var_names.extend(feed_var_names)
        inst = multi_lang_general_model_service_pb2.FeedInst()
        for name in req.feed_var_names:
            tensor = multi_lang_general_model_service_pb2.Tensor()
            var = feed[name]
            v_type = self.feed_types_[name]
            if is_python:
                data = None
                if isinstance(var, list):
                    if v_type == 0:  # int64
                        data = np.array(var, dtype="int64")
                    elif v_type == 1:  # float32
                        data = np.array(var, dtype="float32")
                    elif v_type == 2:  # int32
                        data = np.array(var, dtype="int32")
B
barrierye 已提交
633
                    else:
W
wangjiawei04 已提交
634 635 636 637 638 639 640 641 642 643 644 645
                        raise Exception("error tensor value type.")
                elif isinstance(var, np.ndarray):
                    data = var
                    if v_type == 0:
                        if data.dtype != 'int64':
                            data = data.astype("int64")
                    elif v_type == 1:
                        if data.dtype != 'float32':
                            data = data.astype("float32")
                    elif v_type == 2:
                        if data.dtype != 'int32':
                            data = data.astype("int32")
B
barrierye 已提交
646
                    else:
W
wangjiawei04 已提交
647
                        raise Exception("error tensor value type.")
B
barrierye 已提交
648
                else:
W
wangjiawei04 已提交
649 650 651 652 653 654 655
                    raise Exception("var must be list or ndarray.")
                tensor.data = data.tobytes()
            tensor.shape.extend(list(var.shape))
            if "{}.lod".format(name) in feed.keys():
                tensor.lod.extend(feed["{}.lod".format(name)])
            inst.tensor_array.append(tensor)
        req.insts.append(inst)
B
barrierye 已提交
656
        return req
B
barrierye 已提交
657

658 659 660
    def _unpack_inference_response(self, resp, fetch, is_python,
                                   need_variant_tag):
        if resp.err_code != 0:
B
fix bug  
barrierye 已提交
661
            return None
B
barrierye 已提交
662
        tag = resp.tag
B
barrierye 已提交
663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
        multi_result_map = {}
        for model_result in resp.outputs:
            inst = model_result.insts[0]
            result_map = {}
            for i, name in enumerate(fetch):
                var = inst.tensor_array[i]
                v_type = self.fetch_types_[name]
                if is_python:
                    if v_type == 0:  # int64
                        result_map[name] = np.frombuffer(
                            var.data, dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.frombuffer(
                            var.data, dtype="float32")
                    else:
                        raise Exception("error type.")
B
barrierye 已提交
679
                else:
B
barrierye 已提交
680 681 682 683 684 685 686 687 688 689 690 691 692 693
                    if v_type == 0:  # int64
                        result_map[name] = np.array(
                            list(var.int64_data), dtype="int64")
                    elif v_type == 1:  # float32
                        result_map[name] = np.array(
                            list(var.float_data), dtype="float32")
                    else:
                        raise Exception("error type.")
                result_map[name].shape = list(var.shape)
                if name in self.lod_tensor_set_:
                    result_map["{}.lod".format(name)] = np.array(list(var.lod))
            multi_result_map[model_result.engine_name] = result_map
        ret = None
        if len(resp.outputs) == 1:
B
barrierye 已提交
694
            ret = list(multi_result_map.values())[0]
B
barrierye 已提交
695 696
        else:
            ret = multi_result_map
B
barrierye 已提交
697

698
        ret["serving_status_code"] = 0
B
barrierye 已提交
699
        return ret if not need_variant_tag else [ret, tag]
700

B
barrierye 已提交
701
    def _done_callback_func(self, fetch, is_python, need_variant_tag):
702
        def unpack_resp(resp):
703 704
            return self._unpack_inference_response(resp, fetch, is_python,
                                                   need_variant_tag)
B
barrierye 已提交
705

706 707
        return unpack_resp

W
WangXi 已提交
708 709 710
    def get_feed_names(self):
        return self.feed_names_

B
barrierye 已提交
711 712 713
    def predict(self,
                feed,
                fetch,
W
wangjiawei04 已提交
714
                batch=True,
B
barrierye 已提交
715 716
                need_variant_tag=False,
                asyn=False,
B
barriery 已提交
717 718
                is_python=True,
                log_id=0):
W
wangjiawei04 已提交
719 720 721 722 723 724
        if isinstance(feed, dict) is False:
            raise ValueError("Type Error. grpc feed must be dict.")
        if batch is False:
            for key in feed:
                if ".lod" not in key:
                    feed[key] = feed[key][np.newaxis, :]
725
        if not asyn:
B
barrierye 已提交
726
            try:
B
barrierye 已提交
727 728
                self.profile_.record('py_prepro_0')
                req = self._pack_inference_request(
B
barriery 已提交
729
                    feed, fetch, is_python=is_python, log_id=log_id)
B
barrierye 已提交
730 731 732
                self.profile_.record('py_prepro_1')

                self.profile_.record('py_client_infer_0')
B
barrierye 已提交
733
                resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
B
barrierye 已提交
734 735 736 737
                self.profile_.record('py_client_infer_1')

                self.profile_.record('py_postpro_0')
                ret = self._unpack_inference_response(
B
barrierye 已提交
738 739 740 741
                    resp,
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag)
B
barrierye 已提交
742 743 744
                self.profile_.record('py_postpro_1')
                self.profile_.print_profile()
                return ret
B
barrierye 已提交
745
            except grpc.RpcError as e:
746
                return {"serving_status_code": e.code()}
747
        else:
B
barriery 已提交
748 749
            req = self._pack_inference_request(
                feed, fetch, is_python=is_python, log_id=log_id)
750 751
            call_future = self.stub_.Inference.future(
                req, timeout=self.rpc_timeout_s_)
752
            return MultiLangPredictFuture(
B
barrierye 已提交
753 754 755 756 757
                call_future,
                self._done_callback_func(
                    fetch,
                    is_python=is_python,
                    need_variant_tag=need_variant_tag))
758 759 760 761 762 763 764 765


class MultiLangPredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
B
barrierye 已提交
766 767 768
        try:
            resp = self.call_future_.result()
        except grpc.RpcError as e:
769
            return {"serving_status_code": e.code()}
770
        return self.callback_func_(resp)
W
WangXi 已提交
771 772 773 774 775 776 777

    def add_done_callback(self, fn):
        def __fn__(call_future):
            assert call_future == self.call_future_
            fn(self)

        self.call_future_.add_done_callback(__fn__)