__init__.py 14.2 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
24
from .serving_client import PredictorRes
G
guru4elephant 已提交
25

G
guru4elephant 已提交
26 27 28
int_type = 0
float_type = 1

M
MRXLT 已提交
29

W
WangXi 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
class _NOPProfiler(object):
    def record(self, name):
        pass

    def print_profile(self):
        pass


class _TimeProfiler(object):
    def __init__(self):
        self.pid = os.getpid()
        self.print_head = 'PROFILE\tpid:{}\t'.format(self.pid)
        self.time_record = [self.print_head]

    def record(self, name):
        self.time_record.append('{}:{} '.format(
            name, int(round(time.time() * 1000000))))

    def print_profile(self):
        self.time_record.append('\n')
        sys.stderr.write(''.join(self.time_record))
        self.time_record = [self.print_head]


_is_profile = int(os.environ.get('FLAGS_profile_client', 0))
_Profiler = _TimeProfiler if _is_profile else _NOPProfiler


G
guru4elephant 已提交
58 59 60
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
61 62 63
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
G
guru4elephant 已提交
64

65 66 67 68
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
69 70 71 72 73 74 75

    def gen_desc(self):
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
76 77
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
78

79 80 81 82 83 84
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
85 86 87 88 89 90 91 92 93 94

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = 20000
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
95

G
guru4elephant 已提交
96 97 98 99 100 101 102 103
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
104
        return self.sdk_desc
G
guru4elephant 已提交
105

G
guru4elephant 已提交
106 107 108 109 110 111

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
M
MRXLT 已提交
112
        self.feed_shapes_ = {}
G
guru4elephant 已提交
113
        self.feed_types_ = {}
G
guru4elephant 已提交
114
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
115
        self.pid = os.getpid()
B
barrierye 已提交
116
        self.predictor_sdk_ = None
G
guru4elephant 已提交
117 118
        self.producers = []
        self.consumer = None
W
WangXi 已提交
119
        self.profile_ = _Profiler()
M
MRXLT 已提交
120 121
        self.all_numpy_input = True
        self.has_numpy_input = False
M
MRXLT 已提交
122

G
guru4elephant 已提交
123
    def load_client_config(self, path):
M
MRXLT 已提交
124
        from .serving_client import PredictorClient
125 126 127 128 129
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
130 131 132 133
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
G
guru4elephant 已提交
134 135
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
136 137
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
138
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
139 140
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
141 142
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
143
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
144 145
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
146
        self.lod_tensor_set = set()
M
MRXLT 已提交
147
        self.feed_tensor_len = {}
148

149 150 151
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
152
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
153

M
MRXLT 已提交
154 155
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
156 157 158 159 160
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
161 162 163
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
164 165
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
166 167
        return

168
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
169 170
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
171 172 173
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

B
barrierye 已提交
174
    def connect(self, endpoints=None):
G
guru4elephant 已提交
175 176 177
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
178 179 180 181 182 183 184
        if endpoints is None:
            if self.predictor_sdk_ is None:
                raise SystemExit(
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
185
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
186
                                 100)
B
barrierye 已提交
187 188
            else:
                print(
189
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
190
                    format(endpoints))
191
        sdk_desc = self.predictor_sdk_.gen_desc()
M
MRXLT 已提交
192 193
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
194 195 196 197 198 199 200

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
201 202 203
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
204 205 206 207 208 209
        if isinstance(feed[key],
                      list) and len(feed[key]) != self.feed_tensor_len[key]:
            raise SystemExit("The shape of feed tensor {} not match.".format(
                key))
        if type(feed[key]).__module__ == np.__name__ and np.size(feed[
                key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
210 211 212
            #raise SystemExit("The shape of feed tensor {} not match.".format(
            #    key))
            pass
M
MRXLT 已提交
213

214
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
W
WangXi 已提交
215 216
        self.profile_.record('py_prepro_0')

G
guru4elephant 已提交
217 218 219
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

220 221 222 223 224 225
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
226
            raise ValueError("Fetch only accepts string and list of string")
227 228 229 230 231 232 233

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
234
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
235

M
MRXLT 已提交
236 237 238 239
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
240 241
        int_shape = []
        float_shape = []
M
MRXLT 已提交
242
        fetch_names = []
M
MRXLT 已提交
243
        counter = 0
M
MRXLT 已提交
244
        batch_size = len(feed_batch)
245 246 247 248 249 250 251

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
252
                "Fetch names should not be empty or out of saved fetch list.")
253 254
            return {}

G
guru4elephant 已提交
255
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
256 257
            int_slot = []
            float_slot = []
258
            for key in feed_i:
M
MRXLT 已提交
259
                if key not in self.feed_names_:
M
MRXLT 已提交
260
                    raise ValueError("Wrong feed name: {}.".format(key))
M
MRXLT 已提交
261 262
                #if not isinstance(feed_i[key], np.ndarray):
                self.shape_check(feed_i, key)
M
MRXLT 已提交
263
                if self.feed_types_[key] == int_type:
G
guru4elephant 已提交
264
                    if i == 0:
M
MRXLT 已提交
265
                        int_feed_names.append(key)
D
dongdaxiang 已提交
266
                        if isinstance(feed_i[key], np.ndarray):
267
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
268 269
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
270
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
271
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
272
                        self.has_numpy_input = True
D
dongdaxiang 已提交
273 274
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
275
                        self.all_numpy_input = False
M
MRXLT 已提交
276
                elif self.feed_types_[key] == float_type:
G
guru4elephant 已提交
277
                    if i == 0:
M
MRXLT 已提交
278
                        float_feed_names.append(key)
D
dongdaxiang 已提交
279
                        if isinstance(feed_i[key], np.ndarray):
280
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
281 282
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
283
                    if isinstance(feed_i[key], np.ndarray):
M
MRXLT 已提交
284
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
285
                        self.has_numpy_input = True
D
dongdaxiang 已提交
286 287
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
288
                        self.all_numpy_input = False
M
MRXLT 已提交
289 290 291
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

W
WangXi 已提交
292 293 294
        self.profile_.record('py_prepro_1')
        self.profile_.record('py_client_infer_0')

295
        result_batch_handle = PredictorRes()
M
MRXLT 已提交
296
        if self.all_numpy_input:
M
MRXLT 已提交
297 298
            res = self.client_handle_.numpy_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
299 300
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
301
        elif self.has_numpy_input == False:
M
MRXLT 已提交
302 303
            res = self.client_handle_.batch_predict(
                float_slot_batch, float_feed_names, float_shape, int_slot_batch,
304 305
                int_feed_names, int_shape, fetch_names, result_batch_handle,
                self.pid)
M
MRXLT 已提交
306 307 308 309
        else:
            raise SystemExit(
                "Please make sure the inputs are all in list type or all in numpy.array type"
            )
M
MRXLT 已提交
310

W
WangXi 已提交
311 312 313
        self.profile_.record('py_client_infer_1')
        self.profile_.record('py_postpro_0')

314 315 316
        if res == -1:
            return None

B
barrierye 已提交
317
        multi_result_map = []
318
        model_engine_names = result_batch_handle.get_engine_names()
B
barrierye 已提交
319
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
320
            result_map = {}
B
barrierye 已提交
321
            # result map needs to be a numpy array
B
barrierye 已提交
322 323
            for i, name in enumerate(fetch_names):
                if self.fetch_names_to_type_[name] == int_type:
B
barrierye 已提交
324
                    # result_map[name] will be py::array(numpy array)
325 326 327
                    result_map[name] = result_batch_handle.get_int64_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
328 329
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
330 331
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
332
                elif self.fetch_names_to_type_[name] == float_type:
333 334 335
                    result_map[name] = result_batch_handle.get_float_by_name(
                        mi, name)
                    shape = result_batch_handle.get_shape(mi, name)
B
barrierye 已提交
336 337
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
338 339
                        result_map["{}.lod".format(
                            name)] = result_batch_handle.get_lod(mi, name)
B
barrierye 已提交
340
            multi_result_map.append(result_map)
B
barrierye 已提交
341 342
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
343 344
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
345
        else:
B
barrierye 已提交
346 347 348 349 350 351
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

W
WangXi 已提交
352 353 354
        self.profile_.record('py_postpro_1')
        self.profile_.print_profile()

B
barrierye 已提交
355
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
356
        return ret if not need_variant_tag else [
357
            ret, result_batch_handle.variant_tag()
B
barrierye 已提交
358
        ]
B
barrierye 已提交
359

360 361
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
362
        self.client_handle_ = None