__init__.py 14.6 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
G
guru4elephant 已提交
24

G
guru4elephant 已提交
25 26 27
int_type = 0
float_type = 1

M
MRXLT 已提交
28

W
WangXi 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
57 58 59
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
60 61 62
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
G
guru4elephant 已提交
63

64 65 66 67
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
68 69 70 71 72 73 74

    def gen_desc(self):
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
75 76
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
77

78 79 80 81 82 83
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
84 85 86 87 88 89 90 91 92 93

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = 20000
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
94

G
guru4elephant 已提交
95 96 97 98 99 100 101 102
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
103
        return self.sdk_desc
G
guru4elephant 已提交
104

G
guru4elephant 已提交
105 106 107 108 109 110

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
111
        self.result_handle_ = None
M
MRXLT 已提交
112
        self.feed_shapes_ = {}
G
guru4elephant 已提交
113
        self.feed_types_ = {}
G
guru4elephant 已提交
114
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
115
        self.rpath()
M
MRXLT 已提交
116
        self.pid = os.getpid()
B
barrierye 已提交
117
        self.predictor_sdk_ = None
G
guru4elephant 已提交
118 119
        self.producers = []
        self.consumer = None
W
WangXi 已提交
120
        self.profile_ = _Profiler()
M
MRXLT 已提交
121 122
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
123 124 125 126 127

    def rpath(self):
        lib_path = os.path.dirname(paddle_serving_client.__file__)
        client_path = os.path.join(lib_path, 'serving_client.so')
        lib_path = os.path.join(lib_path, 'lib')
M
MRXLT 已提交
128 129 130 131 132
        ld_path = os.getenv('LD_LIBRARY_PATH')
        if ld_path == None:
            os.environ['LD_LIBRARY_PATH'] = lib_path
        else:
            os.environ['LD_LIBRARY_PATH'] = ld_path + ':' + lib_path
M
MRXLT 已提交
133

G
guru4elephant 已提交
134
    def load_client_config(self, path):
M
MRXLT 已提交
135
        from .serving_client import PredictorClient
136
        from .serving_client import PredictorRes
137 138 139 140 141
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
142 143 144 145
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
146
        self.result_handle_ = PredictorRes()
G
guru4elephant 已提交
147 148
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
149 150
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
151
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
152 153
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
154 155
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
156
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
157 158
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
159
        self.lod_tensor_set = set()
M
MRXLT 已提交
160
        self.feed_tensor_len = {}
161

162 163 164
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
165
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
166

M
MRXLT 已提交
167 168
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
169 170 171 172 173
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
174 175 176
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
177 178
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
179 180
        return

181
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
182 183
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
184 185 186
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

B
barrierye 已提交
187
    def connect(self, endpoints=None):
G
guru4elephant 已提交
188 189 190
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
191 192 193 194 195 196 197
        if endpoints is None:
            if self.predictor_sdk_ is None:
                raise SystemExit(
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
198
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
199
                                 100)
B
barrierye 已提交
200 201
            else:
                print(
202
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
203
                    format(endpoints))
204
        sdk_desc = self.predictor_sdk_.gen_desc()
M
MRXLT 已提交
205 206
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
207 208 209 210 211 212 213

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
214 215 216
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
217
        if len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
218 219 220
            raise SystemExit("The shape of feed tensor {} not match.".format(
                key))

221
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
222 223
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
224 225 226
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

227 228 229 230 231 232
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
233
            raise ValueError("Fetch only accepts string and list of string")
234 235 236 237 238 239 240

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
241
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
242

M
MRXLT 已提交
243 244 245 246
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
247 248
        int_shape = []
        float_shape = []
M
MRXLT 已提交
249
        fetch_names = []
M
MRXLT 已提交
250
        counter = 0
M
MRXLT 已提交
251
        batch_size = len(feed_batch)
252 253 254 255 256 257 258

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
259
                "Fetch names should not be empty or out of saved fetch list.")
260 261
            return {}

G
guru4elephant 已提交
262
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
263 264
            int_slot = []
            float_slot = []
265
            for key in feed_i:
M
MRXLT 已提交
266
                if key not in self.feed_names_:
M
MRXLT 已提交
267
                    raise ValueError("Wrong feed name: {}.".format(key))
268 269
                if not isinstance(feed_i[key], np.ndarray):
                    self.shape_check(feed_i, key)
M
MRXLT 已提交
270
                if self.feed_types_[key] == int_type:
G
guru4elephant 已提交
271
                    if i == 0:
M
MRXLT 已提交
272
                        int_feed_names.append(key)
D
dongdaxiang 已提交
273
                        if isinstance(feed_i[key], np.ndarray):
274
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
275 276
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
277
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
278 279
                        #int_slot.append(np.reshape(feed_i[key], (-1)).tolist())
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
280
                        self.has_numpy_input = True
D
dongdaxiang 已提交
281 282
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
283
                        self.all_numpy_input = False
M
MRXLT 已提交
284
                elif self.feed_types_[key] == float_type:
G
guru4elephant 已提交
285
                    if i == 0:
M
MRXLT 已提交
286
                        float_feed_names.append(key)
D
dongdaxiang 已提交
287
                        if isinstance(feed_i[key], np.ndarray):
288
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
289 290
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
291
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
292 293
                        #float_slot.append(np.reshape(feed_i[key], (-1)).tolist())
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
294
                        self.has_numpy_input = True
D
dongdaxiang 已提交
295 296
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
297
                        self.all_numpy_input = False
M
MRXLT 已提交
298 299 300
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
301 302 303
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

M
MRXLT 已提交
304
        result_batch = self.result_handle_
M
MRXLT 已提交
305
        if self.all_numpy_input:
M
MRXLT 已提交
306 307 308
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
                int_feed_names, int_shape, fetch_names, result_batch, self.pid)
M
MRXLT 已提交
309
        elif self.has_numpy_input == False:
M
MRXLT 已提交
310 311 312
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
                int_feed_names, int_shape, fetch_names, result_batch, self.pid)
M
MRXLT 已提交
313 314 315 316
        else:
            raise SystemExit(
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
317

W
WangXi 已提交
318 319 320
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

321 322 323
        if res == -1:
            return None

B
barrierye 已提交
324
        multi_result_map = []
B
barrierye 已提交
325 326
        model_engine_names = result_batch.get_engine_names()
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
327
            result_map = {}
B
barrierye 已提交
328
            # result map needs to be a numpy array
B
barrierye 已提交
329 330
            for i, name in enumerate(fetch_names):
                if self.fetch_names_to_type_[name] == int_type:
B
barrierye 已提交
331
                    result_map[name] = result_batch.get_int64_by_name(mi, name)
B
barrierye 已提交
332
                    shape = result_batch.get_shape(mi, name)
W
WangXi 已提交
333
                    result_map[name] = np.array(result_map[name], dtype='int64')
B
barrierye 已提交
334 335
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
M
MRXLT 已提交
336 337
                        result_map["{}.lod".format(name)] = np.array(
                            result_batch.get_lod(mi, name))
B
barrierye 已提交
338
                elif self.fetch_names_to_type_[name] == float_type:
B
barrierye 已提交
339
                    result_map[name] = result_batch.get_float_by_name(mi, name)
B
barrierye 已提交
340
                    shape = result_batch.get_shape(mi, name)
W
WangXi 已提交
341 342
                    result_map[name] = np.array(
                        result_map[name], dtype='float32')
B
barrierye 已提交
343 344
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
M
MRXLT 已提交
345 346
                        result_map["{}.lod".format(name)] = np.array(
                            result_batch.get_lod(mi, name))
B
barrierye 已提交
347
            multi_result_map.append(result_map)
B
barrierye 已提交
348 349
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
350 351
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
352
        else:
B
barrierye 已提交
353 354 355 356 357 358
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
359 360 361
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
362
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
363 364 365
        return ret if not need_variant_tag else [
            ret, self.result_handle_.variant_tag()
        ]
B
barrierye 已提交
366

367 368
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
369
        self.client_handle_ = None