__init__.py 14.7 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
24
from .serving_client import PredictorRes
G
guru4elephant 已提交
25

G
guru4elephant 已提交
26 27 28
int_type = 0
float_type = 1

M
MRXLT 已提交
29

W
WangXi 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
58 59 60
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
61 62 63
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
M
MRXLT 已提交
64 65
        self.rpc_timeout_ms = 20000
        self.load_balance_strategy = "la"
G
guru4elephant 已提交
66

67 68 69 70
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
71

M
MRXLT 已提交
72 73 74 75
    def set_load_banlance_strategy(self, strategy):
        self.load_balance_strategy = strategy

    def gen_desc(self, rpc_timeout_ms):
G
guru4elephant 已提交
76 77 78 79 80
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
81 82
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
83

84 85 86 87 88 89
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
90 91 92 93

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
M
MRXLT 已提交
94
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = rpc_timeout_ms
G
guru4elephant 已提交
95 96 97 98 99
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
100

G
guru4elephant 已提交
101 102 103 104 105 106 107 108
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
109
        return self.sdk_desc
G
guru4elephant 已提交
110

G
guru4elephant 已提交
111 112 113 114 115 116

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
117
        self.feed_shapes_ = {}
G
guru4elephant 已提交
118
        self.feed_types_ = {}
G
guru4elephant 已提交
119
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
120
        self.pid = os.getpid()
B
barrierye 已提交
121
        self.predictor_sdk_ = None
G
guru4elephant 已提交
122 123
        self.producers = []
        self.consumer = None
W
WangXi 已提交
124
        self.profile_ = _Profiler()
M
MRXLT 已提交
125 126
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
127
        self.rpc_timeout_ms = 20000
M
MRXLT 已提交
128

G
guru4elephant 已提交
129
    def load_client_config(self, path):
M
MRXLT 已提交
130
        from .serving_client import PredictorClient
131 132 133 134 135
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
136 137 138 139
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
140 141
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
142 143
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
144
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
145 146
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
147 148
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
149
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
150 151
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
152
        self.lod_tensor_set = set()
M
MRXLT 已提交
153
        self.feed_tensor_len = {}
154

155 156 157
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
158
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
159

M
MRXLT 已提交
160 161
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
162 163 164 165 166
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
167 168 169
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
170 171
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
172 173
        return

174
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
175 176
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
177 178 179
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

M
MRXLT 已提交
180 181 182 183 184 185
    def set_rpc_timeout_ms(self, rpc_timeout):
        if not isinstance(rpc_timeout, int):
            raise ValueError("rpc_timeout must be int type.")
        else:
            self.rpc_timeout_ms = rpc_timeout

B
barrierye 已提交
186
    def connect(self, endpoints=None):
G
guru4elephant 已提交
187 188 189
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
190 191
        if endpoints is None:
            if self.predictor_sdk_ is None:
M
MRXLT 已提交
192
                raise ValueError(
B
barrierye 已提交
193 194 195 196
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
197
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
198
                                 100)
B
barrierye 已提交
199 200
            else:
                print(
201
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
202
                    format(endpoints))
M
MRXLT 已提交
203
        sdk_desc = self.predictor_sdk_.gen_desc(self.rpc_timeout_ms)
M
MRXLT 已提交
204 205
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
206 207 208 209 210 211 212

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
213 214 215
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
216 217
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
218
            raise ValueError("The shape of feed tensor {} not match.".format(
M
MRXLT 已提交
219 220 221
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
222 223 224
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
225

226
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
227 228
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
229 230 231
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

232 233 234 235 236 237
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
238
            raise ValueError("Fetch only accepts string and list of string")
239 240 241 242 243 244 245

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
246
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
247

M
MRXLT 已提交
248 249 250 251
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
252 253
        int_shape = []
        float_shape = []
M
MRXLT 已提交
254
        fetch_names = []
M
MRXLT 已提交
255
        counter = 0
M
MRXLT 已提交
256
        batch_size = len(feed_batch)
257 258 259 260 261 262 263

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
264
                "Fetch names should not be empty or out of saved fetch list.")
265 266
            return {}

G
guru4elephant 已提交
267
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
268 269
            int_slot = []
            float_slot = []
270
            for key in feed_i:
M
MRXLT 已提交
271
                if key not in self.feed_names_:
M
MRXLT 已提交
272
                    raise ValueError("Wrong feed name: {}.".format(key))
M
MRXLT 已提交
273 274
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
275
                if self.feed_types_[key] == int_type:
G
guru4elephant 已提交
276
                    if i == 0:
M
MRXLT 已提交
277
                        int_feed_names.append(key)
D
dongdaxiang 已提交
278
                        if isinstance(feed_i[key], np.ndarray):
279
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
280 281
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
282
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
283
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
284
                        self.has_numpy_input = True
D
dongdaxiang 已提交
285 286
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
287
                        self.all_numpy_input = False
M
MRXLT 已提交
288
                elif self.feed_types_[key] == float_type:
G
guru4elephant 已提交
289
                    if i == 0:
M
MRXLT 已提交
290
                        float_feed_names.append(key)
D
dongdaxiang 已提交
291
                        if isinstance(feed_i[key], np.ndarray):
292
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
293 294
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
295
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
296
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
297
                        self.has_numpy_input = True
D
dongdaxiang 已提交
298 299
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
300
                        self.all_numpy_input = False
M
MRXLT 已提交
301 302 303
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
304 305 306
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

307
        result_batch_handle = PredictorRes()
M
MRXLT 已提交
308
        if self.all_numpy_input:
M
MRXLT 已提交
309 310
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
311 312
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
313
        elif self.has_numpy_input == False:
M
MRXLT 已提交
314 315
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
316 317
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
318
        else:
M
MRXLT 已提交
319
            raise ValueError(
M
MRXLT 已提交
320 321
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
322

W
WangXi 已提交
323 324 325
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

326 327 328
        if res == -1:
            return None

B
barrierye 已提交
329
        multi_result_map = []
330
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
331
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
332
            result_map = {}
B
barrierye 已提交
333
            # result map needs to be a numpy array
B
barrierye 已提交
334 335
            for i, name in enumerate(fetch_names):
                if self.fetch_names_to_type_[name] == int_type:
B
barrierye 已提交
336
                    # result_map[name] will be py::array(numpy array)
337 338 339
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
340 341
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
342 343
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
344
                elif self.fetch_names_to_type_[name] == float_type:
345 346 347
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
348 349
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
350 351
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
352
            multi_result_map.append(result_map)
B
barrierye 已提交
353 354
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
355 356
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
357
        else:
B
barrierye 已提交
358 359 360 361 362 363
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
364 365 366
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
367
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
368
        return ret if not need_variant_tag else [
369
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
370
        ]
B
barrierye 已提交
371

372 373
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
374
        self.client_handle_ = None